158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
import os
|
||
from pathlib import Path
|
||
from typing import List, Optional, Tuple, Union
|
||
|
||
import torch
|
||
from torch.utils.data import Dataset
|
||
from torchaudio.datasets.utils import _load_waveform
|
||
|
||
|
||
_SAMPLE_RATE = 16000
|
||
_SPEAKERS = [
|
||
"Aditi",
|
||
"Amy",
|
||
"Brian",
|
||
"Emma",
|
||
"Geraint",
|
||
"Ivy",
|
||
"Joanna",
|
||
"Joey",
|
||
"Justin",
|
||
"Kendra",
|
||
"Kimberly",
|
||
"Matthew",
|
||
"Nicole",
|
||
"Raveena",
|
||
"Russell",
|
||
"Salli",
|
||
]
|
||
|
||
|
||
def _load_labels(file: Path, subset: str):
|
||
"""Load transcirpt, iob, and intent labels for all utterances.
|
||
|
||
Args:
|
||
file (Path): The path to the label file.
|
||
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
|
||
|
||
Returns:
|
||
Dictionary of labels, where the key is the filename of the audio,
|
||
and the label is a Tuple of transcript, Inside–outside–beginning (IOB) label, and intention label.
|
||
"""
|
||
labels = {}
|
||
with open(file, "r") as f:
|
||
for line in f:
|
||
line = line.strip().split(" ")
|
||
index = line[0]
|
||
trans, iob_intent = " ".join(line[1:]).split("\t")
|
||
trans = " ".join(trans.split(" ")[1:-1])
|
||
iob = " ".join(iob_intent.split(" ")[1:-1])
|
||
intent = iob_intent.split(" ")[-1]
|
||
if subset in index:
|
||
labels[index] = (trans, iob, intent)
|
||
return labels
|
||
|
||
|
||
class Snips(Dataset):
|
||
"""*Snips* :cite:`coucke2018snips` dataset.
|
||
|
||
Args:
|
||
root (str or Path): Root directory where the dataset's top level directory is found.
|
||
subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
|
||
speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``,
|
||
include all speakers in the subset. (Default: ``None``)
|
||
audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``].
|
||
(Default: ``"mp3"``)
|
||
"""
|
||
|
||
_trans_file = "all.iob.snips.txt"
|
||
|
||
def __init__(
|
||
self,
|
||
root: Union[str, Path],
|
||
subset: str,
|
||
speakers: Optional[List[str]] = None,
|
||
audio_format: str = "mp3",
|
||
) -> None:
|
||
if subset not in ["train", "valid", "test"]:
|
||
raise ValueError('`subset` must be one of ["train", "valid", "test"].')
|
||
if audio_format not in ["mp3", "wav"]:
|
||
raise ValueError('`audio_format` must be one of ["mp3", "wav].')
|
||
|
||
root = Path(root)
|
||
self._path = root / "SNIPS"
|
||
self.audio_path = self._path / subset
|
||
if speakers is None:
|
||
speakers = _SPEAKERS
|
||
|
||
if not os.path.isdir(self._path):
|
||
raise RuntimeError("Dataset not found.")
|
||
|
||
self.audio_paths = self.audio_path.glob(f"*.{audio_format}")
|
||
self.data = []
|
||
for audio_path in sorted(self.audio_paths):
|
||
audio_name = str(audio_path.name)
|
||
speaker = audio_name.split("-")[0]
|
||
if speaker in speakers:
|
||
self.data.append(audio_path)
|
||
transcript_path = self._path / self._trans_file
|
||
self.labels = _load_labels(transcript_path, subset)
|
||
|
||
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
|
||
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
|
||
but otherwise returns the same fields as :py:func:`__getitem__`.
|
||
|
||
Args:
|
||
n (int): The index of the sample to be loaded.
|
||
|
||
Returns:
|
||
Tuple of the following items:
|
||
|
||
str:
|
||
Path to audio
|
||
int:
|
||
Sample rate
|
||
str:
|
||
File name
|
||
str:
|
||
Transcription of audio
|
||
str:
|
||
Inside–outside–beginning (IOB) label of transcription
|
||
str:
|
||
Intention label of the audio.
|
||
"""
|
||
audio_path = self.data[n]
|
||
relpath = os.path.relpath(audio_path, self._path)
|
||
file_name = audio_path.with_suffix("").name
|
||
transcript, iob, intent = self.labels[file_name]
|
||
return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
|
||
|
||
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
|
||
"""Load the n-th sample from the dataset.
|
||
|
||
Args:
|
||
n (int): The index of the sample to be loaded
|
||
|
||
Returns:
|
||
Tuple of the following items:
|
||
|
||
Tensor:
|
||
Waveform
|
||
int:
|
||
Sample rate
|
||
str:
|
||
File name
|
||
str:
|
||
Transcription of audio
|
||
str:
|
||
Inside–outside–beginning (IOB) label of transcription
|
||
str:
|
||
Intention label of the audio.
|
||
"""
|
||
metadata = self.get_metadata(n)
|
||
waveform = _load_waveform(self._path, metadata[0], metadata[1])
|
||
return (waveform,) + metadata[1:]
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.data)
|