Add model evaluation script

This commit is contained in:
Marcin Kostrzewski 2022-05-05 22:33:52 +02:00
parent fe63ef269c
commit 28585de7c3

31
eval_model.py Normal file
View File

@ -0,0 +1,31 @@
import torch
import sys
from train_model import MLP, PlantsDataset, test
from torch.utils.data import DataLoader
from contextlib import redirect_stdout
def load_model():
model = MLP()
model.load_state_dict(torch.load('./model_out'))
return model
def load_dev_dataset(batch_size=64):
plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev')
return DataLoader(plant_dev, batch_size=batch_size)
def main():
model = load_model()
dataloader = load_dev_dataset()
loss_fn = torch.nn.MSELoss()
with open('evaluation_results.txt', 'w') as f:
with redirect_stdout(f):
test(dataloader, model, loss_fn)
if __name__ == "__main__":
main()