This commit is contained in:
Anna Nowak 2021-05-26 15:25:00 +02:00
parent 07d97c5267
commit d7d5047587
5 changed files with 66 additions and 5 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ venv*
.vscode* .vscode*
__pycache__* __pycache__*
music_genre.csv music_genre.csv
music_genre.model

View File

@ -1,3 +1,35 @@
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix, accuracy_score
import pandas as pd
import numpy as np
import pickle, os
import typing
class Bayes: class Bayes:
def __init__(self): def __init__(self, path: str):
pass self.path = path
self.model_exists = False
if os.path.isfile(self.path):
self.model_exists = True
with open(self.path, 'rb') as file:
self.classifier = pickle.load(file)
else:
self.classifier = GaussianNB()
def train(self, X: pd.DataFrame, Y: pd.Series) -> None:
self.classifier.fit(X, Y)
with open(self.path, 'wb') as file:
pickle.dump(self.classifier, file)
self.model_exists = True
def predict(self, X: pd.DataFrame) -> np.ndarray:
predictions = self.classifier.predict(X)
return predictions
def eval(self, Y: pd.Series, Y_pred: np.ndarray) -> typing.Tuple[np.ndarray, np.float64]:
cm = confusion_matrix(Y, Y_pred)
ac = accuracy_score(Y, Y_pred)
return (cm, ac)

View File

@ -1,5 +1,7 @@
from sklearn.model_selection import train_test_split
from copy import deepcopy from copy import deepcopy
import pandas as pd import pandas as pd
import typing
class DataPreparator: class DataPreparator:
genre_dict = { genre_dict = {
@ -15,9 +17,16 @@ class DataPreparator:
"rock" : 10 "rock" : 10
} }
def prepare_data(df: pd.DataFrame) -> pd.DataFrame: def prepare_data(df: pd.DataFrame) -> pd.DataFrame:
data = deepcopy(df) data = deepcopy(df)
column = df["label"].apply(lambda x: DataPreparator.genre_dict[x]) column = df["label"].apply(lambda x: DataPreparator.genre_dict[x])
data.insert(0, 'genre', column, 'float') data.insert(0, 'genre', column, 'float')
data = data.drop(columns=['filename', 'label', 'length']) data = data.drop(columns=['filename', 'label', 'length'])
return data return data
def train_test_split(df: pd.DataFrame) -> typing.Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
X = df.drop(["genre"], axis=1)
Y = df["genre"]
return train_test_split(X, Y, test_size = 0.20, random_state = False)

17
main.py
View File

@ -10,3 +10,20 @@ else:
data_raw = pd.read_csv('music_genre_raw.csv') data_raw = pd.read_csv('music_genre_raw.csv')
data = DataPreparator.prepare_data(data_raw) data = DataPreparator.prepare_data(data_raw)
data.to_csv(filename, index=False) data.to_csv(filename, index=False)
X_train, X_test, Y_train, Y_test = DataPreparator.train_test_split(data)
bayes = Bayes('music_genre.model')
if(not bayes.model_exists):
bayes.train(X_train, Y_train)
Y_predicted = bayes.predict(X_train)
eval_result = bayes.eval(Y_train, Y_predicted)
print("Train:")
print(eval_result[1])
Y_predicted = bayes.predict(X_test)
eval_result = bayes.eval(Y_test, Y_predicted)
print("Test:")
print(eval_result[1])

View File

@ -1 +1,3 @@
pandas==1.2.4 pandas==1.2.4
numpy==1.20.3
sklearn==0.0