ium_444409/eval_model.py

44 lines
1.0 KiB
Python
Raw Permalink Normal View History

2022-05-06 20:20:22 +02:00
import matplotlib.pyplot as plt
2022-05-05 22:33:52 +02:00
import torch
from torch.utils.data import DataLoader
2022-05-06 20:20:22 +02:00
from model import MLP, PlantsDataset, test
2022-05-06 20:20:22 +02:00
2022-05-05 22:33:52 +02:00
def load_model():
model = MLP()
model.load_state_dict(torch.load('./model_out'))
return model
def load_dev_dataset(batch_size=64):
plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev')
return DataLoader(plant_dev, batch_size=batch_size)
2022-05-05 22:51:30 +02:00
def make_plot(values):
build_nums = list(range(1, len(values) + 1))
plt.xlabel('Build number')
plt.ylabel('MSE Loss')
plt.plot(build_nums, values, label='Model MSE Loss over builds')
plt.legend()
plt.savefig('trend.png')
2022-05-05 22:33:52 +02:00
def main():
model = load_model()
dataloader = load_dev_dataset()
2022-05-06 20:20:22 +02:00
2022-05-05 22:33:52 +02:00
loss_fn = torch.nn.MSELoss()
loss = test(dataloader, model, loss_fn)
2022-05-05 23:22:10 +02:00
with open('evaluation_results.txt', 'a+') as f:
f.write(f'{str(loss)}\n')
2022-05-06 20:20:22 +02:00
with open('evaluation_results.txt', 'r') as f:
2022-05-05 22:51:30 +02:00
values = [float(line) for line in f.readlines() if line]
make_plot(values)
2022-05-05 22:33:52 +02:00
if __name__ == "__main__":
2022-05-06 20:20:22 +02:00
main()