Saving trend figure

This commit is contained in:
Marcin Kostrzewski 2022-05-05 22:51:30 +02:00
parent abf429ee75
commit 435e0e8dad
3 changed files with 19 additions and 4 deletions

2
.gitignore vendored
View File

@ -18,4 +18,4 @@ __pycache__
evaluation_results.txt evaluation_results.txt
model_out model_out
trend.png

View File

@ -2,7 +2,7 @@ import torch
import sys import sys
from train_model import MLP, PlantsDataset, test from train_model import MLP, PlantsDataset, test
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def load_model(): def load_model():
model = MLP() model = MLP()
@ -15,6 +15,15 @@ def load_dev_dataset(batch_size=64):
return DataLoader(plant_dev, batch_size=batch_size) return DataLoader(plant_dev, batch_size=batch_size)
def make_plot(values):
build_nums = list(range(1, len(values) + 1))
plt.xlabel('Build number')
plt.ylabel('MSE Loss')
plt.plot(build_nums, values, label='Model MSE Loss over builds')
plt.legend()
plt.savefig('trend.png')
def main(): def main():
model = load_model() model = load_model()
dataloader = load_dev_dataset() dataloader = load_dev_dataset()
@ -22,8 +31,13 @@ def main():
loss_fn = torch.nn.MSELoss() loss_fn = torch.nn.MSELoss()
loss = test(dataloader, model, loss_fn) loss = test(dataloader, model, loss_fn)
with open('evaluation_results.txt', 'a+') as f: with open('evaluation_results.txt', 'r+') as f:
f.read()
f.write(f'{str(loss)}\n') f.write(f'{str(loss)}\n')
f.flush()
f.seek(0)
values = [float(line) for line in f.readlines() if line]
make_plot(values)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
kaggle==1.5.12 kaggle==1.5.12
pandas==1.4.1 pandas==1.4.1
torch==1.11.0 torch==1.11.0
numpy~=1.22.3 numpy~=1.22.3
matplotlib==3.5.2