Saving trend figure
This commit is contained in:
parent
abf429ee75
commit
435e0e8dad
2
.gitignore
vendored
2
.gitignore
vendored
@ -18,4 +18,4 @@ __pycache__
|
||||
|
||||
evaluation_results.txt
|
||||
model_out
|
||||
|
||||
trend.png
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import sys
|
||||
from train_model import MLP, PlantsDataset, test
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def load_model():
|
||||
model = MLP()
|
||||
@ -15,6 +15,15 @@ def load_dev_dataset(batch_size=64):
|
||||
return DataLoader(plant_dev, batch_size=batch_size)
|
||||
|
||||
|
||||
def make_plot(values):
|
||||
build_nums = list(range(1, len(values) + 1))
|
||||
plt.xlabel('Build number')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.plot(build_nums, values, label='Model MSE Loss over builds')
|
||||
plt.legend()
|
||||
plt.savefig('trend.png')
|
||||
|
||||
|
||||
def main():
|
||||
model = load_model()
|
||||
dataloader = load_dev_dataset()
|
||||
@ -22,8 +31,13 @@ def main():
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
|
||||
loss = test(dataloader, model, loss_fn)
|
||||
with open('evaluation_results.txt', 'a+') as f:
|
||||
with open('evaluation_results.txt', 'r+') as f:
|
||||
f.read()
|
||||
f.write(f'{str(loss)}\n')
|
||||
f.flush()
|
||||
f.seek(0)
|
||||
values = [float(line) for line in f.readlines() if line]
|
||||
make_plot(values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,3 +2,4 @@ kaggle==1.5.12
|
||||
pandas==1.4.1
|
||||
torch==1.11.0
|
||||
numpy~=1.22.3
|
||||
matplotlib==3.5.2
|
Loading…
Reference in New Issue
Block a user