import os from pathlib import Path from typing import Optional, Tuple, Union import torchaudio from torch import Tensor from torch.utils.data import Dataset from torchaudio._internal import download_url_to_file from torchaudio.datasets.utils import _extract_tar # The following lists prefixed with `filtered_` provide a filtered split # that: # # a. Mitigate a known issue with GTZAN (duplication) # # b. Provide a standard split for testing it against other # methods (e.g. the one in jordipons/sklearn-audio-transfer-learning). # # Those are used when GTZAN is initialised with the `filtered` keyword. # The split was taken from (github) jordipons/sklearn-audio-transfer-learning. gtzan_genres = [ "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock", ] filtered_test = [ "blues.00012", "blues.00013", "blues.00014", "blues.00015", "blues.00016", "blues.00017", "blues.00018", "blues.00019", "blues.00020", "blues.00021", "blues.00022", "blues.00023", "blues.00024", "blues.00025", "blues.00026", "blues.00027", "blues.00028", "blues.00061", "blues.00062", "blues.00063", "blues.00064", "blues.00065", "blues.00066", "blues.00067", "blues.00068", "blues.00069", "blues.00070", "blues.00071", "blues.00072", "blues.00098", "blues.00099", "classical.00011", "classical.00012", "classical.00013", "classical.00014", "classical.00015", "classical.00016", "classical.00017", "classical.00018", "classical.00019", "classical.00020", "classical.00021", "classical.00022", "classical.00023", "classical.00024", "classical.00025", "classical.00026", "classical.00027", "classical.00028", "classical.00029", "classical.00034", "classical.00035", "classical.00036", "classical.00037", "classical.00038", "classical.00039", "classical.00040", "classical.00041", "classical.00049", "classical.00077", "classical.00078", "classical.00079", "country.00030", "country.00031", "country.00032", "country.00033", "country.00034", "country.00035", "country.00036", "country.00037", "country.00038", "country.00039", "country.00040", "country.00043", "country.00044", "country.00046", "country.00047", "country.00048", "country.00050", "country.00051", "country.00053", "country.00054", "country.00055", "country.00056", "country.00057", "country.00058", "country.00059", "country.00060", "country.00061", "country.00062", "country.00063", "country.00064", "disco.00001", "disco.00021", "disco.00058", "disco.00062", "disco.00063", "disco.00064", "disco.00065", "disco.00066", "disco.00069", "disco.00076", "disco.00077", "disco.00078", "disco.00079", "disco.00080", "disco.00081", "disco.00082", "disco.00083", "disco.00084", "disco.00085", "disco.00086", "disco.00087", "disco.00088", "disco.00091", "disco.00092", "disco.00093", "disco.00094", "disco.00096", "disco.00097", "disco.00099", "hiphop.00000", "hiphop.00026", "hiphop.00027", "hiphop.00030", "hiphop.00040", "hiphop.00043", "hiphop.00044", "hiphop.00045", "hiphop.00051", "hiphop.00052", "hiphop.00053", "hiphop.00054", "hiphop.00062", "hiphop.00063", "hiphop.00064", "hiphop.00065", "hiphop.00066", "hiphop.00067", "hiphop.00068", "hiphop.00069", "hiphop.00070", "hiphop.00071", "hiphop.00072", "hiphop.00073", "hiphop.00074", "hiphop.00075", "hiphop.00099", "jazz.00073", "jazz.00074", "jazz.00075", "jazz.00076", "jazz.00077", "jazz.00078", "jazz.00079", "jazz.00080", "jazz.00081", "jazz.00082", "jazz.00083", "jazz.00084", "jazz.00085", "jazz.00086", "jazz.00087", "jazz.00088", "jazz.00089", "jazz.00090", "jazz.00091", "jazz.00092", "jazz.00093", "jazz.00094", "jazz.00095", "jazz.00096", "jazz.00097", "jazz.00098", "jazz.00099", "metal.00012", "metal.00013", "metal.00014", "metal.00015", "metal.00022", "metal.00023", "metal.00025", "metal.00026", "metal.00027", "metal.00028", "metal.00029", "metal.00030", "metal.00031", "metal.00032", "metal.00033", "metal.00038", "metal.00039", "metal.00067", "metal.00070", "metal.00073", "metal.00074", "metal.00075", "metal.00078", "metal.00083", "metal.00085", "metal.00087", "metal.00088", "pop.00000", "pop.00001", "pop.00013", "pop.00014", "pop.00043", "pop.00063", "pop.00064", "pop.00065", "pop.00066", "pop.00069", "pop.00070", "pop.00071", "pop.00072", "pop.00073", "pop.00074", "pop.00075", "pop.00076", "pop.00077", "pop.00078", "pop.00079", "pop.00082", "pop.00088", "pop.00089", "pop.00090", "pop.00091", "pop.00092", "pop.00093", "pop.00094", "pop.00095", "pop.00096", "reggae.00034", "reggae.00035", "reggae.00036", "reggae.00037", "reggae.00038", "reggae.00039", "reggae.00040", "reggae.00046", "reggae.00047", "reggae.00048", "reggae.00052", "reggae.00053", "reggae.00064", "reggae.00065", "reggae.00066", "reggae.00067", "reggae.00068", "reggae.00071", "reggae.00079", "reggae.00082", "reggae.00083", "reggae.00084", "reggae.00087", "reggae.00088", "reggae.00089", "reggae.00090", "rock.00010", "rock.00011", "rock.00012", "rock.00013", "rock.00014", "rock.00015", "rock.00027", "rock.00028", "rock.00029", "rock.00030", "rock.00031", "rock.00032", "rock.00033", "rock.00034", "rock.00035", "rock.00036", "rock.00037", "rock.00039", "rock.00040", "rock.00041", "rock.00042", "rock.00043", "rock.00044", "rock.00045", "rock.00046", "rock.00047", "rock.00048", "rock.00086", "rock.00087", "rock.00088", "rock.00089", "rock.00090", ] filtered_train = [ "blues.00029", "blues.00030", "blues.00031", "blues.00032", "blues.00033", "blues.00034", "blues.00035", "blues.00036", "blues.00037", "blues.00038", "blues.00039", "blues.00040", "blues.00041", "blues.00042", "blues.00043", "blues.00044", "blues.00045", "blues.00046", "blues.00047", "blues.00048", "blues.00049", "blues.00073", "blues.00074", "blues.00075", "blues.00076", "blues.00077", "blues.00078", "blues.00079", "blues.00080", "blues.00081", "blues.00082", "blues.00083", "blues.00084", "blues.00085", "blues.00086", "blues.00087", "blues.00088", "blues.00089", "blues.00090", "blues.00091", "blues.00092", "blues.00093", "blues.00094", "blues.00095", "blues.00096", "blues.00097", "classical.00030", "classical.00031", "classical.00032", "classical.00033", "classical.00043", "classical.00044", "classical.00045", "classical.00046", "classical.00047", "classical.00048", "classical.00050", "classical.00051", "classical.00052", "classical.00053", "classical.00054", "classical.00055", "classical.00056", "classical.00057", "classical.00058", "classical.00059", "classical.00060", "classical.00061", "classical.00062", "classical.00063", "classical.00064", "classical.00065", "classical.00066", "classical.00067", "classical.00080", "classical.00081", "classical.00082", "classical.00083", "classical.00084", "classical.00085", "classical.00086", "classical.00087", "classical.00088", "classical.00089", "classical.00090", "classical.00091", "classical.00092", "classical.00093", "classical.00094", "classical.00095", "classical.00096", "classical.00097", "classical.00098", "classical.00099", "country.00019", "country.00020", "country.00021", "country.00022", "country.00023", "country.00024", "country.00025", "country.00026", "country.00028", "country.00029", "country.00065", "country.00066", "country.00067", "country.00068", "country.00069", "country.00070", "country.00071", "country.00072", "country.00073", "country.00074", "country.00075", "country.00076", "country.00077", "country.00078", "country.00079", "country.00080", "country.00081", "country.00082", "country.00083", "country.00084", "country.00085", "country.00086", "country.00087", "country.00088", "country.00089", "country.00090", "country.00091", "country.00092", "country.00093", "country.00094", "country.00095", "country.00096", "country.00097", "country.00098", "country.00099", "disco.00005", "disco.00015", "disco.00016", "disco.00017", "disco.00018", "disco.00019", "disco.00020", "disco.00022", "disco.00023", "disco.00024", "disco.00025", "disco.00026", "disco.00027", "disco.00028", "disco.00029", "disco.00030", "disco.00031", "disco.00032", "disco.00033", "disco.00034", "disco.00035", "disco.00036", "disco.00037", "disco.00039", "disco.00040", "disco.00041", "disco.00042", "disco.00043", "disco.00044", "disco.00045", "disco.00047", "disco.00049", "disco.00053", "disco.00054", "disco.00056", "disco.00057", "disco.00059", "disco.00061", "disco.00070", "disco.00073", "disco.00074", "disco.00089", "hiphop.00002", "hiphop.00003", "hiphop.00004", "hiphop.00005", "hiphop.00006", "hiphop.00007", "hiphop.00008", "hiphop.00009", "hiphop.00010", "hiphop.00011", "hiphop.00012", "hiphop.00013", "hiphop.00014", "hiphop.00015", "hiphop.00016", "hiphop.00017", "hiphop.00018", "hiphop.00019", "hiphop.00020", "hiphop.00021", "hiphop.00022", "hiphop.00023", "hiphop.00024", "hiphop.00025", "hiphop.00028", "hiphop.00029", "hiphop.00031", "hiphop.00032", "hiphop.00033", "hiphop.00034", "hiphop.00035", "hiphop.00036", "hiphop.00037", "hiphop.00038", "hiphop.00041", "hiphop.00042", "hiphop.00055", "hiphop.00056", "hiphop.00057", "hiphop.00058", "hiphop.00059", "hiphop.00060", "hiphop.00061", "hiphop.00077", "hiphop.00078", "hiphop.00079", "hiphop.00080", "jazz.00000", "jazz.00001", "jazz.00011", "jazz.00012", "jazz.00013", "jazz.00014", "jazz.00015", "jazz.00016", "jazz.00017", "jazz.00018", "jazz.00019", "jazz.00020", "jazz.00021", "jazz.00022", "jazz.00023", "jazz.00024", "jazz.00041", "jazz.00047", "jazz.00048", "jazz.00049", "jazz.00050", "jazz.00051", "jazz.00052", "jazz.00053", "jazz.00054", "jazz.00055", "jazz.00056", "jazz.00057", "jazz.00058", "jazz.00059", "jazz.00060", "jazz.00061", "jazz.00062", "jazz.00063", "jazz.00064", "jazz.00065", "jazz.00066", "jazz.00067", "jazz.00068", "jazz.00069", "jazz.00070", "jazz.00071", "jazz.00072", "metal.00002", "metal.00003", "metal.00005", "metal.00021", "metal.00024", "metal.00035", "metal.00046", "metal.00047", "metal.00048", "metal.00049", "metal.00050", "metal.00051", "metal.00052", "metal.00053", "metal.00054", "metal.00055", "metal.00056", "metal.00057", "metal.00059", "metal.00060", "metal.00061", "metal.00062", "metal.00063", "metal.00064", "metal.00065", "metal.00066", "metal.00069", "metal.00071", "metal.00072", "metal.00079", "metal.00080", "metal.00084", "metal.00086", "metal.00089", "metal.00090", "metal.00091", "metal.00092", "metal.00093", "metal.00094", "metal.00095", "metal.00096", "metal.00097", "metal.00098", "metal.00099", "pop.00002", "pop.00003", "pop.00004", "pop.00005", "pop.00006", "pop.00007", "pop.00008", "pop.00009", "pop.00011", "pop.00012", "pop.00016", "pop.00017", "pop.00018", "pop.00019", "pop.00020", "pop.00023", "pop.00024", "pop.00025", "pop.00026", "pop.00027", "pop.00028", "pop.00029", "pop.00031", "pop.00032", "pop.00033", "pop.00034", "pop.00035", "pop.00036", "pop.00038", "pop.00039", "pop.00040", "pop.00041", "pop.00042", "pop.00044", "pop.00046", "pop.00049", "pop.00050", "pop.00080", "pop.00097", "pop.00098", "pop.00099", "reggae.00000", "reggae.00001", "reggae.00002", "reggae.00004", "reggae.00006", "reggae.00009", "reggae.00011", "reggae.00012", "reggae.00014", "reggae.00015", "reggae.00016", "reggae.00017", "reggae.00018", "reggae.00019", "reggae.00020", "reggae.00021", "reggae.00022", "reggae.00023", "reggae.00024", "reggae.00025", "reggae.00026", "reggae.00027", "reggae.00028", "reggae.00029", "reggae.00030", "reggae.00031", "reggae.00032", "reggae.00042", "reggae.00043", "reggae.00044", "reggae.00045", "reggae.00049", "reggae.00050", "reggae.00051", "reggae.00054", "reggae.00055", "reggae.00056", "reggae.00057", "reggae.00058", "reggae.00059", "reggae.00060", "reggae.00063", "reggae.00069", "rock.00000", "rock.00001", "rock.00002", "rock.00003", "rock.00004", "rock.00005", "rock.00006", "rock.00007", "rock.00008", "rock.00009", "rock.00016", "rock.00017", "rock.00018", "rock.00019", "rock.00020", "rock.00021", "rock.00022", "rock.00023", "rock.00024", "rock.00025", "rock.00026", "rock.00057", "rock.00058", "rock.00059", "rock.00060", "rock.00061", "rock.00062", "rock.00063", "rock.00064", "rock.00065", "rock.00066", "rock.00067", "rock.00068", "rock.00069", "rock.00070", "rock.00091", "rock.00092", "rock.00093", "rock.00094", "rock.00095", "rock.00096", "rock.00097", "rock.00098", "rock.00099", ] filtered_valid = [ "blues.00000", "blues.00001", "blues.00002", "blues.00003", "blues.00004", "blues.00005", "blues.00006", "blues.00007", "blues.00008", "blues.00009", "blues.00010", "blues.00011", "blues.00050", "blues.00051", "blues.00052", "blues.00053", "blues.00054", "blues.00055", "blues.00056", "blues.00057", "blues.00058", "blues.00059", "blues.00060", "classical.00000", "classical.00001", "classical.00002", "classical.00003", "classical.00004", "classical.00005", "classical.00006", "classical.00007", "classical.00008", "classical.00009", "classical.00010", "classical.00068", "classical.00069", "classical.00070", "classical.00071", "classical.00072", "classical.00073", "classical.00074", "classical.00075", "classical.00076", "country.00000", "country.00001", "country.00002", "country.00003", "country.00004", "country.00005", "country.00006", "country.00007", "country.00009", "country.00010", "country.00011", "country.00012", "country.00013", "country.00014", "country.00015", "country.00016", "country.00017", "country.00018", "country.00027", "country.00041", "country.00042", "country.00045", "country.00049", "disco.00000", "disco.00002", "disco.00003", "disco.00004", "disco.00006", "disco.00007", "disco.00008", "disco.00009", "disco.00010", "disco.00011", "disco.00012", "disco.00013", "disco.00014", "disco.00046", "disco.00048", "disco.00052", "disco.00067", "disco.00068", "disco.00072", "disco.00075", "disco.00090", "disco.00095", "hiphop.00081", "hiphop.00082", "hiphop.00083", "hiphop.00084", "hiphop.00085", "hiphop.00086", "hiphop.00087", "hiphop.00088", "hiphop.00089", "hiphop.00090", "hiphop.00091", "hiphop.00092", "hiphop.00093", "hiphop.00094", "hiphop.00095", "hiphop.00096", "hiphop.00097", "hiphop.00098", "jazz.00002", "jazz.00003", "jazz.00004", "jazz.00005", "jazz.00006", "jazz.00007", "jazz.00008", "jazz.00009", "jazz.00010", "jazz.00025", "jazz.00026", "jazz.00027", "jazz.00028", "jazz.00029", "jazz.00030", "jazz.00031", "jazz.00032", "metal.00000", "metal.00001", "metal.00006", "metal.00007", "metal.00008", "metal.00009", "metal.00010", "metal.00011", "metal.00016", "metal.00017", "metal.00018", "metal.00019", "metal.00020", "metal.00036", "metal.00037", "metal.00068", "metal.00076", "metal.00077", "metal.00081", "metal.00082", "pop.00010", "pop.00053", "pop.00055", "pop.00058", "pop.00059", "pop.00060", "pop.00061", "pop.00062", "pop.00081", "pop.00083", "pop.00084", "pop.00085", "pop.00086", "reggae.00061", "reggae.00062", "reggae.00070", "reggae.00072", "reggae.00074", "reggae.00076", "reggae.00077", "reggae.00078", "reggae.00085", "reggae.00092", "reggae.00093", "reggae.00094", "reggae.00095", "reggae.00096", "reggae.00097", "reggae.00098", "reggae.00099", "rock.00038", "rock.00049", "rock.00050", "rock.00051", "rock.00052", "rock.00053", "rock.00054", "rock.00055", "rock.00056", "rock.00071", "rock.00072", "rock.00073", "rock.00074", "rock.00075", "rock.00076", "rock.00077", "rock.00078", "rock.00079", "rock.00080", "rock.00081", "rock.00082", "rock.00083", "rock.00084", "rock.00085", ] URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz" FOLDER_IN_ARCHIVE = "genres" _CHECKSUMS = { "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6" } def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]: """ Loads a file from the dataset and returns the raw waveform as a Torch Tensor, its sample rate as an integer, and its genre as a string. """ # Filenames are of the form label.id, e.g. blues.00078 label, _ = fileid.split(".") # Read wav file_audio = os.path.join(path, label, fileid + ext_audio) waveform, sample_rate = torchaudio.load(file_audio) return waveform, sample_rate, label class GTZAN(Dataset): """*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset. Note: Please see http://marsyas.info/downloads/datasets.html if you are planning to use this dataset to publish results. Note: As of October 2022, the download link is not currently working. Setting ``download=True`` in GTZAN dataset will result in a URL connection error. Args: root (str or Path): Path to the directory where the dataset is found or downloaded. url (str, optional): The URL to download the dataset from. (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``) folder_in_archive (str, optional): The top-level directory of the dataset. download (bool, optional): Whether to download the dataset if it is not found at root path. (default: ``False``). subset (str or None, optional): Which subset of the dataset to use. One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``. If ``None``, the entire dataset is used. (default: ``None``). """ _ext_audio = ".wav" def __init__( self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, subset: Optional[str] = None, ) -> None: # super(GTZAN, self).__init__() # Get string representation of 'root' in case Path object is passed root = os.fspath(root) self.root = root self.url = url self.folder_in_archive = folder_in_archive self.download = download self.subset = subset if subset is not None and subset not in ["training", "validation", "testing"]: raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].") archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url_to_file(url, archive, hash_prefix=checksum) _extract_tar(archive) if not os.path.isdir(self._path): raise RuntimeError("Dataset not found. Please use `download=True` to download it.") if self.subset is None: # Check every subdirectory under dataset root # which has the same name as the genres in # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.) # This lets users remove or move around song files, # useful when e.g. they want to use only some of the files # in a genre or want to label other files with a different # genre. self._walker = [] root = os.path.expanduser(self._path) for directory in gtzan_genres: fulldir = os.path.join(root, directory) if not os.path.exists(fulldir): continue songs_in_genre = os.listdir(fulldir) songs_in_genre.sort() for fname in songs_in_genre: name, ext = os.path.splitext(fname) if ext.lower() == ".wav" and "." in name: # Check whether the file is of the form # `gtzan_genre`.`5 digit number`.wav genre, num = name.split(".") if genre in gtzan_genres and len(num) == 5 and num.isdigit(): self._walker.append(name) else: if self.subset == "training": self._walker = filtered_train elif self.subset == "validation": self._walker = filtered_valid elif self.subset == "testing": self._walker = filtered_test def __getitem__(self, n: int) -> Tuple[Tensor, int, str]: """Load the n-th sample from the dataset. Args: n (int): The index of the sample to be loaded Returns: Tuple of the following items; Tensor: Waveform int: Sample rate str: Label """ fileid = self._walker[n] item = load_gtzan_item(fileid, self._path, self._ext_audio) waveform, sample_rate, label = item return waveform, sample_rate, label def __len__(self) -> int: return len(self._walker)