122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
|
from pathlib import Path
|
||
|
from typing import Dict, Tuple, Union
|
||
|
|
||
|
import torchaudio
|
||
|
from torch import Tensor
|
||
|
from torch.utils.data import Dataset
|
||
|
from torchaudio._internal import download_url_to_file
|
||
|
from torchaudio.datasets.utils import _extract_zip
|
||
|
|
||
|
|
||
|
_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
|
||
|
_CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769"
|
||
|
_SUPPORTED_SUBSETS = {"train", "test"}
|
||
|
|
||
|
|
||
|
class DR_VCTK(Dataset):
|
||
|
"""*Device Recorded VCTK (Small subset version)* :cite:`Sarfjoo2018DeviceRV` dataset.
|
||
|
|
||
|
Args:
|
||
|
root (str or Path): Root directory where the dataset's top level directory is found.
|
||
|
subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``).
|
||
|
download (bool):
|
||
|
Whether to download the dataset if it is not found at root path. (default: ``False``).
|
||
|
url (str): The URL to download the dataset from.
|
||
|
(default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[str, Path],
|
||
|
subset: str = "train",
|
||
|
*,
|
||
|
download: bool = False,
|
||
|
url: str = _URL,
|
||
|
) -> None:
|
||
|
if subset not in _SUPPORTED_SUBSETS:
|
||
|
raise RuntimeError(
|
||
|
f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}"
|
||
|
)
|
||
|
|
||
|
root = Path(root).expanduser()
|
||
|
archive = root / "DR-VCTK.zip"
|
||
|
|
||
|
self._subset = subset
|
||
|
self._path = root / "DR-VCTK" / "DR-VCTK"
|
||
|
self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k"
|
||
|
self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k"
|
||
|
self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt"
|
||
|
|
||
|
if not self._path.is_dir():
|
||
|
if not archive.is_file():
|
||
|
if not download:
|
||
|
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
|
||
|
download_url_to_file(url, archive, hash_prefix=_CHECKSUM)
|
||
|
_extract_zip(archive, root)
|
||
|
|
||
|
self._config = self._load_config(self._config_filepath)
|
||
|
self._filename_list = sorted(self._config)
|
||
|
|
||
|
def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]:
|
||
|
# Skip header
|
||
|
skip_rows = 2 if self._subset == "train" else 1
|
||
|
|
||
|
config = {}
|
||
|
with open(filepath) as f:
|
||
|
for i, line in enumerate(f):
|
||
|
if i < skip_rows or not line:
|
||
|
continue
|
||
|
filename, source, channel_id = line.strip().split("\t")
|
||
|
config[filename] = (source, int(channel_id))
|
||
|
return config
|
||
|
|
||
|
def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
|
||
|
speaker_id, utterance_id = filename.split(".")[0].split("_")
|
||
|
source, channel_id = self._config[filename]
|
||
|
file_clean_audio = self._clean_audio_dir / filename
|
||
|
file_noisy_audio = self._noisy_audio_dir / filename
|
||
|
waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
|
||
|
waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
|
||
|
return (
|
||
|
waveform_clean,
|
||
|
sample_rate_clean,
|
||
|
waveform_noisy,
|
||
|
sample_rate_noisy,
|
||
|
speaker_id,
|
||
|
utterance_id,
|
||
|
source,
|
||
|
channel_id,
|
||
|
)
|
||
|
|
||
|
def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
|
||
|
"""Load the n-th sample from the dataset.
|
||
|
|
||
|
Args:
|
||
|
n (int): The index of the sample to be loaded
|
||
|
|
||
|
Returns:
|
||
|
Tuple of the following items;
|
||
|
|
||
|
Tensor:
|
||
|
Clean waveform
|
||
|
int:
|
||
|
Sample rate of the clean waveform
|
||
|
Tensor:
|
||
|
Noisy waveform
|
||
|
int:
|
||
|
Sample rate of the noisy waveform
|
||
|
str:
|
||
|
Speaker ID
|
||
|
str:
|
||
|
Utterance ID
|
||
|
str:
|
||
|
Source
|
||
|
int:
|
||
|
Channel ID
|
||
|
"""
|
||
|
filename = self._filename_list[n]
|
||
|
return self._load_dr_vctk_item(filename)
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._filename_list)
|