Traktor/myenv/Lib/site-packages/torchaudio/datasets/iemocap.py
2024-05-26 05:12:46 +02:00

148 lines
4.8 KiB
Python

import os
import re
from pathlib import Path
from typing import Optional, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform
_SAMPLE_RATE = 16000
def _get_wavs_paths(data_dir):
wav_dir = data_dir / "sentences" / "wav"
wav_paths = sorted(str(p) for p in wav_dir.glob("*/*.wav"))
relative_paths = []
for wav_path in wav_paths:
start = wav_path.find("Session")
wav_path = wav_path[start:]
relative_paths.append(wav_path)
return relative_paths
class IEMOCAP(Dataset):
"""*IEMOCAP* :cite:`iemocap` dataset.
Args:
root (str or Path): Root directory where the dataset's top level directory is found
sessions (Tuple[int]): Tuple of sessions (1-5) to use. (Default: ``(1, 2, 3, 4, 5)``)
utterance_type (str or None, optional): Which type(s) of utterances to include in the dataset.
Options: ("scripted", "improvised", ``None``). If ``None``, both scripted and improvised
data are used.
"""
def __init__(
self,
root: Union[str, Path],
sessions: Tuple[str] = (1, 2, 3, 4, 5),
utterance_type: Optional[str] = None,
):
root = Path(root)
self._path = root / "IEMOCAP"
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found.")
if utterance_type not in ["scripted", "improvised", None]:
raise ValueError("utterance_type must be one of ['scripted', 'improvised', or None]")
all_data = []
self.data = []
self.mapping = {}
for session in sessions:
session_name = f"Session{session}"
session_dir = self._path / session_name
# get wav paths
wav_paths = _get_wavs_paths(session_dir)
for wav_path in wav_paths:
wav_stem = str(Path(wav_path).stem)
all_data.append(wav_stem)
# add labels
label_dir = session_dir / "dialog" / "EmoEvaluation"
query = "*.txt"
if utterance_type == "scripted":
query = "*script*.txt"
elif utterance_type == "improvised":
query = "*impro*.txt"
label_paths = label_dir.glob(query)
for label_path in label_paths:
with open(label_path, "r") as f:
for line in f:
if not line.startswith("["):
continue
line = re.split("[\t\n]", line)
wav_stem = line[1]
label = line[2]
if wav_stem not in all_data:
continue
if label not in ["neu", "hap", "ang", "sad", "exc", "fru"]:
continue
self.mapping[wav_stem] = {}
self.mapping[wav_stem]["label"] = label
for wav_path in wav_paths:
wav_stem = str(Path(wav_path).stem)
if wav_stem in self.mapping:
self.data.append(wav_stem)
self.mapping[wav_stem]["path"] = wav_path
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:meth:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
str:
Path to audio
int:
Sample rate
str:
File name
str:
Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``)
str:
Speaker
"""
wav_stem = self.data[n]
wav_path = self.mapping[wav_stem]["path"]
label = self.mapping[wav_stem]["label"]
speaker = wav_stem.split("_")[0]
return (wav_path, _SAMPLE_RATE, wav_stem, label, speaker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
Tensor:
Waveform
int:
Sample rate
str:
File name
str:
Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``)
str:
Speaker
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self):
return len(self.data)