raw
This commit is contained in:
parent
07d97c5267
commit
d7d5047587
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ venv*
|
|||||||
.vscode*
|
.vscode*
|
||||||
__pycache__*
|
__pycache__*
|
||||||
music_genre.csv
|
music_genre.csv
|
||||||
|
music_genre.model
|
36
bayes.py
36
bayes.py
@ -1,3 +1,35 @@
|
|||||||
|
from sklearn.naive_bayes import GaussianNB
|
||||||
|
from sklearn.metrics import confusion_matrix, accuracy_score
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import pickle, os
|
||||||
|
import typing
|
||||||
|
|
||||||
class Bayes:
|
class Bayes:
|
||||||
def __init__(self):
|
def __init__(self, path: str):
|
||||||
pass
|
self.path = path
|
||||||
|
self.model_exists = False
|
||||||
|
if os.path.isfile(self.path):
|
||||||
|
self.model_exists = True
|
||||||
|
with open(self.path, 'rb') as file:
|
||||||
|
self.classifier = pickle.load(file)
|
||||||
|
else:
|
||||||
|
self.classifier = GaussianNB()
|
||||||
|
|
||||||
|
|
||||||
|
def train(self, X: pd.DataFrame, Y: pd.Series) -> None:
|
||||||
|
self.classifier.fit(X, Y)
|
||||||
|
with open(self.path, 'wb') as file:
|
||||||
|
pickle.dump(self.classifier, file)
|
||||||
|
self.model_exists = True
|
||||||
|
|
||||||
|
|
||||||
|
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
||||||
|
predictions = self.classifier.predict(X)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def eval(self, Y: pd.Series, Y_pred: np.ndarray) -> typing.Tuple[np.ndarray, np.float64]:
|
||||||
|
cm = confusion_matrix(Y, Y_pred)
|
||||||
|
ac = accuracy_score(Y, Y_pred)
|
||||||
|
return (cm, ac)
|
@ -1,5 +1,7 @@
|
|||||||
|
from sklearn.model_selection import train_test_split
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import typing
|
||||||
|
|
||||||
class DataPreparator:
|
class DataPreparator:
|
||||||
genre_dict = {
|
genre_dict = {
|
||||||
@ -15,9 +17,16 @@ class DataPreparator:
|
|||||||
"rock" : 10
|
"rock" : 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(df: pd.DataFrame) -> pd.DataFrame:
|
def prepare_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
data = deepcopy(df)
|
data = deepcopy(df)
|
||||||
column = df["label"].apply(lambda x: DataPreparator.genre_dict[x])
|
column = df["label"].apply(lambda x: DataPreparator.genre_dict[x])
|
||||||
data.insert(0, 'genre', column, 'float')
|
data.insert(0, 'genre', column, 'float')
|
||||||
data = data.drop(columns=['filename', 'label', 'length'])
|
data = data.drop(columns=['filename', 'label', 'length'])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def train_test_split(df: pd.DataFrame) -> typing.Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
|
||||||
|
X = df.drop(["genre"], axis=1)
|
||||||
|
Y = df["genre"]
|
||||||
|
return train_test_split(X, Y, test_size = 0.20, random_state = False)
|
17
main.py
17
main.py
@ -10,3 +10,20 @@ else:
|
|||||||
data_raw = pd.read_csv('music_genre_raw.csv')
|
data_raw = pd.read_csv('music_genre_raw.csv')
|
||||||
data = DataPreparator.prepare_data(data_raw)
|
data = DataPreparator.prepare_data(data_raw)
|
||||||
data.to_csv(filename, index=False)
|
data.to_csv(filename, index=False)
|
||||||
|
|
||||||
|
X_train, X_test, Y_train, Y_test = DataPreparator.train_test_split(data)
|
||||||
|
|
||||||
|
bayes = Bayes('music_genre.model')
|
||||||
|
if(not bayes.model_exists):
|
||||||
|
bayes.train(X_train, Y_train)
|
||||||
|
|
||||||
|
|
||||||
|
Y_predicted = bayes.predict(X_train)
|
||||||
|
eval_result = bayes.eval(Y_train, Y_predicted)
|
||||||
|
print("Train:")
|
||||||
|
print(eval_result[1])
|
||||||
|
|
||||||
|
Y_predicted = bayes.predict(X_test)
|
||||||
|
eval_result = bayes.eval(Y_test, Y_predicted)
|
||||||
|
print("Test:")
|
||||||
|
print(eval_result[1])
|
||||||
|
@ -1 +1,3 @@
|
|||||||
pandas==1.2.4
|
pandas==1.2.4
|
||||||
|
numpy==1.20.3
|
||||||
|
sklearn==0.0
|
Loading…
Reference in New Issue
Block a user