import torch from train import MyNeuralNetwork, load_data from torch.utils.data import DataLoader import csv import os import matplotlib.pyplot as plt from typing import Tuple, List def evaluate_model() -> Tuple[List[float], float]: model = MyNeuralNetwork() model.load_state_dict(torch.load('model.pt')) model.eval() test_dataset = load_data("gender_classification_test.csv") batch_size: int = 32 test_dataloader: DataLoader = DataLoader(test_dataset, batch_size=batch_size) predictions = [] labels = [] get_label = lambda pred: 1 if pred >= 0.5 else 0 total = 0 correct = 0 with torch.no_grad(): for batch_data, batch_labels in test_dataloader: batch_predictions = model(batch_data) predicted_batch_labels = [get_label(prediction) for prediction in batch_predictions] total += len(predicted_batch_labels) batch_labels_list = list(map(int,batch_labels.tolist())) correct += sum(x == y for x, y in zip(predicted_batch_labels, batch_labels_list)) predictions.extend(batch_predictions) labels.extend(batch_labels) accuracy = correct/total return predictions, accuracy def save_predictions(predictions: list[float]) -> None: filename = "results.csv" column_name = "predict" with open(filename, 'w', newline='') as file: writer = csv.writer(file) writer.writerow([column_name]) for result in predictions: loan_decision = 1 if result.item() > 0.5 else 0 writer.writerow([loan_decision]) def save_accuracy(accuracy): filename = 'results.csv' if os.path.exists(filename): with open(filename, 'a') as file: writer = csv.writer(file) writer.writerow([accuracy]) else: with open(filename, 'w') as file: writer = csv.writer(file) writer.writerow(['accuracy']) writer.writerow([accuracy]) def plot_accuracy(): filename = 'results.csv' accuracy_results = [] if os.path.exists(filename): with open(filename, 'r') as file: reader = csv.reader(file) for idx, row in enumerate(reader): if idx == 0: continue accuracy_results.append(float(row[0])) iterations = list(map(str,range(1, len(accuracy_results)+1))) plt.plot(iterations, accuracy_results) plt.xlabel('build') plt.ylabel('accuracy') plt.title("Accuracies over builds.") plt.savefig("plot.png") def main(): predictions, accuracy = evaluate_model() save_predictions(predictions) save_accuracy(accuracy) plot_accuracy() if __name__ == "__main__": main()