Code reformat

This commit is contained in:
Marcin Kostrzewski 2022-05-06 20:20:22 +02:00
parent d02ebda808
commit 593b17cd2b
2 changed files with 12 additions and 13 deletions

View File

@ -1,8 +1,9 @@
import torch
import sys
from train_model import MLP, PlantsDataset, test
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from train_model import MLP, PlantsDataset, test
def load_model():
model = MLP()
@ -27,16 +28,16 @@ def make_plot(values):
def main():
model = load_model()
dataloader = load_dev_dataset()
loss_fn = torch.nn.MSELoss()
loss = test(dataloader, model, loss_fn)
with open('evaluation_results.txt', 'a+') as f:
f.write(f'{str(loss)}\n')
with open('evaluation_results.txt', 'r') as f:
with open('evaluation_results.txt', 'r') as f:
values = [float(line) for line in f.readlines() if line]
make_plot(values)
if __name__ == "__main__":
main()
main()

View File

@ -1,17 +1,14 @@
from ast import arg
from sqlite3 import paramstyle
import argparse
import numpy as np
import pandas as pd
import torch
import argparse
from torch import nn
from torch.utils.data import DataLoader, Dataset
default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -134,5 +131,6 @@ def main():
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
if __name__ == "__main__":
main()
main()