From 28585de7c3f3e3a07e0959b22fe52e620684aea0 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Thu, 5 May 2022 22:33:52 +0200 Subject: [PATCH] Add model evaluation script --- eval_model.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 eval_model.py diff --git a/eval_model.py b/eval_model.py new file mode 100644 index 0000000..cac9a81 --- /dev/null +++ b/eval_model.py @@ -0,0 +1,31 @@ +import torch +import sys +from train_model import MLP, PlantsDataset, test +from torch.utils.data import DataLoader +from contextlib import redirect_stdout + + +def load_model(): + model = MLP() + model.load_state_dict(torch.load('./model_out')) + return model + + +def load_dev_dataset(batch_size=64): + plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev') + return DataLoader(plant_dev, batch_size=batch_size) + + +def main(): + model = load_model() + dataloader = load_dev_dataset() + + loss_fn = torch.nn.MSELoss() + + with open('evaluation_results.txt', 'w') as f: + with redirect_stdout(f): + test(dataloader, model, loss_fn) + + +if __name__ == "__main__": + main() \ No newline at end of file