diff --git a/.gitignore b/.gitignore index 77347e4..680c817 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,4 @@ __pycache__ evaluation_results.txt model_out - +trend.png diff --git a/eval_model.py b/eval_model.py index a14eea4..fef402f 100644 --- a/eval_model.py +++ b/eval_model.py @@ -2,7 +2,7 @@ import torch import sys from train_model import MLP, PlantsDataset, test from torch.utils.data import DataLoader - +import matplotlib.pyplot as plt def load_model(): model = MLP() @@ -15,6 +15,15 @@ def load_dev_dataset(batch_size=64): return DataLoader(plant_dev, batch_size=batch_size) +def make_plot(values): + build_nums = list(range(1, len(values) + 1)) + plt.xlabel('Build number') + plt.ylabel('MSE Loss') + plt.plot(build_nums, values, label='Model MSE Loss over builds') + plt.legend() + plt.savefig('trend.png') + + def main(): model = load_model() dataloader = load_dev_dataset() @@ -22,8 +31,13 @@ def main(): loss_fn = torch.nn.MSELoss() loss = test(dataloader, model, loss_fn) - with open('evaluation_results.txt', 'a+') as f: + with open('evaluation_results.txt', 'r+') as f: + f.read() f.write(f'{str(loss)}\n') + f.flush() + f.seek(0) + values = [float(line) for line in f.readlines() if line] + make_plot(values) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 08cc030..a1e279b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ kaggle==1.5.12 pandas==1.4.1 torch==1.11.0 -numpy~=1.22.3 \ No newline at end of file +numpy~=1.22.3 +matplotlib==3.5.2 \ No newline at end of file