Code reformat
This commit is contained in:
parent
d02ebda808
commit
593b17cd2b
@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import sys
|
||||
from train_model import MLP, PlantsDataset, test
|
||||
from torch.utils.data import DataLoader
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from train_model import MLP, PlantsDataset, test
|
||||
|
||||
|
||||
def load_model():
|
||||
model = MLP()
|
||||
@ -27,16 +28,16 @@ def make_plot(values):
|
||||
def main():
|
||||
model = load_model()
|
||||
dataloader = load_dev_dataset()
|
||||
|
||||
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
|
||||
loss = test(dataloader, model, loss_fn)
|
||||
with open('evaluation_results.txt', 'a+') as f:
|
||||
f.write(f'{str(loss)}\n')
|
||||
with open('evaluation_results.txt', 'r') as f:
|
||||
with open('evaluation_results.txt', 'r') as f:
|
||||
values = [float(line) for line in f.readlines() if line]
|
||||
make_plot(values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
@ -1,17 +1,14 @@
|
||||
from ast import arg
|
||||
from sqlite3 import paramstyle
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import argparse
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
default_batch_size = 64
|
||||
default_epochs = 5
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@ -134,5 +131,6 @@ def main():
|
||||
torch.save(model.state_dict(), './model_out')
|
||||
print("Model saved in ./model_out file.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user