Compare commits

...

2 Commits

Author SHA1 Message Date
Jakub Henyk
3e53e29514 fix13 2023-05-07 18:31:23 +02:00
Jakub Henyk
7130fa7cdd fix12 2023-05-07 18:31:14 +02:00
2 changed files with 2 additions and 5 deletions

View File

@ -13,7 +13,6 @@ RUN python3 -m pip install pandas
RUN python3 -m pip install numpy RUN python3 -m pip install numpy
RUN python3 -m pip install torch RUN python3 -m pip install torch
RUN python3 -m pip install torchvision RUN python3 -m pip install torchvision
RUN python3 -m pip install regex
COPY ./zadanie1.py ./ COPY ./zadanie1.py ./
COPY ./Customers.csv ./ COPY ./Customers.csv ./

View File

@ -10,7 +10,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import sys import sys
import re
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
@ -76,10 +75,9 @@ if __name__ == '__main__':
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = sys.argv[1] epochs = sys.argv[1]
print(epochs)
epochs_m = re.findall(r'\d+\.\d+', epochs) trainNet(trainloader, criterion, optimizer, int(float(epochs)))
trainNet(trainloader, criterion, optimizer, int(float(epochs_m[0])))
PATH = './cifar_net.pth' PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH) torch.save(net.state_dict(), PATH)