diff --git a/predict.py b/predict.py index 461db7c..61f9134 100644 --- a/predict.py +++ b/predict.py @@ -5,6 +5,7 @@ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" from keras.models import load_model import pandas as pd from sklearn.metrics import confusion_matrix +import numpy as np def main(): @@ -13,7 +14,8 @@ def main(): y_test = pd.read_csv("data/y_test.csv") y_pred = model.predict(X_test) - y_pred = y_pred > 0.5 + y_pred = y_pred >= 0.5 + np.savetxt("data/y_pred.csv", y_pred, delimiter=",") cm = confusion_matrix(y_test, y_pred) print(