30 lines
668 B
Python
30 lines
668 B
Python
import torch
|
|
import sys
|
|
from train_model import MLP, PlantsDataset, test
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def load_model():
|
|
model = MLP()
|
|
model.load_state_dict(torch.load('./model_out'))
|
|
return model
|
|
|
|
|
|
def load_dev_dataset(batch_size=64):
|
|
plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev')
|
|
return DataLoader(plant_dev, batch_size=batch_size)
|
|
|
|
|
|
def main():
|
|
model = load_model()
|
|
dataloader = load_dev_dataset()
|
|
|
|
loss_fn = torch.nn.MSELoss()
|
|
|
|
loss = test(dataloader, model, loss_fn)
|
|
with open('evaluation_results.txt', 'a+') as f:
|
|
f.write(f'{str(loss)}\n')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |