ium_487184/evaluate.py
bartosz.maslanka.consultant 8c2f6e4e0f add jnks,etc
2023-06-28 22:39:02 +02:00

80 lines
2.7 KiB
Python

import torch
from train import MyNeuralNetwork, load_data
from torch.utils.data import DataLoader
import csv
import os
import matplotlib.pyplot as plt
from typing import Tuple, List
def evaluate_model() -> Tuple[List[float], float]:
model = MyNeuralNetwork()
model.load_state_dict(torch.load('model.pt'))
model.eval()
test_dataset = load_data("gender_classification_test.csv")
batch_size: int = 32
test_dataloader: DataLoader = DataLoader(test_dataset, batch_size=batch_size)
predictions = []
labels = []
get_label = lambda pred: 1 if pred >= 0.5 else 0
total = 0
correct = 0
with torch.no_grad():
for batch_data, batch_labels in test_dataloader:
batch_predictions = model(batch_data)
predicted_batch_labels = [get_label(prediction) for prediction in batch_predictions]
total += len(predicted_batch_labels)
batch_labels_list = list(map(int,batch_labels.tolist()))
correct += sum(x == y for x, y in zip(predicted_batch_labels, batch_labels_list))
predictions.extend(batch_predictions)
labels.extend(batch_labels)
accuracy = correct/total
return predictions, accuracy
def save_predictions(predictions: list[float]) -> None:
filename = "results.csv"
column_name = "predict"
with open(filename, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow([column_name])
for result in predictions:
loan_decision = 1 if result.item() > 0.5 else 0
writer.writerow([loan_decision])
def save_accuracy(accuracy):
filename = 'results.csv'
if os.path.exists(filename):
with open(filename, 'a') as file:
writer = csv.writer(file)
writer.writerow([accuracy])
else:
with open(filename, 'w') as file:
writer = csv.writer(file)
writer.writerow(['accuracy'])
writer.writerow([accuracy])
def plot_accuracy():
filename = 'results.csv'
accuracy_results = []
if os.path.exists(filename):
with open(filename, 'r') as file:
reader = csv.reader(file)
for idx, row in enumerate(reader):
if idx == 0:
continue
accuracy_results.append(float(row[0]))
iterations = list(map(str,range(1, len(accuracy_results)+1)))
plt.plot(iterations, accuracy_results)
plt.xlabel('build')
plt.ylabel('accuracy')
plt.title("Accuracies over builds.")
plt.savefig("plot.png")
def main():
predictions, accuracy = evaluate_model()
save_predictions(predictions)
save_accuracy(accuracy)
plot_accuracy()
if __name__ == "__main__":
main()