90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
import os
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Union
|
|
|
|
import torchaudio
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset
|
|
from torchaudio._internal import download_url_to_file
|
|
from torchaudio.datasets.utils import _extract_tar
|
|
|
|
|
|
_RELEASE_CONFIGS = {
|
|
"release1": {
|
|
"folder_in_archive": "waves_yesno",
|
|
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
|
|
"checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
|
|
}
|
|
}
|
|
|
|
|
|
class YESNO(Dataset):
|
|
"""*YesNo* :cite:`YesNo` dataset.
|
|
|
|
Args:
|
|
root (str or Path): Path to the directory where the dataset is found or downloaded.
|
|
url (str, optional): The URL to download the dataset from.
|
|
(default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)
|
|
folder_in_archive (str, optional):
|
|
The top-level directory of the dataset. (default: ``"waves_yesno"``)
|
|
download (bool, optional):
|
|
Whether to download the dataset if it is not found at root path. (default: ``False``).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[str, Path],
|
|
url: str = _RELEASE_CONFIGS["release1"]["url"],
|
|
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
|
|
download: bool = False,
|
|
) -> None:
|
|
|
|
self._parse_filesystem(root, url, folder_in_archive, download)
|
|
|
|
def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
|
|
root = Path(root)
|
|
archive = os.path.basename(url)
|
|
archive = root / archive
|
|
|
|
self._path = root / folder_in_archive
|
|
if download:
|
|
if not os.path.isdir(self._path):
|
|
if not os.path.isfile(archive):
|
|
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
|
|
download_url_to_file(url, archive, hash_prefix=checksum)
|
|
_extract_tar(archive)
|
|
|
|
if not os.path.isdir(self._path):
|
|
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
|
|
|
|
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
|
|
|
|
def _load_item(self, fileid: str, path: str):
|
|
labels = [int(c) for c in fileid.split("_")]
|
|
file_audio = os.path.join(path, fileid + ".wav")
|
|
waveform, sample_rate = torchaudio.load(file_audio)
|
|
return waveform, sample_rate, labels
|
|
|
|
def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
|
|
"""Load the n-th sample from the dataset.
|
|
|
|
Args:
|
|
n (int): The index of the sample to be loaded
|
|
|
|
Returns:
|
|
Tuple of the following items;
|
|
|
|
Tensor:
|
|
Waveform
|
|
int:
|
|
Sample rate
|
|
List[int]:
|
|
labels
|
|
"""
|
|
fileid = self._walker[n]
|
|
item = self._load_item(fileid, self._path)
|
|
return item
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._walker)
|