61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
|
import torch
|
||
|
import matplotlib.pyplot as plt
|
||
|
import os
|
||
|
import json
|
||
|
from sklearn.metrics import accuracy_score, precision_score, recall_score
|
||
|
import pandas as pd
|
||
|
|
||
|
# Wczytywanie
|
||
|
# Nie potrzebuje wczytywać modelu i danych testowych bo w jobie trenującym już stworzyłem csv z wynikami
|
||
|
# model = torch.load('model.pkl')
|
||
|
# test_set = pd.read_csv('d_test.csv', encoding='latin-1')
|
||
|
test_results = pd.read_csv('neural_network_prediction_results.csv')
|
||
|
|
||
|
|
||
|
# Ewaluacja
|
||
|
y_true = test_results['Testing Y']
|
||
|
y_predicted = test_results['Predicted Y']
|
||
|
|
||
|
accuracy = accuracy_score(y_true, y_predicted)
|
||
|
precision = precision_score(y_true, y_predicted, average='macro')
|
||
|
recall = recall_score(y_true, y_predicted, average='macro')
|
||
|
|
||
|
eval_results = {
|
||
|
'Accuracy': accuracy,
|
||
|
'Macro-Avg Precision': precision,
|
||
|
'Macro-Avg Recall': recall
|
||
|
}
|
||
|
|
||
|
|
||
|
filename = 'eval_results.json'
|
||
|
if not os.path.exists(filename):
|
||
|
with open(filename, 'w') as file:
|
||
|
json.dump({'eval_results': []}, file, indent=2)
|
||
|
|
||
|
with open(filename, 'r+') as file:
|
||
|
file_data = json.load(file)
|
||
|
file_data['eval_results'].append(eval_results)
|
||
|
file.seek(0)
|
||
|
json.dump(file_data, file, indent=2)
|
||
|
|
||
|
with open(filename, 'r') as file:
|
||
|
results = json.load(file)['eval_results']
|
||
|
f_acc = []
|
||
|
f_prc = []
|
||
|
f_rec = []
|
||
|
for res in results:
|
||
|
f_acc.append(res['Accuracy'])
|
||
|
f_prc.append(res['Macro-Avg Precision'])
|
||
|
f_rec.append(res['Macro-Avg Recall'])
|
||
|
|
||
|
build_axis = [i+1 for i in range(len(f_acc))]
|
||
|
|
||
|
plt.xlabel('Build')
|
||
|
plt.ylabel('Score')
|
||
|
plt.plot(build_axis, f_acc, label='Accuracy')
|
||
|
plt.plot(build_axis, f_prc, label='Macro-Avg Precision')
|
||
|
plt.plot(build_axis, f_rec, label='Macro-Avg Recall')
|
||
|
plt.legend()
|
||
|
plt.show()
|
||
|
plt.savefig('metrics.png')
|