80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
|
import torch
|
||
|
from train import MyNeuralNetwork, load_data
|
||
|
from torch.utils.data import DataLoader
|
||
|
import csv
|
||
|
import os
|
||
|
import matplotlib.pyplot as plt
|
||
|
from typing import Tuple, List
|
||
|
|
||
|
def evaluate_model() -> Tuple[List[float], float]:
|
||
|
model = MyNeuralNetwork()
|
||
|
model.load_state_dict(torch.load('model.pt'))
|
||
|
model.eval()
|
||
|
test_dataset = load_data("gender_classification_test.csv")
|
||
|
batch_size: int = 32
|
||
|
test_dataloader: DataLoader = DataLoader(test_dataset, batch_size=batch_size)
|
||
|
predictions = []
|
||
|
labels = []
|
||
|
get_label = lambda pred: 1 if pred >= 0.5 else 0
|
||
|
total = 0
|
||
|
correct = 0
|
||
|
with torch.no_grad():
|
||
|
for batch_data, batch_labels in test_dataloader:
|
||
|
batch_predictions = model(batch_data)
|
||
|
predicted_batch_labels = [get_label(prediction) for prediction in batch_predictions]
|
||
|
total += len(predicted_batch_labels)
|
||
|
batch_labels_list = list(map(int,batch_labels.tolist()))
|
||
|
correct += sum(x == y for x, y in zip(predicted_batch_labels, batch_labels_list))
|
||
|
predictions.extend(batch_predictions)
|
||
|
labels.extend(batch_labels)
|
||
|
accuracy = correct/total
|
||
|
return predictions, accuracy
|
||
|
|
||
|
def save_predictions(predictions: list[float]) -> None:
|
||
|
filename = "results.csv"
|
||
|
column_name = "predict"
|
||
|
with open(filename, 'w', newline='') as file:
|
||
|
writer = csv.writer(file)
|
||
|
writer.writerow([column_name])
|
||
|
for result in predictions:
|
||
|
loan_decision = 1 if result.item() > 0.5 else 0
|
||
|
writer.writerow([loan_decision])
|
||
|
|
||
|
def save_accuracy(accuracy):
|
||
|
filename = 'results.csv'
|
||
|
if os.path.exists(filename):
|
||
|
with open(filename, 'a') as file:
|
||
|
writer = csv.writer(file)
|
||
|
writer.writerow([accuracy])
|
||
|
else:
|
||
|
with open(filename, 'w') as file:
|
||
|
writer = csv.writer(file)
|
||
|
writer.writerow(['accuracy'])
|
||
|
writer.writerow([accuracy])
|
||
|
|
||
|
def plot_accuracy():
|
||
|
filename = 'results.csv'
|
||
|
accuracy_results = []
|
||
|
if os.path.exists(filename):
|
||
|
with open(filename, 'r') as file:
|
||
|
reader = csv.reader(file)
|
||
|
for idx, row in enumerate(reader):
|
||
|
if idx == 0:
|
||
|
continue
|
||
|
accuracy_results.append(float(row[0]))
|
||
|
iterations = list(map(str,range(1, len(accuracy_results)+1)))
|
||
|
plt.plot(iterations, accuracy_results)
|
||
|
plt.xlabel('build')
|
||
|
plt.ylabel('accuracy')
|
||
|
plt.title("Accuracies over builds.")
|
||
|
plt.savefig("plot.png")
|
||
|
|
||
|
def main():
|
||
|
predictions, accuracy = evaluate_model()
|
||
|
save_predictions(predictions)
|
||
|
save_accuracy(accuracy)
|
||
|
plot_accuracy()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|