fix11
This commit is contained in:
parent
a0dc9ef9aa
commit
4660aed907
@ -13,6 +13,7 @@ RUN python3 -m pip install pandas
|
||||
RUN python3 -m pip install numpy
|
||||
RUN python3 -m pip install torch
|
||||
RUN python3 -m pip install torchvision
|
||||
RUN python3 -m pip install regex
|
||||
|
||||
COPY ./zadanie1.py ./
|
||||
COPY ./Customers.csv ./
|
||||
|
5
train.py
5
train.py
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import sys
|
||||
import re
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
@ -76,7 +77,9 @@ if __name__ == '__main__':
|
||||
|
||||
epochs = sys.argv[1]
|
||||
|
||||
trainNet(trainloader, criterion, optimizer, int(float(epochs)))
|
||||
epochs_m = re.findall(r'\d+\.\d+', epochs)
|
||||
|
||||
trainNet(trainloader, criterion, optimizer, int(float(epochs_m[0])))
|
||||
|
||||
PATH = './cifar_net.pth'
|
||||
torch.save(net.state_dict(), PATH)
|
||||
|
Loading…
Reference in New Issue
Block a user