Code reformat

This commit is contained in:
Marcin Kostrzewski 2022-05-06 20:20:22 +02:00
parent d02ebda808
commit 593b17cd2b
2 changed files with 12 additions and 13 deletions

View File

@ -1,8 +1,9 @@
import torch
import sys
from train_model import MLP, PlantsDataset, test
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from train_model import MLP, PlantsDataset, test
def load_model(): def load_model():
model = MLP() model = MLP()

View File

@ -1,17 +1,14 @@
from ast import arg import argparse
from sqlite3 import paramstyle
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import argparse
from torch import nn from torch import nn
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
default_batch_size = 64 default_batch_size = 64
default_epochs = 5 default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@ -134,5 +131,6 @@ def main():
torch.save(model.state_dict(), './model_out') torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.") print("Model saved in ./model_out file.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()