219 lines
8.5 KiB
Python
219 lines
8.5 KiB
Python
|
import os
|
||
|
from pathlib import Path
|
||
|
from typing import Tuple, Union
|
||
|
|
||
|
import torchaudio
|
||
|
from torch import Tensor
|
||
|
from torch.utils.data import Dataset
|
||
|
from torchaudio._internal import download_url_to_file
|
||
|
from torchaudio.datasets.utils import _extract_tar
|
||
|
|
||
|
|
||
|
_RELEASE_CONFIGS = {
|
||
|
"release1": {
|
||
|
"folder_in_archive": "TEDLIUM_release1",
|
||
|
"url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz",
|
||
|
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
|
||
|
"data_path": "",
|
||
|
"subset": "train",
|
||
|
"supported_subsets": ["train", "test", "dev"],
|
||
|
"dict": "TEDLIUM.150K.dic",
|
||
|
},
|
||
|
"release2": {
|
||
|
"folder_in_archive": "TEDLIUM_release2",
|
||
|
"url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz",
|
||
|
"checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58",
|
||
|
"data_path": "",
|
||
|
"subset": "train",
|
||
|
"supported_subsets": ["train", "test", "dev"],
|
||
|
"dict": "TEDLIUM.152k.dic",
|
||
|
},
|
||
|
"release3": {
|
||
|
"folder_in_archive": "TEDLIUM_release-3",
|
||
|
"url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz",
|
||
|
"checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb",
|
||
|
"data_path": "data/",
|
||
|
"subset": "train",
|
||
|
"supported_subsets": ["train", "test", "dev"],
|
||
|
"dict": "TEDLIUM.152k.dic",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
|
||
|
class TEDLIUM(Dataset):
|
||
|
"""*Tedlium* :cite:`rousseau2012tedlium` dataset (releases 1,2 and 3).
|
||
|
|
||
|
Args:
|
||
|
root (str or Path): Path to the directory where the dataset is found or downloaded.
|
||
|
release (str, optional): Release version.
|
||
|
Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
|
||
|
(default: ``"release1"``).
|
||
|
subset (str, optional): The subset of dataset to use. Valid options are ``"train"``, ``"dev"``,
|
||
|
and ``"test"``. Defaults to ``"train"``.
|
||
|
download (bool, optional):
|
||
|
Whether to download the dataset if it is not found at root path. (default: ``False``).
|
||
|
audio_ext (str, optional): extension for audio file (default: ``".sph"``)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[str, Path],
|
||
|
release: str = "release1",
|
||
|
subset: str = "train",
|
||
|
download: bool = False,
|
||
|
audio_ext: str = ".sph",
|
||
|
) -> None:
|
||
|
self._ext_audio = audio_ext
|
||
|
if release in _RELEASE_CONFIGS.keys():
|
||
|
folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"]
|
||
|
url = _RELEASE_CONFIGS[release]["url"]
|
||
|
subset = subset if subset else _RELEASE_CONFIGS[release]["subset"]
|
||
|
else:
|
||
|
# Raise warning
|
||
|
raise RuntimeError(
|
||
|
"The release {} does not match any of the supported tedlium releases{} ".format(
|
||
|
release,
|
||
|
_RELEASE_CONFIGS.keys(),
|
||
|
)
|
||
|
)
|
||
|
if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]:
|
||
|
# Raise warning
|
||
|
raise RuntimeError(
|
||
|
"The subset {} does not match any of the supported tedlium subsets{} ".format(
|
||
|
subset,
|
||
|
_RELEASE_CONFIGS[release]["supported_subsets"],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# Get string representation of 'root' in case Path object is passed
|
||
|
root = os.fspath(root)
|
||
|
|
||
|
basename = os.path.basename(url)
|
||
|
archive = os.path.join(root, basename)
|
||
|
|
||
|
basename = basename.split(".")[0]
|
||
|
|
||
|
if release == "release3":
|
||
|
if subset == "train":
|
||
|
self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"])
|
||
|
else:
|
||
|
self._path = os.path.join(root, folder_in_archive, "legacy", subset)
|
||
|
else:
|
||
|
self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"], subset)
|
||
|
|
||
|
if download:
|
||
|
if not os.path.isdir(self._path):
|
||
|
if not os.path.isfile(archive):
|
||
|
checksum = _RELEASE_CONFIGS[release]["checksum"]
|
||
|
download_url_to_file(url, archive, hash_prefix=checksum)
|
||
|
_extract_tar(archive)
|
||
|
else:
|
||
|
if not os.path.exists(self._path):
|
||
|
raise RuntimeError(
|
||
|
f"The path {self._path} doesn't exist. "
|
||
|
"Please check the ``root`` path or set `download=True` to download it"
|
||
|
)
|
||
|
|
||
|
# Create list for all samples
|
||
|
self._filelist = []
|
||
|
stm_path = os.path.join(self._path, "stm")
|
||
|
for file in sorted(os.listdir(stm_path)):
|
||
|
if file.endswith(".stm"):
|
||
|
stm_path = os.path.join(self._path, "stm", file)
|
||
|
with open(stm_path) as f:
|
||
|
l = len(f.readlines())
|
||
|
file = file.replace(".stm", "")
|
||
|
self._filelist.extend((file, line) for line in range(l))
|
||
|
# Create dict path for later read
|
||
|
self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"])
|
||
|
self._phoneme_dict = None
|
||
|
|
||
|
def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]:
|
||
|
"""Loads a TEDLIUM dataset sample given a file name and corresponding sentence name.
|
||
|
|
||
|
Args:
|
||
|
fileid (str): File id to identify both text and audio files corresponding to the sample
|
||
|
line (int): Line identifier for the sample inside the text file
|
||
|
path (str): Dataset root path
|
||
|
|
||
|
Returns:
|
||
|
(Tensor, int, str, int, int, int):
|
||
|
``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
|
||
|
"""
|
||
|
transcript_path = os.path.join(path, "stm", fileid)
|
||
|
with open(transcript_path + ".stm") as f:
|
||
|
transcript = f.readlines()[line]
|
||
|
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6)
|
||
|
|
||
|
wave_path = os.path.join(path, "sph", fileid)
|
||
|
waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time)
|
||
|
|
||
|
return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier)
|
||
|
|
||
|
def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]:
|
||
|
"""Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality
|
||
|
and load individual sentences from a full ted audio talk file.
|
||
|
|
||
|
Args:
|
||
|
path (str): Path to audio file
|
||
|
start_time (int): Time in seconds where the sample sentence stars
|
||
|
end_time (int): Time in seconds where the sample sentence finishes
|
||
|
sample_rate (float, optional): Sampling rate
|
||
|
|
||
|
Returns:
|
||
|
[Tensor, int]: Audio tensor representation and sample rate
|
||
|
"""
|
||
|
start_time = int(float(start_time) * sample_rate)
|
||
|
end_time = int(float(end_time) * sample_rate)
|
||
|
|
||
|
kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time}
|
||
|
|
||
|
return torchaudio.load(path, **kwargs)
|
||
|
|
||
|
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
|
||
|
"""Load the n-th sample from the dataset.
|
||
|
|
||
|
Args:
|
||
|
n (int): The index of the sample to be loaded
|
||
|
|
||
|
Returns:
|
||
|
Tuple of the following items;
|
||
|
|
||
|
Tensor:
|
||
|
Waveform
|
||
|
int:
|
||
|
Sample rate
|
||
|
str:
|
||
|
Transcript
|
||
|
int:
|
||
|
Talk ID
|
||
|
int:
|
||
|
Speaker ID
|
||
|
int:
|
||
|
Identifier
|
||
|
"""
|
||
|
fileid, line = self._filelist[n]
|
||
|
return self._load_tedlium_item(fileid, line, self._path)
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
"""TEDLIUM dataset custom function overwritting len default behaviour.
|
||
|
|
||
|
Returns:
|
||
|
int: TEDLIUM dataset length
|
||
|
"""
|
||
|
return len(self._filelist)
|
||
|
|
||
|
@property
|
||
|
def phoneme_dict(self):
|
||
|
"""dict[str, tuple[str]]: Phonemes. Mapping from word to tuple of phonemes.
|
||
|
Note that some words have empty phonemes.
|
||
|
"""
|
||
|
# Read phoneme dictionary
|
||
|
if not self._phoneme_dict:
|
||
|
self._phoneme_dict = {}
|
||
|
with open(self._dict_path, "r", encoding="utf-8") as f:
|
||
|
for line in f.readlines():
|
||
|
content = line.strip().split()
|
||
|
self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list
|
||
|
return self._phoneme_dict.copy()
|