31 lines
1019 B
Python
31 lines
1019 B
Python
import torch
|
|
from model_train import MyNeuralNetwork, load_data
|
|
from torch.utils.data import DataLoader
|
|
import csv
|
|
|
|
|
|
def main() -> None:
|
|
model: MyNeuralNetwork = MyNeuralNetwork()
|
|
model.load_state_dict(torch.load('model.pt'))
|
|
model.eval()
|
|
test_dataset = load_data("home_loan_test.csv")
|
|
batch_size: int = 32
|
|
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
|
|
predictions = []
|
|
labels = []
|
|
with torch.no_grad():
|
|
for batch_data, batch_labels in test_dataloader:
|
|
batch_predictions = model(batch_data)
|
|
predictions.extend(batch_predictions)
|
|
labels.extend(batch_labels)
|
|
filename = "results.csv"
|
|
column_name = "predict"
|
|
with open(filename, 'w', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([column_name])
|
|
for result in predictions:
|
|
loan_decision = 1 if result.item() > 0.5 else 0
|
|
writer.writerow([loan_decision])
|
|
|
|
if __name__ == "__main__":
|
|
main() |