From 593b17cd2b8fd1e818e4168870e2813c84c67d92 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Fri, 6 May 2022 20:20:22 +0200 Subject: [PATCH] Code reformat --- eval_model.py | 15 ++++++++------- train_model.py | 10 ++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/eval_model.py b/eval_model.py index f83f923..e496104 100644 --- a/eval_model.py +++ b/eval_model.py @@ -1,8 +1,9 @@ -import torch -import sys -from train_model import MLP, PlantsDataset, test -from torch.utils.data import DataLoader import matplotlib.pyplot as plt +import torch +from torch.utils.data import DataLoader + +from train_model import MLP, PlantsDataset, test + def load_model(): model = MLP() @@ -27,16 +28,16 @@ def make_plot(values): def main(): model = load_model() dataloader = load_dev_dataset() - + loss_fn = torch.nn.MSELoss() loss = test(dataloader, model, loss_fn) with open('evaluation_results.txt', 'a+') as f: f.write(f'{str(loss)}\n') - with open('evaluation_results.txt', 'r') as f: + with open('evaluation_results.txt', 'r') as f: values = [float(line) for line in f.readlines() if line] make_plot(values) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/train_model.py b/train_model.py index cdaedeb..5aa3932 100644 --- a/train_model.py +++ b/train_model.py @@ -1,17 +1,14 @@ -from ast import arg -from sqlite3 import paramstyle +import argparse + import numpy as np import pandas as pd import torch -import argparse from torch import nn from torch.utils.data import DataLoader, Dataset - default_batch_size = 64 default_epochs = 5 - device = "cuda" if torch.cuda.is_available() else "cpu" @@ -134,5 +131,6 @@ def main(): torch.save(model.state_dict(), './model_out') print("Model saved in ./model_out file.") + if __name__ == "__main__": - main() \ No newline at end of file + main()