ium_z487179/ML/model_test.py

31 lines
1019 B
Python
Raw Normal View History

2023-06-10 22:49:15 +02:00
import torch
from model_train import MyNeuralNetwork, load_data
from torch.utils.data import DataLoader
import csv
def main() -> None:
model: MyNeuralNetwork = MyNeuralNetwork()
model.load_state_dict(torch.load('model.pt'))
model.eval()
test_dataset = load_data("home_loan_test.csv")
batch_size: int = 32
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
predictions = []
labels = []
with torch.no_grad():
for batch_data, batch_labels in test_dataloader:
batch_predictions = model(batch_data)
predictions.extend(batch_predictions)
labels.extend(batch_labels)
filename = "results.csv"
column_name = "predict"
with open(filename, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow([column_name])
for result in predictions:
loan_decision = 1 if result.item() > 0.5 else 0
writer.writerow([loan_decision])
if __name__ == "__main__":
main()