1119 lines
24 KiB
Python
1119 lines
24 KiB
Python
|
import os
|
||
|
from pathlib import Path
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import torchaudio
|
||
|
from torch import Tensor
|
||
|
from torch.utils.data import Dataset
|
||
|
from torchaudio._internal import download_url_to_file
|
||
|
from torchaudio.datasets.utils import _extract_tar
|
||
|
|
||
|
# The following lists prefixed with `filtered_` provide a filtered split
|
||
|
# that:
|
||
|
#
|
||
|
# a. Mitigate a known issue with GTZAN (duplication)
|
||
|
#
|
||
|
# b. Provide a standard split for testing it against other
|
||
|
# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
|
||
|
#
|
||
|
# Those are used when GTZAN is initialised with the `filtered` keyword.
|
||
|
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
|
||
|
|
||
|
gtzan_genres = [
|
||
|
"blues",
|
||
|
"classical",
|
||
|
"country",
|
||
|
"disco",
|
||
|
"hiphop",
|
||
|
"jazz",
|
||
|
"metal",
|
||
|
"pop",
|
||
|
"reggae",
|
||
|
"rock",
|
||
|
]
|
||
|
|
||
|
filtered_test = [
|
||
|
"blues.00012",
|
||
|
"blues.00013",
|
||
|
"blues.00014",
|
||
|
"blues.00015",
|
||
|
"blues.00016",
|
||
|
"blues.00017",
|
||
|
"blues.00018",
|
||
|
"blues.00019",
|
||
|
"blues.00020",
|
||
|
"blues.00021",
|
||
|
"blues.00022",
|
||
|
"blues.00023",
|
||
|
"blues.00024",
|
||
|
"blues.00025",
|
||
|
"blues.00026",
|
||
|
"blues.00027",
|
||
|
"blues.00028",
|
||
|
"blues.00061",
|
||
|
"blues.00062",
|
||
|
"blues.00063",
|
||
|
"blues.00064",
|
||
|
"blues.00065",
|
||
|
"blues.00066",
|
||
|
"blues.00067",
|
||
|
"blues.00068",
|
||
|
"blues.00069",
|
||
|
"blues.00070",
|
||
|
"blues.00071",
|
||
|
"blues.00072",
|
||
|
"blues.00098",
|
||
|
"blues.00099",
|
||
|
"classical.00011",
|
||
|
"classical.00012",
|
||
|
"classical.00013",
|
||
|
"classical.00014",
|
||
|
"classical.00015",
|
||
|
"classical.00016",
|
||
|
"classical.00017",
|
||
|
"classical.00018",
|
||
|
"classical.00019",
|
||
|
"classical.00020",
|
||
|
"classical.00021",
|
||
|
"classical.00022",
|
||
|
"classical.00023",
|
||
|
"classical.00024",
|
||
|
"classical.00025",
|
||
|
"classical.00026",
|
||
|
"classical.00027",
|
||
|
"classical.00028",
|
||
|
"classical.00029",
|
||
|
"classical.00034",
|
||
|
"classical.00035",
|
||
|
"classical.00036",
|
||
|
"classical.00037",
|
||
|
"classical.00038",
|
||
|
"classical.00039",
|
||
|
"classical.00040",
|
||
|
"classical.00041",
|
||
|
"classical.00049",
|
||
|
"classical.00077",
|
||
|
"classical.00078",
|
||
|
"classical.00079",
|
||
|
"country.00030",
|
||
|
"country.00031",
|
||
|
"country.00032",
|
||
|
"country.00033",
|
||
|
"country.00034",
|
||
|
"country.00035",
|
||
|
"country.00036",
|
||
|
"country.00037",
|
||
|
"country.00038",
|
||
|
"country.00039",
|
||
|
"country.00040",
|
||
|
"country.00043",
|
||
|
"country.00044",
|
||
|
"country.00046",
|
||
|
"country.00047",
|
||
|
"country.00048",
|
||
|
"country.00050",
|
||
|
"country.00051",
|
||
|
"country.00053",
|
||
|
"country.00054",
|
||
|
"country.00055",
|
||
|
"country.00056",
|
||
|
"country.00057",
|
||
|
"country.00058",
|
||
|
"country.00059",
|
||
|
"country.00060",
|
||
|
"country.00061",
|
||
|
"country.00062",
|
||
|
"country.00063",
|
||
|
"country.00064",
|
||
|
"disco.00001",
|
||
|
"disco.00021",
|
||
|
"disco.00058",
|
||
|
"disco.00062",
|
||
|
"disco.00063",
|
||
|
"disco.00064",
|
||
|
"disco.00065",
|
||
|
"disco.00066",
|
||
|
"disco.00069",
|
||
|
"disco.00076",
|
||
|
"disco.00077",
|
||
|
"disco.00078",
|
||
|
"disco.00079",
|
||
|
"disco.00080",
|
||
|
"disco.00081",
|
||
|
"disco.00082",
|
||
|
"disco.00083",
|
||
|
"disco.00084",
|
||
|
"disco.00085",
|
||
|
"disco.00086",
|
||
|
"disco.00087",
|
||
|
"disco.00088",
|
||
|
"disco.00091",
|
||
|
"disco.00092",
|
||
|
"disco.00093",
|
||
|
"disco.00094",
|
||
|
"disco.00096",
|
||
|
"disco.00097",
|
||
|
"disco.00099",
|
||
|
"hiphop.00000",
|
||
|
"hiphop.00026",
|
||
|
"hiphop.00027",
|
||
|
"hiphop.00030",
|
||
|
"hiphop.00040",
|
||
|
"hiphop.00043",
|
||
|
"hiphop.00044",
|
||
|
"hiphop.00045",
|
||
|
"hiphop.00051",
|
||
|
"hiphop.00052",
|
||
|
"hiphop.00053",
|
||
|
"hiphop.00054",
|
||
|
"hiphop.00062",
|
||
|
"hiphop.00063",
|
||
|
"hiphop.00064",
|
||
|
"hiphop.00065",
|
||
|
"hiphop.00066",
|
||
|
"hiphop.00067",
|
||
|
"hiphop.00068",
|
||
|
"hiphop.00069",
|
||
|
"hiphop.00070",
|
||
|
"hiphop.00071",
|
||
|
"hiphop.00072",
|
||
|
"hiphop.00073",
|
||
|
"hiphop.00074",
|
||
|
"hiphop.00075",
|
||
|
"hiphop.00099",
|
||
|
"jazz.00073",
|
||
|
"jazz.00074",
|
||
|
"jazz.00075",
|
||
|
"jazz.00076",
|
||
|
"jazz.00077",
|
||
|
"jazz.00078",
|
||
|
"jazz.00079",
|
||
|
"jazz.00080",
|
||
|
"jazz.00081",
|
||
|
"jazz.00082",
|
||
|
"jazz.00083",
|
||
|
"jazz.00084",
|
||
|
"jazz.00085",
|
||
|
"jazz.00086",
|
||
|
"jazz.00087",
|
||
|
"jazz.00088",
|
||
|
"jazz.00089",
|
||
|
"jazz.00090",
|
||
|
"jazz.00091",
|
||
|
"jazz.00092",
|
||
|
"jazz.00093",
|
||
|
"jazz.00094",
|
||
|
"jazz.00095",
|
||
|
"jazz.00096",
|
||
|
"jazz.00097",
|
||
|
"jazz.00098",
|
||
|
"jazz.00099",
|
||
|
"metal.00012",
|
||
|
"metal.00013",
|
||
|
"metal.00014",
|
||
|
"metal.00015",
|
||
|
"metal.00022",
|
||
|
"metal.00023",
|
||
|
"metal.00025",
|
||
|
"metal.00026",
|
||
|
"metal.00027",
|
||
|
"metal.00028",
|
||
|
"metal.00029",
|
||
|
"metal.00030",
|
||
|
"metal.00031",
|
||
|
"metal.00032",
|
||
|
"metal.00033",
|
||
|
"metal.00038",
|
||
|
"metal.00039",
|
||
|
"metal.00067",
|
||
|
"metal.00070",
|
||
|
"metal.00073",
|
||
|
"metal.00074",
|
||
|
"metal.00075",
|
||
|
"metal.00078",
|
||
|
"metal.00083",
|
||
|
"metal.00085",
|
||
|
"metal.00087",
|
||
|
"metal.00088",
|
||
|
"pop.00000",
|
||
|
"pop.00001",
|
||
|
"pop.00013",
|
||
|
"pop.00014",
|
||
|
"pop.00043",
|
||
|
"pop.00063",
|
||
|
"pop.00064",
|
||
|
"pop.00065",
|
||
|
"pop.00066",
|
||
|
"pop.00069",
|
||
|
"pop.00070",
|
||
|
"pop.00071",
|
||
|
"pop.00072",
|
||
|
"pop.00073",
|
||
|
"pop.00074",
|
||
|
"pop.00075",
|
||
|
"pop.00076",
|
||
|
"pop.00077",
|
||
|
"pop.00078",
|
||
|
"pop.00079",
|
||
|
"pop.00082",
|
||
|
"pop.00088",
|
||
|
"pop.00089",
|
||
|
"pop.00090",
|
||
|
"pop.00091",
|
||
|
"pop.00092",
|
||
|
"pop.00093",
|
||
|
"pop.00094",
|
||
|
"pop.00095",
|
||
|
"pop.00096",
|
||
|
"reggae.00034",
|
||
|
"reggae.00035",
|
||
|
"reggae.00036",
|
||
|
"reggae.00037",
|
||
|
"reggae.00038",
|
||
|
"reggae.00039",
|
||
|
"reggae.00040",
|
||
|
"reggae.00046",
|
||
|
"reggae.00047",
|
||
|
"reggae.00048",
|
||
|
"reggae.00052",
|
||
|
"reggae.00053",
|
||
|
"reggae.00064",
|
||
|
"reggae.00065",
|
||
|
"reggae.00066",
|
||
|
"reggae.00067",
|
||
|
"reggae.00068",
|
||
|
"reggae.00071",
|
||
|
"reggae.00079",
|
||
|
"reggae.00082",
|
||
|
"reggae.00083",
|
||
|
"reggae.00084",
|
||
|
"reggae.00087",
|
||
|
"reggae.00088",
|
||
|
"reggae.00089",
|
||
|
"reggae.00090",
|
||
|
"rock.00010",
|
||
|
"rock.00011",
|
||
|
"rock.00012",
|
||
|
"rock.00013",
|
||
|
"rock.00014",
|
||
|
"rock.00015",
|
||
|
"rock.00027",
|
||
|
"rock.00028",
|
||
|
"rock.00029",
|
||
|
"rock.00030",
|
||
|
"rock.00031",
|
||
|
"rock.00032",
|
||
|
"rock.00033",
|
||
|
"rock.00034",
|
||
|
"rock.00035",
|
||
|
"rock.00036",
|
||
|
"rock.00037",
|
||
|
"rock.00039",
|
||
|
"rock.00040",
|
||
|
"rock.00041",
|
||
|
"rock.00042",
|
||
|
"rock.00043",
|
||
|
"rock.00044",
|
||
|
"rock.00045",
|
||
|
"rock.00046",
|
||
|
"rock.00047",
|
||
|
"rock.00048",
|
||
|
"rock.00086",
|
||
|
"rock.00087",
|
||
|
"rock.00088",
|
||
|
"rock.00089",
|
||
|
"rock.00090",
|
||
|
]
|
||
|
|
||
|
filtered_train = [
|
||
|
"blues.00029",
|
||
|
"blues.00030",
|
||
|
"blues.00031",
|
||
|
"blues.00032",
|
||
|
"blues.00033",
|
||
|
"blues.00034",
|
||
|
"blues.00035",
|
||
|
"blues.00036",
|
||
|
"blues.00037",
|
||
|
"blues.00038",
|
||
|
"blues.00039",
|
||
|
"blues.00040",
|
||
|
"blues.00041",
|
||
|
"blues.00042",
|
||
|
"blues.00043",
|
||
|
"blues.00044",
|
||
|
"blues.00045",
|
||
|
"blues.00046",
|
||
|
"blues.00047",
|
||
|
"blues.00048",
|
||
|
"blues.00049",
|
||
|
"blues.00073",
|
||
|
"blues.00074",
|
||
|
"blues.00075",
|
||
|
"blues.00076",
|
||
|
"blues.00077",
|
||
|
"blues.00078",
|
||
|
"blues.00079",
|
||
|
"blues.00080",
|
||
|
"blues.00081",
|
||
|
"blues.00082",
|
||
|
"blues.00083",
|
||
|
"blues.00084",
|
||
|
"blues.00085",
|
||
|
"blues.00086",
|
||
|
"blues.00087",
|
||
|
"blues.00088",
|
||
|
"blues.00089",
|
||
|
"blues.00090",
|
||
|
"blues.00091",
|
||
|
"blues.00092",
|
||
|
"blues.00093",
|
||
|
"blues.00094",
|
||
|
"blues.00095",
|
||
|
"blues.00096",
|
||
|
"blues.00097",
|
||
|
"classical.00030",
|
||
|
"classical.00031",
|
||
|
"classical.00032",
|
||
|
"classical.00033",
|
||
|
"classical.00043",
|
||
|
"classical.00044",
|
||
|
"classical.00045",
|
||
|
"classical.00046",
|
||
|
"classical.00047",
|
||
|
"classical.00048",
|
||
|
"classical.00050",
|
||
|
"classical.00051",
|
||
|
"classical.00052",
|
||
|
"classical.00053",
|
||
|
"classical.00054",
|
||
|
"classical.00055",
|
||
|
"classical.00056",
|
||
|
"classical.00057",
|
||
|
"classical.00058",
|
||
|
"classical.00059",
|
||
|
"classical.00060",
|
||
|
"classical.00061",
|
||
|
"classical.00062",
|
||
|
"classical.00063",
|
||
|
"classical.00064",
|
||
|
"classical.00065",
|
||
|
"classical.00066",
|
||
|
"classical.00067",
|
||
|
"classical.00080",
|
||
|
"classical.00081",
|
||
|
"classical.00082",
|
||
|
"classical.00083",
|
||
|
"classical.00084",
|
||
|
"classical.00085",
|
||
|
"classical.00086",
|
||
|
"classical.00087",
|
||
|
"classical.00088",
|
||
|
"classical.00089",
|
||
|
"classical.00090",
|
||
|
"classical.00091",
|
||
|
"classical.00092",
|
||
|
"classical.00093",
|
||
|
"classical.00094",
|
||
|
"classical.00095",
|
||
|
"classical.00096",
|
||
|
"classical.00097",
|
||
|
"classical.00098",
|
||
|
"classical.00099",
|
||
|
"country.00019",
|
||
|
"country.00020",
|
||
|
"country.00021",
|
||
|
"country.00022",
|
||
|
"country.00023",
|
||
|
"country.00024",
|
||
|
"country.00025",
|
||
|
"country.00026",
|
||
|
"country.00028",
|
||
|
"country.00029",
|
||
|
"country.00065",
|
||
|
"country.00066",
|
||
|
"country.00067",
|
||
|
"country.00068",
|
||
|
"country.00069",
|
||
|
"country.00070",
|
||
|
"country.00071",
|
||
|
"country.00072",
|
||
|
"country.00073",
|
||
|
"country.00074",
|
||
|
"country.00075",
|
||
|
"country.00076",
|
||
|
"country.00077",
|
||
|
"country.00078",
|
||
|
"country.00079",
|
||
|
"country.00080",
|
||
|
"country.00081",
|
||
|
"country.00082",
|
||
|
"country.00083",
|
||
|
"country.00084",
|
||
|
"country.00085",
|
||
|
"country.00086",
|
||
|
"country.00087",
|
||
|
"country.00088",
|
||
|
"country.00089",
|
||
|
"country.00090",
|
||
|
"country.00091",
|
||
|
"country.00092",
|
||
|
"country.00093",
|
||
|
"country.00094",
|
||
|
"country.00095",
|
||
|
"country.00096",
|
||
|
"country.00097",
|
||
|
"country.00098",
|
||
|
"country.00099",
|
||
|
"disco.00005",
|
||
|
"disco.00015",
|
||
|
"disco.00016",
|
||
|
"disco.00017",
|
||
|
"disco.00018",
|
||
|
"disco.00019",
|
||
|
"disco.00020",
|
||
|
"disco.00022",
|
||
|
"disco.00023",
|
||
|
"disco.00024",
|
||
|
"disco.00025",
|
||
|
"disco.00026",
|
||
|
"disco.00027",
|
||
|
"disco.00028",
|
||
|
"disco.00029",
|
||
|
"disco.00030",
|
||
|
"disco.00031",
|
||
|
"disco.00032",
|
||
|
"disco.00033",
|
||
|
"disco.00034",
|
||
|
"disco.00035",
|
||
|
"disco.00036",
|
||
|
"disco.00037",
|
||
|
"disco.00039",
|
||
|
"disco.00040",
|
||
|
"disco.00041",
|
||
|
"disco.00042",
|
||
|
"disco.00043",
|
||
|
"disco.00044",
|
||
|
"disco.00045",
|
||
|
"disco.00047",
|
||
|
"disco.00049",
|
||
|
"disco.00053",
|
||
|
"disco.00054",
|
||
|
"disco.00056",
|
||
|
"disco.00057",
|
||
|
"disco.00059",
|
||
|
"disco.00061",
|
||
|
"disco.00070",
|
||
|
"disco.00073",
|
||
|
"disco.00074",
|
||
|
"disco.00089",
|
||
|
"hiphop.00002",
|
||
|
"hiphop.00003",
|
||
|
"hiphop.00004",
|
||
|
"hiphop.00005",
|
||
|
"hiphop.00006",
|
||
|
"hiphop.00007",
|
||
|
"hiphop.00008",
|
||
|
"hiphop.00009",
|
||
|
"hiphop.00010",
|
||
|
"hiphop.00011",
|
||
|
"hiphop.00012",
|
||
|
"hiphop.00013",
|
||
|
"hiphop.00014",
|
||
|
"hiphop.00015",
|
||
|
"hiphop.00016",
|
||
|
"hiphop.00017",
|
||
|
"hiphop.00018",
|
||
|
"hiphop.00019",
|
||
|
"hiphop.00020",
|
||
|
"hiphop.00021",
|
||
|
"hiphop.00022",
|
||
|
"hiphop.00023",
|
||
|
"hiphop.00024",
|
||
|
"hiphop.00025",
|
||
|
"hiphop.00028",
|
||
|
"hiphop.00029",
|
||
|
"hiphop.00031",
|
||
|
"hiphop.00032",
|
||
|
"hiphop.00033",
|
||
|
"hiphop.00034",
|
||
|
"hiphop.00035",
|
||
|
"hiphop.00036",
|
||
|
"hiphop.00037",
|
||
|
"hiphop.00038",
|
||
|
"hiphop.00041",
|
||
|
"hiphop.00042",
|
||
|
"hiphop.00055",
|
||
|
"hiphop.00056",
|
||
|
"hiphop.00057",
|
||
|
"hiphop.00058",
|
||
|
"hiphop.00059",
|
||
|
"hiphop.00060",
|
||
|
"hiphop.00061",
|
||
|
"hiphop.00077",
|
||
|
"hiphop.00078",
|
||
|
"hiphop.00079",
|
||
|
"hiphop.00080",
|
||
|
"jazz.00000",
|
||
|
"jazz.00001",
|
||
|
"jazz.00011",
|
||
|
"jazz.00012",
|
||
|
"jazz.00013",
|
||
|
"jazz.00014",
|
||
|
"jazz.00015",
|
||
|
"jazz.00016",
|
||
|
"jazz.00017",
|
||
|
"jazz.00018",
|
||
|
"jazz.00019",
|
||
|
"jazz.00020",
|
||
|
"jazz.00021",
|
||
|
"jazz.00022",
|
||
|
"jazz.00023",
|
||
|
"jazz.00024",
|
||
|
"jazz.00041",
|
||
|
"jazz.00047",
|
||
|
"jazz.00048",
|
||
|
"jazz.00049",
|
||
|
"jazz.00050",
|
||
|
"jazz.00051",
|
||
|
"jazz.00052",
|
||
|
"jazz.00053",
|
||
|
"jazz.00054",
|
||
|
"jazz.00055",
|
||
|
"jazz.00056",
|
||
|
"jazz.00057",
|
||
|
"jazz.00058",
|
||
|
"jazz.00059",
|
||
|
"jazz.00060",
|
||
|
"jazz.00061",
|
||
|
"jazz.00062",
|
||
|
"jazz.00063",
|
||
|
"jazz.00064",
|
||
|
"jazz.00065",
|
||
|
"jazz.00066",
|
||
|
"jazz.00067",
|
||
|
"jazz.00068",
|
||
|
"jazz.00069",
|
||
|
"jazz.00070",
|
||
|
"jazz.00071",
|
||
|
"jazz.00072",
|
||
|
"metal.00002",
|
||
|
"metal.00003",
|
||
|
"metal.00005",
|
||
|
"metal.00021",
|
||
|
"metal.00024",
|
||
|
"metal.00035",
|
||
|
"metal.00046",
|
||
|
"metal.00047",
|
||
|
"metal.00048",
|
||
|
"metal.00049",
|
||
|
"metal.00050",
|
||
|
"metal.00051",
|
||
|
"metal.00052",
|
||
|
"metal.00053",
|
||
|
"metal.00054",
|
||
|
"metal.00055",
|
||
|
"metal.00056",
|
||
|
"metal.00057",
|
||
|
"metal.00059",
|
||
|
"metal.00060",
|
||
|
"metal.00061",
|
||
|
"metal.00062",
|
||
|
"metal.00063",
|
||
|
"metal.00064",
|
||
|
"metal.00065",
|
||
|
"metal.00066",
|
||
|
"metal.00069",
|
||
|
"metal.00071",
|
||
|
"metal.00072",
|
||
|
"metal.00079",
|
||
|
"metal.00080",
|
||
|
"metal.00084",
|
||
|
"metal.00086",
|
||
|
"metal.00089",
|
||
|
"metal.00090",
|
||
|
"metal.00091",
|
||
|
"metal.00092",
|
||
|
"metal.00093",
|
||
|
"metal.00094",
|
||
|
"metal.00095",
|
||
|
"metal.00096",
|
||
|
"metal.00097",
|
||
|
"metal.00098",
|
||
|
"metal.00099",
|
||
|
"pop.00002",
|
||
|
"pop.00003",
|
||
|
"pop.00004",
|
||
|
"pop.00005",
|
||
|
"pop.00006",
|
||
|
"pop.00007",
|
||
|
"pop.00008",
|
||
|
"pop.00009",
|
||
|
"pop.00011",
|
||
|
"pop.00012",
|
||
|
"pop.00016",
|
||
|
"pop.00017",
|
||
|
"pop.00018",
|
||
|
"pop.00019",
|
||
|
"pop.00020",
|
||
|
"pop.00023",
|
||
|
"pop.00024",
|
||
|
"pop.00025",
|
||
|
"pop.00026",
|
||
|
"pop.00027",
|
||
|
"pop.00028",
|
||
|
"pop.00029",
|
||
|
"pop.00031",
|
||
|
"pop.00032",
|
||
|
"pop.00033",
|
||
|
"pop.00034",
|
||
|
"pop.00035",
|
||
|
"pop.00036",
|
||
|
"pop.00038",
|
||
|
"pop.00039",
|
||
|
"pop.00040",
|
||
|
"pop.00041",
|
||
|
"pop.00042",
|
||
|
"pop.00044",
|
||
|
"pop.00046",
|
||
|
"pop.00049",
|
||
|
"pop.00050",
|
||
|
"pop.00080",
|
||
|
"pop.00097",
|
||
|
"pop.00098",
|
||
|
"pop.00099",
|
||
|
"reggae.00000",
|
||
|
"reggae.00001",
|
||
|
"reggae.00002",
|
||
|
"reggae.00004",
|
||
|
"reggae.00006",
|
||
|
"reggae.00009",
|
||
|
"reggae.00011",
|
||
|
"reggae.00012",
|
||
|
"reggae.00014",
|
||
|
"reggae.00015",
|
||
|
"reggae.00016",
|
||
|
"reggae.00017",
|
||
|
"reggae.00018",
|
||
|
"reggae.00019",
|
||
|
"reggae.00020",
|
||
|
"reggae.00021",
|
||
|
"reggae.00022",
|
||
|
"reggae.00023",
|
||
|
"reggae.00024",
|
||
|
"reggae.00025",
|
||
|
"reggae.00026",
|
||
|
"reggae.00027",
|
||
|
"reggae.00028",
|
||
|
"reggae.00029",
|
||
|
"reggae.00030",
|
||
|
"reggae.00031",
|
||
|
"reggae.00032",
|
||
|
"reggae.00042",
|
||
|
"reggae.00043",
|
||
|
"reggae.00044",
|
||
|
"reggae.00045",
|
||
|
"reggae.00049",
|
||
|
"reggae.00050",
|
||
|
"reggae.00051",
|
||
|
"reggae.00054",
|
||
|
"reggae.00055",
|
||
|
"reggae.00056",
|
||
|
"reggae.00057",
|
||
|
"reggae.00058",
|
||
|
"reggae.00059",
|
||
|
"reggae.00060",
|
||
|
"reggae.00063",
|
||
|
"reggae.00069",
|
||
|
"rock.00000",
|
||
|
"rock.00001",
|
||
|
"rock.00002",
|
||
|
"rock.00003",
|
||
|
"rock.00004",
|
||
|
"rock.00005",
|
||
|
"rock.00006",
|
||
|
"rock.00007",
|
||
|
"rock.00008",
|
||
|
"rock.00009",
|
||
|
"rock.00016",
|
||
|
"rock.00017",
|
||
|
"rock.00018",
|
||
|
"rock.00019",
|
||
|
"rock.00020",
|
||
|
"rock.00021",
|
||
|
"rock.00022",
|
||
|
"rock.00023",
|
||
|
"rock.00024",
|
||
|
"rock.00025",
|
||
|
"rock.00026",
|
||
|
"rock.00057",
|
||
|
"rock.00058",
|
||
|
"rock.00059",
|
||
|
"rock.00060",
|
||
|
"rock.00061",
|
||
|
"rock.00062",
|
||
|
"rock.00063",
|
||
|
"rock.00064",
|
||
|
"rock.00065",
|
||
|
"rock.00066",
|
||
|
"rock.00067",
|
||
|
"rock.00068",
|
||
|
"rock.00069",
|
||
|
"rock.00070",
|
||
|
"rock.00091",
|
||
|
"rock.00092",
|
||
|
"rock.00093",
|
||
|
"rock.00094",
|
||
|
"rock.00095",
|
||
|
"rock.00096",
|
||
|
"rock.00097",
|
||
|
"rock.00098",
|
||
|
"rock.00099",
|
||
|
]
|
||
|
|
||
|
filtered_valid = [
|
||
|
"blues.00000",
|
||
|
"blues.00001",
|
||
|
"blues.00002",
|
||
|
"blues.00003",
|
||
|
"blues.00004",
|
||
|
"blues.00005",
|
||
|
"blues.00006",
|
||
|
"blues.00007",
|
||
|
"blues.00008",
|
||
|
"blues.00009",
|
||
|
"blues.00010",
|
||
|
"blues.00011",
|
||
|
"blues.00050",
|
||
|
"blues.00051",
|
||
|
"blues.00052",
|
||
|
"blues.00053",
|
||
|
"blues.00054",
|
||
|
"blues.00055",
|
||
|
"blues.00056",
|
||
|
"blues.00057",
|
||
|
"blues.00058",
|
||
|
"blues.00059",
|
||
|
"blues.00060",
|
||
|
"classical.00000",
|
||
|
"classical.00001",
|
||
|
"classical.00002",
|
||
|
"classical.00003",
|
||
|
"classical.00004",
|
||
|
"classical.00005",
|
||
|
"classical.00006",
|
||
|
"classical.00007",
|
||
|
"classical.00008",
|
||
|
"classical.00009",
|
||
|
"classical.00010",
|
||
|
"classical.00068",
|
||
|
"classical.00069",
|
||
|
"classical.00070",
|
||
|
"classical.00071",
|
||
|
"classical.00072",
|
||
|
"classical.00073",
|
||
|
"classical.00074",
|
||
|
"classical.00075",
|
||
|
"classical.00076",
|
||
|
"country.00000",
|
||
|
"country.00001",
|
||
|
"country.00002",
|
||
|
"country.00003",
|
||
|
"country.00004",
|
||
|
"country.00005",
|
||
|
"country.00006",
|
||
|
"country.00007",
|
||
|
"country.00009",
|
||
|
"country.00010",
|
||
|
"country.00011",
|
||
|
"country.00012",
|
||
|
"country.00013",
|
||
|
"country.00014",
|
||
|
"country.00015",
|
||
|
"country.00016",
|
||
|
"country.00017",
|
||
|
"country.00018",
|
||
|
"country.00027",
|
||
|
"country.00041",
|
||
|
"country.00042",
|
||
|
"country.00045",
|
||
|
"country.00049",
|
||
|
"disco.00000",
|
||
|
"disco.00002",
|
||
|
"disco.00003",
|
||
|
"disco.00004",
|
||
|
"disco.00006",
|
||
|
"disco.00007",
|
||
|
"disco.00008",
|
||
|
"disco.00009",
|
||
|
"disco.00010",
|
||
|
"disco.00011",
|
||
|
"disco.00012",
|
||
|
"disco.00013",
|
||
|
"disco.00014",
|
||
|
"disco.00046",
|
||
|
"disco.00048",
|
||
|
"disco.00052",
|
||
|
"disco.00067",
|
||
|
"disco.00068",
|
||
|
"disco.00072",
|
||
|
"disco.00075",
|
||
|
"disco.00090",
|
||
|
"disco.00095",
|
||
|
"hiphop.00081",
|
||
|
"hiphop.00082",
|
||
|
"hiphop.00083",
|
||
|
"hiphop.00084",
|
||
|
"hiphop.00085",
|
||
|
"hiphop.00086",
|
||
|
"hiphop.00087",
|
||
|
"hiphop.00088",
|
||
|
"hiphop.00089",
|
||
|
"hiphop.00090",
|
||
|
"hiphop.00091",
|
||
|
"hiphop.00092",
|
||
|
"hiphop.00093",
|
||
|
"hiphop.00094",
|
||
|
"hiphop.00095",
|
||
|
"hiphop.00096",
|
||
|
"hiphop.00097",
|
||
|
"hiphop.00098",
|
||
|
"jazz.00002",
|
||
|
"jazz.00003",
|
||
|
"jazz.00004",
|
||
|
"jazz.00005",
|
||
|
"jazz.00006",
|
||
|
"jazz.00007",
|
||
|
"jazz.00008",
|
||
|
"jazz.00009",
|
||
|
"jazz.00010",
|
||
|
"jazz.00025",
|
||
|
"jazz.00026",
|
||
|
"jazz.00027",
|
||
|
"jazz.00028",
|
||
|
"jazz.00029",
|
||
|
"jazz.00030",
|
||
|
"jazz.00031",
|
||
|
"jazz.00032",
|
||
|
"metal.00000",
|
||
|
"metal.00001",
|
||
|
"metal.00006",
|
||
|
"metal.00007",
|
||
|
"metal.00008",
|
||
|
"metal.00009",
|
||
|
"metal.00010",
|
||
|
"metal.00011",
|
||
|
"metal.00016",
|
||
|
"metal.00017",
|
||
|
"metal.00018",
|
||
|
"metal.00019",
|
||
|
"metal.00020",
|
||
|
"metal.00036",
|
||
|
"metal.00037",
|
||
|
"metal.00068",
|
||
|
"metal.00076",
|
||
|
"metal.00077",
|
||
|
"metal.00081",
|
||
|
"metal.00082",
|
||
|
"pop.00010",
|
||
|
"pop.00053",
|
||
|
"pop.00055",
|
||
|
"pop.00058",
|
||
|
"pop.00059",
|
||
|
"pop.00060",
|
||
|
"pop.00061",
|
||
|
"pop.00062",
|
||
|
"pop.00081",
|
||
|
"pop.00083",
|
||
|
"pop.00084",
|
||
|
"pop.00085",
|
||
|
"pop.00086",
|
||
|
"reggae.00061",
|
||
|
"reggae.00062",
|
||
|
"reggae.00070",
|
||
|
"reggae.00072",
|
||
|
"reggae.00074",
|
||
|
"reggae.00076",
|
||
|
"reggae.00077",
|
||
|
"reggae.00078",
|
||
|
"reggae.00085",
|
||
|
"reggae.00092",
|
||
|
"reggae.00093",
|
||
|
"reggae.00094",
|
||
|
"reggae.00095",
|
||
|
"reggae.00096",
|
||
|
"reggae.00097",
|
||
|
"reggae.00098",
|
||
|
"reggae.00099",
|
||
|
"rock.00038",
|
||
|
"rock.00049",
|
||
|
"rock.00050",
|
||
|
"rock.00051",
|
||
|
"rock.00052",
|
||
|
"rock.00053",
|
||
|
"rock.00054",
|
||
|
"rock.00055",
|
||
|
"rock.00056",
|
||
|
"rock.00071",
|
||
|
"rock.00072",
|
||
|
"rock.00073",
|
||
|
"rock.00074",
|
||
|
"rock.00075",
|
||
|
"rock.00076",
|
||
|
"rock.00077",
|
||
|
"rock.00078",
|
||
|
"rock.00079",
|
||
|
"rock.00080",
|
||
|
"rock.00081",
|
||
|
"rock.00082",
|
||
|
"rock.00083",
|
||
|
"rock.00084",
|
||
|
"rock.00085",
|
||
|
]
|
||
|
|
||
|
|
||
|
URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
|
||
|
FOLDER_IN_ARCHIVE = "genres"
|
||
|
_CHECKSUMS = {
|
||
|
"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
|
||
|
}
|
||
|
|
||
|
|
||
|
def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
|
||
|
"""
|
||
|
Loads a file from the dataset and returns the raw waveform
|
||
|
as a Torch Tensor, its sample rate as an integer, and its
|
||
|
genre as a string.
|
||
|
"""
|
||
|
# Filenames are of the form label.id, e.g. blues.00078
|
||
|
label, _ = fileid.split(".")
|
||
|
|
||
|
# Read wav
|
||
|
file_audio = os.path.join(path, label, fileid + ext_audio)
|
||
|
waveform, sample_rate = torchaudio.load(file_audio)
|
||
|
|
||
|
return waveform, sample_rate, label
|
||
|
|
||
|
|
||
|
class GTZAN(Dataset):
|
||
|
"""*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset.
|
||
|
|
||
|
Note:
|
||
|
Please see http://marsyas.info/downloads/datasets.html if you are planning to use
|
||
|
this dataset to publish results.
|
||
|
|
||
|
Note:
|
||
|
As of October 2022, the download link is not currently working. Setting ``download=True``
|
||
|
in GTZAN dataset will result in a URL connection error.
|
||
|
|
||
|
Args:
|
||
|
root (str or Path): Path to the directory where the dataset is found or downloaded.
|
||
|
url (str, optional): The URL to download the dataset from.
|
||
|
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
|
||
|
folder_in_archive (str, optional): The top-level directory of the dataset.
|
||
|
download (bool, optional):
|
||
|
Whether to download the dataset if it is not found at root path. (default: ``False``).
|
||
|
subset (str or None, optional): Which subset of the dataset to use.
|
||
|
One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
|
||
|
If ``None``, the entire dataset is used. (default: ``None``).
|
||
|
"""
|
||
|
|
||
|
_ext_audio = ".wav"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[str, Path],
|
||
|
url: str = URL,
|
||
|
folder_in_archive: str = FOLDER_IN_ARCHIVE,
|
||
|
download: bool = False,
|
||
|
subset: Optional[str] = None,
|
||
|
) -> None:
|
||
|
|
||
|
# super(GTZAN, self).__init__()
|
||
|
|
||
|
# Get string representation of 'root' in case Path object is passed
|
||
|
root = os.fspath(root)
|
||
|
|
||
|
self.root = root
|
||
|
self.url = url
|
||
|
self.folder_in_archive = folder_in_archive
|
||
|
self.download = download
|
||
|
self.subset = subset
|
||
|
|
||
|
if subset is not None and subset not in ["training", "validation", "testing"]:
|
||
|
raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
|
||
|
|
||
|
archive = os.path.basename(url)
|
||
|
archive = os.path.join(root, archive)
|
||
|
self._path = os.path.join(root, folder_in_archive)
|
||
|
|
||
|
if download:
|
||
|
if not os.path.isdir(self._path):
|
||
|
if not os.path.isfile(archive):
|
||
|
checksum = _CHECKSUMS.get(url, None)
|
||
|
download_url_to_file(url, archive, hash_prefix=checksum)
|
||
|
_extract_tar(archive)
|
||
|
|
||
|
if not os.path.isdir(self._path):
|
||
|
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
|
||
|
|
||
|
if self.subset is None:
|
||
|
# Check every subdirectory under dataset root
|
||
|
# which has the same name as the genres in
|
||
|
# GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
|
||
|
# This lets users remove or move around song files,
|
||
|
# useful when e.g. they want to use only some of the files
|
||
|
# in a genre or want to label other files with a different
|
||
|
# genre.
|
||
|
self._walker = []
|
||
|
|
||
|
root = os.path.expanduser(self._path)
|
||
|
|
||
|
for directory in gtzan_genres:
|
||
|
fulldir = os.path.join(root, directory)
|
||
|
|
||
|
if not os.path.exists(fulldir):
|
||
|
continue
|
||
|
|
||
|
songs_in_genre = os.listdir(fulldir)
|
||
|
songs_in_genre.sort()
|
||
|
for fname in songs_in_genre:
|
||
|
name, ext = os.path.splitext(fname)
|
||
|
if ext.lower() == ".wav" and "." in name:
|
||
|
# Check whether the file is of the form
|
||
|
# `gtzan_genre`.`5 digit number`.wav
|
||
|
genre, num = name.split(".")
|
||
|
if genre in gtzan_genres and len(num) == 5 and num.isdigit():
|
||
|
self._walker.append(name)
|
||
|
else:
|
||
|
if self.subset == "training":
|
||
|
self._walker = filtered_train
|
||
|
elif self.subset == "validation":
|
||
|
self._walker = filtered_valid
|
||
|
elif self.subset == "testing":
|
||
|
self._walker = filtered_test
|
||
|
|
||
|
def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
|
||
|
"""Load the n-th sample from the dataset.
|
||
|
|
||
|
Args:
|
||
|
n (int): The index of the sample to be loaded
|
||
|
|
||
|
Returns:
|
||
|
Tuple of the following items;
|
||
|
|
||
|
Tensor:
|
||
|
Waveform
|
||
|
int:
|
||
|
Sample rate
|
||
|
str:
|
||
|
Label
|
||
|
"""
|
||
|
fileid = self._walker[n]
|
||
|
item = load_gtzan_item(fileid, self._path, self._ext_audio)
|
||
|
waveform, sample_rate, label = item
|
||
|
return waveform, sample_rate, label
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._walker)
|