Compare commits
2 Commits
4660aed907
...
3e53e29514
Author | SHA1 | Date | |
---|---|---|---|
|
3e53e29514 | ||
|
7130fa7cdd |
@ -13,7 +13,6 @@ RUN python3 -m pip install pandas
|
|||||||
RUN python3 -m pip install numpy
|
RUN python3 -m pip install numpy
|
||||||
RUN python3 -m pip install torch
|
RUN python3 -m pip install torch
|
||||||
RUN python3 -m pip install torchvision
|
RUN python3 -m pip install torchvision
|
||||||
RUN python3 -m pip install regex
|
|
||||||
|
|
||||||
COPY ./zadanie1.py ./
|
COPY ./zadanie1.py ./
|
||||||
COPY ./Customers.csv ./
|
COPY ./Customers.csv ./
|
||||||
|
6
train.py
6
train.py
@ -10,7 +10,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import sys
|
import sys
|
||||||
import re
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -76,10 +75,9 @@ if __name__ == '__main__':
|
|||||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||||
|
|
||||||
epochs = sys.argv[1]
|
epochs = sys.argv[1]
|
||||||
|
print(epochs)
|
||||||
|
|
||||||
epochs_m = re.findall(r'\d+\.\d+', epochs)
|
trainNet(trainloader, criterion, optimizer, int(float(epochs)))
|
||||||
|
|
||||||
trainNet(trainloader, criterion, optimizer, int(float(epochs_m[0])))
|
|
||||||
|
|
||||||
PATH = './cifar_net.pth'
|
PATH = './cifar_net.pth'
|
||||||
torch.save(net.state_dict(), PATH)
|
torch.save(net.state_dict(), PATH)
|
||||||
|
Loading…
Reference in New Issue
Block a user