39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from sklearn.model_selection import train_test_split
|
|
from copy import deepcopy
|
|
import pandas as pd
|
|
import typing
|
|
|
|
class DataPreparator:
|
|
genre_dict = {
|
|
"blues" : 1,
|
|
"classical" : 2,
|
|
"country" : 3,
|
|
"disco" : 4,
|
|
"hiphop" : 5,
|
|
"jazz" : 6,
|
|
"metal" : 7,
|
|
"pop" : 8,
|
|
"reggae" : 9,
|
|
"rock" : 10
|
|
}
|
|
|
|
|
|
def prepare_data(df: pd.DataFrame) -> pd.DataFrame:
|
|
data = deepcopy(df)
|
|
column = df["label"].apply(lambda x: DataPreparator.genre_dict[x])
|
|
data.insert(0, 'genre', column, 'int')
|
|
data = data.drop(columns=['filename', 'label', 'length'])
|
|
return data
|
|
|
|
|
|
def train_test_split(df: pd.DataFrame) -> typing.Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
|
|
#X = df.drop(["genre"], axis=1)
|
|
X = df[["chroma_stft_mean","chroma_stft_var","rms_mean"]]
|
|
Y = df["genre"]
|
|
return train_test_split(X, Y, test_size = 0.20, random_state = False)
|
|
|
|
|
|
def print_df_info(df: pd.DataFrame) -> None:
|
|
for key in DataPreparator.genre_dict.keys():
|
|
count = len(df[df["genre"]==DataPreparator.genre_dict[key]])
|
|
print(f"Key: {key}\tCount: {count}") |