Add model evaluation script
This commit is contained in:
parent
fe63ef269c
commit
28585de7c3
31
eval_model.py
Normal file
31
eval_model.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import sys
|
||||
from train_model import MLP, PlantsDataset, test
|
||||
from torch.utils.data import DataLoader
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
|
||||
def load_model():
|
||||
model = MLP()
|
||||
model.load_state_dict(torch.load('./model_out'))
|
||||
return model
|
||||
|
||||
|
||||
def load_dev_dataset(batch_size=64):
|
||||
plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev')
|
||||
return DataLoader(plant_dev, batch_size=batch_size)
|
||||
|
||||
|
||||
def main():
|
||||
model = load_model()
|
||||
dataloader = load_dev_dataset()
|
||||
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
|
||||
with open('evaluation_results.txt', 'w') as f:
|
||||
with redirect_stdout(f):
|
||||
test(dataloader, model, loss_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user