Saving trend figure
This commit is contained in:
parent
abf429ee75
commit
435e0e8dad
2
.gitignore
vendored
2
.gitignore
vendored
@ -18,4 +18,4 @@ __pycache__
|
|||||||
|
|
||||||
evaluation_results.txt
|
evaluation_results.txt
|
||||||
model_out
|
model_out
|
||||||
|
trend.png
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import sys
|
import sys
|
||||||
from train_model import MLP, PlantsDataset, test
|
from train_model import MLP, PlantsDataset, test
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
model = MLP()
|
model = MLP()
|
||||||
@ -15,6 +15,15 @@ def load_dev_dataset(batch_size=64):
|
|||||||
return DataLoader(plant_dev, batch_size=batch_size)
|
return DataLoader(plant_dev, batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def make_plot(values):
|
||||||
|
build_nums = list(range(1, len(values) + 1))
|
||||||
|
plt.xlabel('Build number')
|
||||||
|
plt.ylabel('MSE Loss')
|
||||||
|
plt.plot(build_nums, values, label='Model MSE Loss over builds')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('trend.png')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model = load_model()
|
model = load_model()
|
||||||
dataloader = load_dev_dataset()
|
dataloader = load_dev_dataset()
|
||||||
@ -22,8 +31,13 @@ def main():
|
|||||||
loss_fn = torch.nn.MSELoss()
|
loss_fn = torch.nn.MSELoss()
|
||||||
|
|
||||||
loss = test(dataloader, model, loss_fn)
|
loss = test(dataloader, model, loss_fn)
|
||||||
with open('evaluation_results.txt', 'a+') as f:
|
with open('evaluation_results.txt', 'r+') as f:
|
||||||
|
f.read()
|
||||||
f.write(f'{str(loss)}\n')
|
f.write(f'{str(loss)}\n')
|
||||||
|
f.flush()
|
||||||
|
f.seek(0)
|
||||||
|
values = [float(line) for line in f.readlines() if line]
|
||||||
|
make_plot(values)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,3 +2,4 @@ kaggle==1.5.12
|
|||||||
pandas==1.4.1
|
pandas==1.4.1
|
||||||
torch==1.11.0
|
torch==1.11.0
|
||||||
numpy~=1.22.3
|
numpy~=1.22.3
|
||||||
|
matplotlib==3.5.2
|
Loading…
Reference in New Issue
Block a user