158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
|
import os
|
|||
|
from pathlib import Path
|
|||
|
from typing import List, Optional, Tuple, Union
|
|||
|
|
|||
|
import torch
|
|||
|
from torch.utils.data import Dataset
|
|||
|
from torchaudio.datasets.utils import _load_waveform
|
|||
|
|
|||
|
|
|||
|
_SAMPLE_RATE = 16000
|
|||
|
_SPEAKERS = [
|
|||
|
"Aditi",
|
|||
|
"Amy",
|
|||
|
"Brian",
|
|||
|
"Emma",
|
|||
|
"Geraint",
|
|||
|
"Ivy",
|
|||
|
"Joanna",
|
|||
|
"Joey",
|
|||
|
"Justin",
|
|||
|
"Kendra",
|
|||
|
"Kimberly",
|
|||
|
"Matthew",
|
|||
|
"Nicole",
|
|||
|
"Raveena",
|
|||
|
"Russell",
|
|||
|
"Salli",
|
|||
|
]
|
|||
|
|
|||
|
|
|||
|
def _load_labels(file: Path, subset: str):
|
|||
|
"""Load transcirpt, iob, and intent labels for all utterances.
|
|||
|
|
|||
|
Args:
|
|||
|
file (Path): The path to the label file.
|
|||
|
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
|
|||
|
|
|||
|
Returns:
|
|||
|
Dictionary of labels, where the key is the filename of the audio,
|
|||
|
and the label is a Tuple of transcript, Inside–outside–beginning (IOB) label, and intention label.
|
|||
|
"""
|
|||
|
labels = {}
|
|||
|
with open(file, "r") as f:
|
|||
|
for line in f:
|
|||
|
line = line.strip().split(" ")
|
|||
|
index = line[0]
|
|||
|
trans, iob_intent = " ".join(line[1:]).split("\t")
|
|||
|
trans = " ".join(trans.split(" ")[1:-1])
|
|||
|
iob = " ".join(iob_intent.split(" ")[1:-1])
|
|||
|
intent = iob_intent.split(" ")[-1]
|
|||
|
if subset in index:
|
|||
|
labels[index] = (trans, iob, intent)
|
|||
|
return labels
|
|||
|
|
|||
|
|
|||
|
class Snips(Dataset):
|
|||
|
"""*Snips* :cite:`coucke2018snips` dataset.
|
|||
|
|
|||
|
Args:
|
|||
|
root (str or Path): Root directory where the dataset's top level directory is found.
|
|||
|
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
|
|||
|
speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``,
|
|||
|
include all speakers in the subset. (Default: ``None``)
|
|||
|
audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``].
|
|||
|
(Default: ``"mp3"``)
|
|||
|
"""
|
|||
|
|
|||
|
_trans_file = "all.iob.snips.txt"
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
root: Union[str, Path],
|
|||
|
subset: str,
|
|||
|
speakers: Optional[List[str]] = None,
|
|||
|
audio_format: str = "mp3",
|
|||
|
) -> None:
|
|||
|
if subset not in ["train", "valid", "test"]:
|
|||
|
raise ValueError('`subset` must be one of ["train", "valid", "test"].')
|
|||
|
if audio_format not in ["mp3", "wav"]:
|
|||
|
raise ValueError('`audio_format` must be one of ["mp3", "wav].')
|
|||
|
|
|||
|
root = Path(root)
|
|||
|
self._path = root / "SNIPS"
|
|||
|
self.audio_path = self._path / subset
|
|||
|
if speakers is None:
|
|||
|
speakers = _SPEAKERS
|
|||
|
|
|||
|
if not os.path.isdir(self._path):
|
|||
|
raise RuntimeError("Dataset not found.")
|
|||
|
|
|||
|
self.audio_paths = self.audio_path.glob(f"*.{audio_format}")
|
|||
|
self.data = []
|
|||
|
for audio_path in sorted(self.audio_paths):
|
|||
|
audio_name = str(audio_path.name)
|
|||
|
speaker = audio_name.split("-")[0]
|
|||
|
if speaker in speakers:
|
|||
|
self.data.append(audio_path)
|
|||
|
transcript_path = self._path / self._trans_file
|
|||
|
self.labels = _load_labels(transcript_path, subset)
|
|||
|
|
|||
|
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
|
|||
|
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
|
|||
|
but otherwise returns the same fields as :py:func:`__getitem__`.
|
|||
|
|
|||
|
Args:
|
|||
|
n (int): The index of the sample to be loaded.
|
|||
|
|
|||
|
Returns:
|
|||
|
Tuple of the following items:
|
|||
|
|
|||
|
str:
|
|||
|
Path to audio
|
|||
|
int:
|
|||
|
Sample rate
|
|||
|
str:
|
|||
|
File name
|
|||
|
str:
|
|||
|
Transcription of audio
|
|||
|
str:
|
|||
|
Inside–outside–beginning (IOB) label of transcription
|
|||
|
str:
|
|||
|
Intention label of the audio.
|
|||
|
"""
|
|||
|
audio_path = self.data[n]
|
|||
|
relpath = os.path.relpath(audio_path, self._path)
|
|||
|
file_name = audio_path.with_suffix("").name
|
|||
|
transcript, iob, intent = self.labels[file_name]
|
|||
|
return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
|
|||
|
|
|||
|
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
|
|||
|
"""Load the n-th sample from the dataset.
|
|||
|
|
|||
|
Args:
|
|||
|
n (int): The index of the sample to be loaded
|
|||
|
|
|||
|
Returns:
|
|||
|
Tuple of the following items:
|
|||
|
|
|||
|
Tensor:
|
|||
|
Waveform
|
|||
|
int:
|
|||
|
Sample rate
|
|||
|
str:
|
|||
|
File name
|
|||
|
str:
|
|||
|
Transcription of audio
|
|||
|
str:
|
|||
|
Inside–outside–beginning (IOB) label of transcription
|
|||
|
str:
|
|||
|
Intention label of the audio.
|
|||
|
"""
|
|||
|
metadata = self.get_metadata(n)
|
|||
|
waveform = _load_waveform(self._path, metadata[0], metadata[1])
|
|||
|
return (waveform,) + metadata[1:]
|
|||
|
|
|||
|
def __len__(self) -> int:
|
|||
|
return len(self.data)
|