diff --git a/eval_model.py b/eval_model.py new file mode 100644 index 0000000..cac9a81 --- /dev/null +++ b/eval_model.py @@ -0,0 +1,31 @@ +import torch +import sys +from train_model import MLP, PlantsDataset, test +from torch.utils.data import DataLoader +from contextlib import redirect_stdout + + +def load_model(): + model = MLP() + model.load_state_dict(torch.load('./model_out')) + return model + + +def load_dev_dataset(batch_size=64): + plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev') + return DataLoader(plant_dev, batch_size=batch_size) + + +def main(): + model = load_model() + dataloader = load_dev_dataset() + + loss_fn = torch.nn.MSELoss() + + with open('evaluation_results.txt', 'w') as f: + with redirect_stdout(f): + test(dataloader, model, loss_fn) + + +if __name__ == "__main__": + main() \ No newline at end of file