Compare commits

..

No commits in common. "3e53e29514aed95b12d434e43eaff71380b1f8e0" and "4660aed907989c8c769b68ba78a3ea0cc3999a7a" have entirely different histories.

2 changed files with 5 additions and 2 deletions

View File

@ -13,6 +13,7 @@ RUN python3 -m pip install pandas
RUN python3 -m pip install numpy
RUN python3 -m pip install torch
RUN python3 -m pip install torchvision
RUN python3 -m pip install regex
COPY ./zadanie1.py ./
COPY ./Customers.csv ./

View File

@ -10,6 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
import re
class Net(nn.Module):
def __init__(self):
@ -75,9 +76,10 @@ if __name__ == '__main__':
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = sys.argv[1]
print(epochs)
trainNet(trainloader, criterion, optimizer, int(float(epochs)))
epochs_m = re.findall(r'\d+\.\d+', epochs)
trainNet(trainloader, criterion, optimizer, int(float(epochs_m[0])))
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)