from pathlib import Path from typing import Dict, Tuple, Union import torchaudio from torch import Tensor from torch.utils.data import Dataset from torchaudio._internal import download_url_to_file from torchaudio.datasets.utils import _extract_zip _URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip" _CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769" _SUPPORTED_SUBSETS = {"train", "test"} class DR_VCTK(Dataset): """*Device Recorded VCTK (Small subset version)* :cite:`Sarfjoo2018DeviceRV` dataset. Args: root (str or Path): Root directory where the dataset's top level directory is found. subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``). download (bool): Whether to download the dataset if it is not found at root path. (default: ``False``). url (str): The URL to download the dataset from. (default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``) """ def __init__( self, root: Union[str, Path], subset: str = "train", *, download: bool = False, url: str = _URL, ) -> None: if subset not in _SUPPORTED_SUBSETS: raise RuntimeError( f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}" ) root = Path(root).expanduser() archive = root / "DR-VCTK.zip" self._subset = subset self._path = root / "DR-VCTK" / "DR-VCTK" self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k" self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k" self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt" if not self._path.is_dir(): if not archive.is_file(): if not download: raise RuntimeError("Dataset not found. Please use `download=True` to download it.") download_url_to_file(url, archive, hash_prefix=_CHECKSUM) _extract_zip(archive, root) self._config = self._load_config(self._config_filepath) self._filename_list = sorted(self._config) def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]: # Skip header skip_rows = 2 if self._subset == "train" else 1 config = {} with open(filepath) as f: for i, line in enumerate(f): if i < skip_rows or not line: continue filename, source, channel_id = line.strip().split("\t") config[filename] = (source, int(channel_id)) return config def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]: speaker_id, utterance_id = filename.split(".")[0].split("_") source, channel_id = self._config[filename] file_clean_audio = self._clean_audio_dir / filename file_noisy_audio = self._noisy_audio_dir / filename waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio) waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio) return ( waveform_clean, sample_rate_clean, waveform_noisy, sample_rate_noisy, speaker_id, utterance_id, source, channel_id, ) def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]: """Load the n-th sample from the dataset. Args: n (int): The index of the sample to be loaded Returns: Tuple of the following items; Tensor: Clean waveform int: Sample rate of the clean waveform Tensor: Noisy waveform int: Sample rate of the noisy waveform str: Speaker ID str: Utterance ID str: Source int: Channel ID """ filename = self._filename_list[n] return self._load_dr_vctk_item(filename) def __len__(self) -> int: return len(self._filename_list)