34 lines
890 B
Python
34 lines
890 B
Python
from bayes import Bayes
|
|
from datapreparator import DataPreparator
|
|
import pandas as pd
|
|
import os
|
|
|
|
filename = 'music_genre.csv'
|
|
if os.path.isfile(filename):
|
|
data = pd.read_csv(filename)
|
|
else:
|
|
data_raw = pd.read_csv('music_genre_raw.csv')
|
|
data = DataPreparator.prepare_data(data_raw)
|
|
data.to_csv(filename, index=False)
|
|
|
|
X_train, X_test, Y_train, Y_test = DataPreparator.train_test_split(data)
|
|
|
|
bayes = Bayes('_model.model')
|
|
if(not bayes.model_exists):
|
|
bayes.train(X_train, Y_train)
|
|
|
|
Y_predicted = bayes.predict(X_train)
|
|
eval_result = bayes.eval(Y_train, Y_predicted)
|
|
print("Train:")
|
|
print(eval_result[1])
|
|
|
|
Y_predicted = bayes.predict(X_test)
|
|
eval_result = bayes.eval(Y_test, Y_predicted)
|
|
|
|
print("Test:")
|
|
print(eval_result[1])
|
|
|
|
#Result preview
|
|
# for i in range(100):
|
|
# print(f"Expected: {Y_test.to_numpy()[i]}\tPred: {Y_predicted[i]}")
|
|
DataPreparator.print_df_info(data) |