37 lines
941 B
Python
37 lines
941 B
Python
|
import matplotlib.pyplot as plt
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
from neutral_network import MLP, AtpDataset, test
|
||
|
|
||
|
def load_model():
|
||
|
model = MLP()
|
||
|
model.load_state_dict(torch.load('./model.zip'))
|
||
|
return model
|
||
|
|
||
|
def load_dev_dataset(batch_size=64):
|
||
|
atp_dev = AtpDataset('atp_dev.csv')
|
||
|
return DataLoader(atp_dev, batch_size=batch_size)
|
||
|
|
||
|
def make_plot(values):
|
||
|
build_nums = list(range(1, len(values) + 1))
|
||
|
plt.xlabel('Build number')
|
||
|
plt.ylabel('MSE Loss')
|
||
|
plt.plot(build_nums, values, label='Model MSE Loss over builds')
|
||
|
plt.legend()
|
||
|
plt.savefig('plot.png')
|
||
|
|
||
|
model = load_model()
|
||
|
dataloader = load_dev_dataset()
|
||
|
|
||
|
loss_fn = torch.nn.MSELoss()
|
||
|
|
||
|
loss = test(dataloader, model, loss_fn)
|
||
|
with open('eval_result.txt', 'a+') as f:
|
||
|
f.write(f'{str(loss)}\n')
|
||
|
with open('eval_result.txt', 'r') as f:
|
||
|
values = [float(line) for line in f.readlines() if line]
|
||
|
make_plot(values)
|
||
|
|
||
|
|
||
|
|