2021-05-15 15:33:42 +02:00
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow import keras
|
|
|
|
import matplotlib.pyplot as plt
|
2021-05-15 16:59:57 +02:00
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
2021-05-15 15:33:42 +02:00
|
|
|
|
2021-05-15 15:47:36 +02:00
|
|
|
model = keras.models.load_model('wine_model.h5')
|
2021-05-15 15:33:42 +02:00
|
|
|
print('evaluating')
|
|
|
|
|
2021-05-15 16:59:57 +02:00
|
|
|
data =pd.read_csv('test.csv')
|
2021-05-15 15:33:42 +02:00
|
|
|
|
2021-05-15 16:59:57 +02:00
|
|
|
#########################################
|
2021-05-15 16:40:43 +02:00
|
|
|
|
2021-05-15 16:59:57 +02:00
|
|
|
y = data['quality']
|
|
|
|
x = data.drop('quality', axis=1)
|
|
|
|
|
|
|
|
citricacid = x['fixed acidity'] * x['citric acid']
|
|
|
|
citric_acidity = pd.DataFrame(citricacid, columns=['citric_accidity'])
|
|
|
|
|
|
|
|
density_acidity = x['fixed acidity'] * x['density']
|
|
|
|
density_acidity = pd.DataFrame(density_acidity, columns=['density_acidity'])
|
|
|
|
|
|
|
|
|
|
|
|
x = data.join(citric_acidity).join(density_acidity)
|
|
|
|
|
|
|
|
print(y)
|
|
|
|
|
|
|
|
bins = (2, 5, 8)
|
|
|
|
gnames = ['bad', 'nice']
|
|
|
|
y = pd.cut(y, bins = bins, labels = gnames)
|
|
|
|
|
|
|
|
enc = LabelEncoder()
|
|
|
|
yenc = enc.fit_transform(y)
|
|
|
|
|
|
|
|
scale = StandardScaler()
|
|
|
|
scaled_x = scale.fit_transform(x)
|
|
|
|
|
|
|
|
##################################
|
|
|
|
|
|
|
|
y_pred = model.predict(scaled_x)
|
2021-05-15 15:33:42 +02:00
|
|
|
|
|
|
|
y_pred = np.around(y_pred, decimals=0)
|
|
|
|
|
2021-05-15 16:59:57 +02:00
|
|
|
results = accuracy_score(yenc,y_pred)
|
2021-05-15 15:33:42 +02:00
|
|
|
with open('results.txt', 'a+', encoding="UTF-8") as f:
|
|
|
|
f.write(str(results) +"\n")
|
|
|
|
|
|
|
|
with open('results.txt', 'r', encoding="UTF-8") as f:
|
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(10,10))
|
|
|
|
|
|
|
|
chart = fig.add_subplot()
|
|
|
|
chart.set_ylabel("Accuracy")
|
|
|
|
chart.set_xlabel("Number of build")
|
|
|
|
x = np.arange(0, len(lines), 1)
|
|
|
|
y = [float(x) for x in lines]
|
|
|
|
print(y)
|
|
|
|
plt.plot(x,y,"ro")
|
|
|
|
plt.savefig("evaluation.png")
|