Code reformat
This commit is contained in:
parent
d02ebda808
commit
593b17cd2b
@ -1,8 +1,9 @@
|
|||||||
import torch
|
|
||||||
import sys
|
|
||||||
from train_model import MLP, PlantsDataset, test
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from train_model import MLP, PlantsDataset, test
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
model = MLP()
|
model = MLP()
|
||||||
|
@ -1,17 +1,14 @@
|
|||||||
from ast import arg
|
import argparse
|
||||||
from sqlite3 import paramstyle
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
default_batch_size = 64
|
default_batch_size = 64
|
||||||
default_epochs = 5
|
default_epochs = 5
|
||||||
|
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
@ -134,5 +131,6 @@ def main():
|
|||||||
torch.save(model.state_dict(), './model_out')
|
torch.save(model.state_dict(), './model_out')
|
||||||
print("Model saved in ./model_out file.")
|
print("Model saved in ./model_out file.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
Loading…
Reference in New Issue
Block a user