import torch
import sys
from train_model import MLP, PlantsDataset, test
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def load_model():
    model = MLP()
    model.load_state_dict(torch.load('./model_out'))
    return model


def load_dev_dataset(batch_size=64):
    plant_dev = PlantsDataset('data/Plant_1_Generation_Data.csv.dev')
    return DataLoader(plant_dev, batch_size=batch_size)


def make_plot(values):
    build_nums = list(range(1, len(values) + 1))
    plt.xlabel('Build number')
    plt.ylabel('MSE Loss')
    plt.plot(build_nums, values, label='Model MSE Loss over builds')
    plt.legend()
    plt.savefig('trend.png')


def main():
    model = load_model()
    dataloader = load_dev_dataset()
    
    loss_fn = torch.nn.MSELoss()

    loss = test(dataloader, model, loss_fn)
    with open('evaluation_results.txt', 'r+') as f:
        f.read()
        f.write(f'{str(loss)}\n')
        f.flush()
        f.seek(0)
        values = [float(line) for line in f.readlines() if line]
        make_plot(values)


if __name__ == "__main__":
   main()