Code reformat
This commit is contained in:
parent
d02ebda808
commit
593b17cd2b
@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import sys
|
||||
from train_model import MLP, PlantsDataset, test
|
||||
from torch.utils.data import DataLoader
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from train_model import MLP, PlantsDataset, test
|
||||
|
||||
|
||||
def load_model():
|
||||
model = MLP()
|
||||
|
@ -1,17 +1,14 @@
|
||||
from ast import arg
|
||||
from sqlite3 import paramstyle
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import argparse
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
default_batch_size = 64
|
||||
default_epochs = 5
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@ -134,5 +131,6 @@ def main():
|
||||
torch.save(model.state_dict(), './model_out')
|
||||
print("Model saved in ./model_out file.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user