Traktor/myenv/Lib/site-packages/torchaudio/datasets/snips.py

158 lines
4.9 KiB
Python
Raw Permalink Normal View History

2024-05-26 05:12:46 +02:00
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
_SAMPLE_RATE = 16000
_SPEAKERS = [
"Aditi",
"Amy",
"Brian",
"Emma",
"Geraint",
"Ivy",
"Joanna",
"Joey",
"Justin",
"Kendra",
"Kimberly",
"Matthew",
"Nicole",
"Raveena",
"Russell",
"Salli",
]
def _load_labels(file: Path, subset: str):
"""Load transcirpt, iob, and intent labels for all utterances.
Args:
file (Path): The path to the label file.
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
Returns:
Dictionary of labels, where the key is the filename of the audio,
and the label is a Tuple of transcript, Insideoutsidebeginning (IOB) label, and intention label.
"""
labels = {}
with open(file, "r") as f:
for line in f:
line = line.strip().split(" ")
index = line[0]
trans, iob_intent = " ".join(line[1:]).split("\t")
trans = " ".join(trans.split(" ")[1:-1])
iob = " ".join(iob_intent.split(" ")[1:-1])
intent = iob_intent.split(" ")[-1]
if subset in index:
labels[index] = (trans, iob, intent)
return labels
class Snips(Dataset):
"""*Snips* :cite:`coucke2018snips` dataset.
Args:
root (str or Path): Root directory where the dataset's top level directory is found.
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``,
include all speakers in the subset. (Default: ``None``)
audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``].
(Default: ``"mp3"``)
"""
_trans_file = "all.iob.snips.txt"
def __init__(
self,
root: Union[str, Path],
subset: str,
speakers: Optional[List[str]] = None,
audio_format: str = "mp3",
) -> None:
if subset not in ["train", "valid", "test"]:
raise ValueError('`subset` must be one of ["train", "valid", "test"].')
if audio_format not in ["mp3", "wav"]:
raise ValueError('`audio_format` must be one of ["mp3", "wav].')
root = Path(root)
self._path = root / "SNIPS"
self.audio_path = self._path / subset
if speakers is None:
speakers = _SPEAKERS
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found.")
self.audio_paths = self.audio_path.glob(f"*.{audio_format}")
self.data = []
for audio_path in sorted(self.audio_paths):
audio_name = str(audio_path.name)
speaker = audio_name.split("-")[0]
if speaker in speakers:
self.data.append(audio_path)
transcript_path = self._path / self._trans_file
self.labels = _load_labels(transcript_path, subset)
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded.
Returns:
Tuple of the following items:
str:
Path to audio
int:
Sample rate
str:
File name
str:
Transcription of audio
str:
Insideoutsidebeginning (IOB) label of transcription
str:
Intention label of the audio.
"""
audio_path = self.data[n]
relpath = os.path.relpath(audio_path, self._path)
file_name = audio_path.with_suffix("").name
transcript, iob, intent = self.labels[file_name]
return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items:
Tensor:
Waveform
int:
Sample rate
str:
File name
str:
Transcription of audio
str:
Insideoutsidebeginning (IOB) label of transcription
str:
Intention label of the audio.
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self.data)