import torch import sys from train_model import MLP, PlantsDataset, test from torch.utils.data import DataLoader from contextlib import redirect_stdout def load_model(): model = MLP() model.load_state_dict(torch.load('./model_out')) return model def load_dev_dataset(batch_size=64): plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev') return DataLoader(plant_dev, batch_size=batch_size) def main(): model = load_model() dataloader = load_dev_dataset() loss_fn = torch.nn.MSELoss() with open('evaluation_results.txt', 'w') as f: with redirect_stdout(f): test(dataloader, model, loss_fn) if __name__ == "__main__": main()