DSIC-Bayes-continuous/bayes.py

35 lines
1.0 KiB
Python
Raw Permalink Normal View History

2021-05-26 15:25:00 +02:00
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix, accuracy_score
import pandas as pd
import numpy as np
import pickle, os
import typing
2021-05-26 13:32:48 +02:00
class Bayes:
2021-05-26 15:25:00 +02:00
def __init__(self, path: str):
self.path = path
self.model_exists = False
if os.path.isfile(self.path):
self.model_exists = True
with open(self.path, 'rb') as file:
self.classifier = pickle.load(file)
else:
self.classifier = GaussianNB()
def train(self, X: pd.DataFrame, Y: pd.Series) -> None:
self.classifier.fit(X, Y)
with open(self.path, 'wb') as file:
pickle.dump(self.classifier, file)
self.model_exists = True
def predict(self, X: pd.DataFrame) -> np.ndarray:
predictions = self.classifier.predict(X)
return predictions
def eval(self, Y: pd.Series, Y_pred: np.ndarray) -> typing.Tuple[np.ndarray, np.float64]:
cm = confusion_matrix(Y, Y_pred)
ac = accuracy_score(Y, Y_pred)
return (cm, ac)