import os import re from pathlib import Path from typing import Iterable, List, Tuple, Union from torch.utils.data import Dataset from torchaudio._internal import download_url_to_file _CHECKSUMS = { "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501 "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501 } _PUNCTUATIONS = { "!EXCLAMATION-POINT", '"CLOSE-QUOTE', '"DOUBLE-QUOTE', '"END-OF-QUOTE', '"END-QUOTE', '"IN-QUOTES', '"QUOTE', '"UNQUOTE', "#HASH-MARK", "#POUND-SIGN", "#SHARP-SIGN", "%PERCENT", "&ERSAND", "'END-INNER-QUOTE", "'END-QUOTE", "'INNER-QUOTE", "'QUOTE", "'SINGLE-QUOTE", "(BEGIN-PARENS", "(IN-PARENTHESES", "(LEFT-PAREN", "(OPEN-PARENTHESES", "(PAREN", "(PARENS", "(PARENTHESES", ")CLOSE-PAREN", ")CLOSE-PARENTHESES", ")END-PAREN", ")END-PARENS", ")END-PARENTHESES", ")END-THE-PAREN", ")PAREN", ")PARENS", ")RIGHT-PAREN", ")UN-PARENTHESES", "+PLUS", ",COMMA", "--DASH", "-DASH", "-HYPHEN", "...ELLIPSIS", ".DECIMAL", ".DOT", ".FULL-STOP", ".PERIOD", ".POINT", "/SLASH", ":COLON", ";SEMI-COLON", ";SEMI-COLON(1)", "?QUESTION-MARK", "{BRACE", "{LEFT-BRACE", "{OPEN-BRACE", "}CLOSE-BRACE", "}RIGHT-BRACE", } def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]: _alt_re = re.compile(r"\([0-9]+\)") cmudict: List[Tuple[str, List[str]]] = [] for line in lines: if not line or line.startswith(";;;"): # ignore comments continue word, phones = line.strip().split(" ") if word in _PUNCTUATIONS: if exclude_punctuations: continue # !EXCLAMATION-POINT -> ! # --DASH -> -- # ...ELLIPSIS -> ... if word.startswith("..."): word = "..." elif word.startswith("--"): word = "--" else: word = word[0] # if a word have multiple pronunciations, there will be (number) appended to it # for example, DATAPOINTS and DATAPOINTS(1), # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS word = re.sub(_alt_re, "", word) phones = phones.split(" ") cmudict.append((word, phones)) return cmudict class CMUDict(Dataset): """*CMU Pronouncing Dictionary* :cite:`cmudict` (CMUDict) dataset. Args: root (str or Path): Path to the directory where the dataset is found or downloaded. exclude_punctuations (bool, optional): When enabled, exclude the pronounciation of punctuations, such as `!EXCLAMATION-POINT` and `#HASH-MARK`. download (bool, optional): Whether to download the dataset if it is not found at root path. (default: ``False``). url (str, optional): The URL to download the dictionary from. (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``) url_symbols (str, optional): The URL to download the list of symbols from. (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``) """ def __init__( self, root: Union[str, Path], exclude_punctuations: bool = True, *, download: bool = False, url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b", url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols", ) -> None: self.exclude_punctuations = exclude_punctuations self._root_path = Path(root) if not os.path.isdir(self._root_path): raise RuntimeError(f"The root directory does not exist; {root}") dict_file = self._root_path / os.path.basename(url) symbol_file = self._root_path / os.path.basename(url_symbols) if not os.path.exists(dict_file): if not download: raise RuntimeError( "The dictionary file is not found in the following location. " f"Set `download=True` to download it. {dict_file}" ) checksum = _CHECKSUMS.get(url, None) download_url_to_file(url, dict_file, checksum) if not os.path.exists(symbol_file): if not download: raise RuntimeError( "The symbol file is not found in the following location. " f"Set `download=True` to download it. {symbol_file}" ) checksum = _CHECKSUMS.get(url_symbols, None) download_url_to_file(url_symbols, symbol_file, checksum) with open(symbol_file, "r") as text: self._symbols = [line.strip() for line in text.readlines()] with open(dict_file, "r", encoding="latin-1") as text: self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations) def __getitem__(self, n: int) -> Tuple[str, List[str]]: """Load the n-th sample from the dataset. Args: n (int): The index of the sample to be loaded. Returns: Tuple of a word and its phonemes str: Word List[str]: Phonemes """ return self._dictionary[n] def __len__(self) -> int: return len(self._dictionary) @property def symbols(self) -> List[str]: """list[str]: A list of phonemes symbols, such as ``"AA"``, ``"AE"``, ``"AH"``.""" return self._symbols.copy()