187 lines
5.8 KiB
Python
187 lines
5.8 KiB
Python
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Iterable, List, Tuple, Union
|
|
|
|
from torch.utils.data import Dataset
|
|
from torchaudio._internal import download_url_to_file
|
|
|
|
|
|
_CHECKSUMS = {
|
|
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501
|
|
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501
|
|
}
|
|
_PUNCTUATIONS = {
|
|
"!EXCLAMATION-POINT",
|
|
'"CLOSE-QUOTE',
|
|
'"DOUBLE-QUOTE',
|
|
'"END-OF-QUOTE',
|
|
'"END-QUOTE',
|
|
'"IN-QUOTES',
|
|
'"QUOTE',
|
|
'"UNQUOTE',
|
|
"#HASH-MARK",
|
|
"#POUND-SIGN",
|
|
"#SHARP-SIGN",
|
|
"%PERCENT",
|
|
"&ERSAND",
|
|
"'END-INNER-QUOTE",
|
|
"'END-QUOTE",
|
|
"'INNER-QUOTE",
|
|
"'QUOTE",
|
|
"'SINGLE-QUOTE",
|
|
"(BEGIN-PARENS",
|
|
"(IN-PARENTHESES",
|
|
"(LEFT-PAREN",
|
|
"(OPEN-PARENTHESES",
|
|
"(PAREN",
|
|
"(PARENS",
|
|
"(PARENTHESES",
|
|
")CLOSE-PAREN",
|
|
")CLOSE-PARENTHESES",
|
|
")END-PAREN",
|
|
")END-PARENS",
|
|
")END-PARENTHESES",
|
|
")END-THE-PAREN",
|
|
")PAREN",
|
|
")PARENS",
|
|
")RIGHT-PAREN",
|
|
")UN-PARENTHESES",
|
|
"+PLUS",
|
|
",COMMA",
|
|
"--DASH",
|
|
"-DASH",
|
|
"-HYPHEN",
|
|
"...ELLIPSIS",
|
|
".DECIMAL",
|
|
".DOT",
|
|
".FULL-STOP",
|
|
".PERIOD",
|
|
".POINT",
|
|
"/SLASH",
|
|
":COLON",
|
|
";SEMI-COLON",
|
|
";SEMI-COLON(1)",
|
|
"?QUESTION-MARK",
|
|
"{BRACE",
|
|
"{LEFT-BRACE",
|
|
"{OPEN-BRACE",
|
|
"}CLOSE-BRACE",
|
|
"}RIGHT-BRACE",
|
|
}
|
|
|
|
|
|
def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
|
|
_alt_re = re.compile(r"\([0-9]+\)")
|
|
cmudict: List[Tuple[str, List[str]]] = []
|
|
for line in lines:
|
|
if not line or line.startswith(";;;"): # ignore comments
|
|
continue
|
|
|
|
word, phones = line.strip().split(" ")
|
|
if word in _PUNCTUATIONS:
|
|
if exclude_punctuations:
|
|
continue
|
|
# !EXCLAMATION-POINT -> !
|
|
# --DASH -> --
|
|
# ...ELLIPSIS -> ...
|
|
if word.startswith("..."):
|
|
word = "..."
|
|
elif word.startswith("--"):
|
|
word = "--"
|
|
else:
|
|
word = word[0]
|
|
|
|
# if a word have multiple pronunciations, there will be (number) appended to it
|
|
# for example, DATAPOINTS and DATAPOINTS(1),
|
|
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
|
|
word = re.sub(_alt_re, "", word)
|
|
phones = phones.split(" ")
|
|
cmudict.append((word, phones))
|
|
|
|
return cmudict
|
|
|
|
|
|
class CMUDict(Dataset):
|
|
"""*CMU Pronouncing Dictionary* :cite:`cmudict` (CMUDict) dataset.
|
|
|
|
Args:
|
|
root (str or Path): Path to the directory where the dataset is found or downloaded.
|
|
exclude_punctuations (bool, optional):
|
|
When enabled, exclude the pronounciation of punctuations, such as
|
|
`!EXCLAMATION-POINT` and `#HASH-MARK`.
|
|
download (bool, optional):
|
|
Whether to download the dataset if it is not found at root path. (default: ``False``).
|
|
url (str, optional):
|
|
The URL to download the dictionary from.
|
|
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
|
|
url_symbols (str, optional):
|
|
The URL to download the list of symbols from.
|
|
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[str, Path],
|
|
exclude_punctuations: bool = True,
|
|
*,
|
|
download: bool = False,
|
|
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
|
|
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
|
|
) -> None:
|
|
|
|
self.exclude_punctuations = exclude_punctuations
|
|
|
|
self._root_path = Path(root)
|
|
if not os.path.isdir(self._root_path):
|
|
raise RuntimeError(f"The root directory does not exist; {root}")
|
|
|
|
dict_file = self._root_path / os.path.basename(url)
|
|
symbol_file = self._root_path / os.path.basename(url_symbols)
|
|
if not os.path.exists(dict_file):
|
|
if not download:
|
|
raise RuntimeError(
|
|
"The dictionary file is not found in the following location. "
|
|
f"Set `download=True` to download it. {dict_file}"
|
|
)
|
|
checksum = _CHECKSUMS.get(url, None)
|
|
download_url_to_file(url, dict_file, checksum)
|
|
if not os.path.exists(symbol_file):
|
|
if not download:
|
|
raise RuntimeError(
|
|
"The symbol file is not found in the following location. "
|
|
f"Set `download=True` to download it. {symbol_file}"
|
|
)
|
|
checksum = _CHECKSUMS.get(url_symbols, None)
|
|
download_url_to_file(url_symbols, symbol_file, checksum)
|
|
|
|
with open(symbol_file, "r") as text:
|
|
self._symbols = [line.strip() for line in text.readlines()]
|
|
|
|
with open(dict_file, "r", encoding="latin-1") as text:
|
|
self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations)
|
|
|
|
def __getitem__(self, n: int) -> Tuple[str, List[str]]:
|
|
"""Load the n-th sample from the dataset.
|
|
|
|
Args:
|
|
n (int): The index of the sample to be loaded.
|
|
|
|
Returns:
|
|
Tuple of a word and its phonemes
|
|
|
|
str:
|
|
Word
|
|
List[str]:
|
|
Phonemes
|
|
"""
|
|
return self._dictionary[n]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._dictionary)
|
|
|
|
@property
|
|
def symbols(self) -> List[str]:
|
|
"""list[str]: A list of phonemes symbols, such as ``"AA"``, ``"AE"``, ``"AH"``."""
|
|
return self._symbols.copy()
|