140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
|
import os
|
||
|
from pathlib import Path
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torchaudio
|
||
|
from torch.utils.data import Dataset
|
||
|
from torchaudio._internal import download_url_to_file
|
||
|
from torchaudio.datasets.utils import _extract_zip
|
||
|
|
||
|
_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
|
||
|
_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
|
||
|
_EXT = ".wav"
|
||
|
_SAMPLE_RATE = 44100
|
||
|
_VALIDATION_SET = [
|
||
|
"Actions - One Minute Smile",
|
||
|
"Clara Berry And Wooldog - Waltz For My Victims",
|
||
|
"Johnny Lokke - Promises & Lies",
|
||
|
"Patrick Talbot - A Reason To Leave",
|
||
|
"Triviul - Angelsaint",
|
||
|
"Alexander Ross - Goodbye Bolero",
|
||
|
"Fergessen - Nos Palpitants",
|
||
|
"Leaf - Summerghost",
|
||
|
"Skelpolu - Human Mistakes",
|
||
|
"Young Griffo - Pennies",
|
||
|
"ANiMAL - Rockshow",
|
||
|
"James May - On The Line",
|
||
|
"Meaxic - Take A Step",
|
||
|
"Traffic Experiment - Sirens",
|
||
|
]
|
||
|
|
||
|
|
||
|
class MUSDB_HQ(Dataset):
|
||
|
"""*MUSDB_HQ* :cite:`MUSDB18HQ` dataset.
|
||
|
|
||
|
Args:
|
||
|
root (str or Path): Root directory where the dataset's top level directory is found
|
||
|
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``].
|
||
|
sources (List[str] or None, optional): Sources extract data from.
|
||
|
List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
|
||
|
If ``None``, dataset consists of tracks except mixture.
|
||
|
(default: ``None``)
|
||
|
split (str or None, optional): Whether to split training set into train and validation set.
|
||
|
If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set.
|
||
|
(default: ``None``)
|
||
|
download (bool, optional): Whether to download the dataset if it is not found at root path.
|
||
|
(default: ``False``)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[str, Path],
|
||
|
subset: str,
|
||
|
sources: Optional[List[str]] = None,
|
||
|
split: Optional[str] = None,
|
||
|
download: bool = False,
|
||
|
) -> None:
|
||
|
self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources
|
||
|
self.split = split
|
||
|
|
||
|
basename = os.path.basename(_URL)
|
||
|
archive = os.path.join(root, basename)
|
||
|
basename = basename.rsplit(".", 2)[0]
|
||
|
|
||
|
if subset not in ["test", "train"]:
|
||
|
raise ValueError("`subset` must be one of ['test', 'train']")
|
||
|
if self.split is not None and self.split not in ["train", "validation"]:
|
||
|
raise ValueError("`split` must be one of ['train', 'validation']")
|
||
|
base_path = os.path.join(root, basename)
|
||
|
self._path = os.path.join(base_path, subset)
|
||
|
if not os.path.isdir(self._path):
|
||
|
if not os.path.isfile(archive):
|
||
|
if not download:
|
||
|
raise RuntimeError("Dataset not found. Please use `download=True` to download")
|
||
|
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
|
||
|
os.makedirs(base_path, exist_ok=True)
|
||
|
_extract_zip(archive, base_path)
|
||
|
|
||
|
self.names = self._collect_songs()
|
||
|
|
||
|
def _get_track(self, name, source):
|
||
|
return Path(self._path) / name / f"{source}{_EXT}"
|
||
|
|
||
|
def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
|
||
|
name = self.names[n]
|
||
|
wavs = []
|
||
|
num_frames = None
|
||
|
for source in self.sources:
|
||
|
track = self._get_track(name, source)
|
||
|
wav, sr = torchaudio.load(str(track))
|
||
|
if sr != _SAMPLE_RATE:
|
||
|
raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
|
||
|
if num_frames is None:
|
||
|
num_frames = wav.shape[-1]
|
||
|
else:
|
||
|
if wav.shape[-1] != num_frames:
|
||
|
raise ValueError("num_frames do not match across sources")
|
||
|
wavs.append(wav)
|
||
|
|
||
|
stacked = torch.stack(wavs)
|
||
|
|
||
|
return stacked, _SAMPLE_RATE, num_frames, name
|
||
|
|
||
|
def _collect_songs(self):
|
||
|
if self.split == "validation":
|
||
|
return _VALIDATION_SET
|
||
|
path = Path(self._path)
|
||
|
names = []
|
||
|
for root, folders, _ in os.walk(path, followlinks=True):
|
||
|
root = Path(root)
|
||
|
if root.name.startswith(".") or folders or root == path:
|
||
|
continue
|
||
|
name = str(root.relative_to(path))
|
||
|
if self.split and name in _VALIDATION_SET:
|
||
|
continue
|
||
|
names.append(name)
|
||
|
return sorted(names)
|
||
|
|
||
|
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
|
||
|
"""Load the n-th sample from the dataset.
|
||
|
|
||
|
Args:
|
||
|
n (int): The index of the sample to be loaded
|
||
|
Returns:
|
||
|
Tuple of the following items;
|
||
|
|
||
|
Tensor:
|
||
|
Waveform
|
||
|
int:
|
||
|
Sample rate
|
||
|
int:
|
||
|
Num frames
|
||
|
str:
|
||
|
Track name
|
||
|
"""
|
||
|
return self._load_sample(n)
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.names)
|