2021-05-26 15:25:00 +02:00
|
|
|
from sklearn.naive_bayes import GaussianNB
|
|
|
|
from sklearn.metrics import confusion_matrix, accuracy_score
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
import pickle, os
|
|
|
|
import typing
|
|
|
|
|
2021-05-26 13:32:48 +02:00
|
|
|
class Bayes:
|
2021-05-26 15:25:00 +02:00
|
|
|
def __init__(self, path: str):
|
|
|
|
self.path = path
|
|
|
|
self.model_exists = False
|
|
|
|
if os.path.isfile(self.path):
|
|
|
|
self.model_exists = True
|
|
|
|
with open(self.path, 'rb') as file:
|
|
|
|
self.classifier = pickle.load(file)
|
|
|
|
else:
|
|
|
|
self.classifier = GaussianNB()
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, X: pd.DataFrame, Y: pd.Series) -> None:
|
|
|
|
self.classifier.fit(X, Y)
|
|
|
|
with open(self.path, 'wb') as file:
|
|
|
|
pickle.dump(self.classifier, file)
|
|
|
|
self.model_exists = True
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, X: pd.DataFrame) -> np.ndarray:
|
|
|
|
predictions = self.classifier.predict(X)
|
|
|
|
return predictions
|
|
|
|
|
|
|
|
|
|
|
|
def eval(self, Y: pd.Series, Y_pred: np.ndarray) -> typing.Tuple[np.ndarray, np.float64]:
|
|
|
|
cm = confusion_matrix(Y, Y_pred)
|
|
|
|
ac = accuracy_score(Y, Y_pred)
|
|
|
|
return (cm, ac)
|