From 4660aed907989c8c769b68ba78a3ea0cc3999a7a Mon Sep 17 00:00:00 2001 From: Jakub Henyk Date: Sun, 7 May 2023 18:29:20 +0200 Subject: [PATCH] fix11 --- Dockerfile | 1 + train.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index edd189c..f299131 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,6 +13,7 @@ RUN python3 -m pip install pandas RUN python3 -m pip install numpy RUN python3 -m pip install torch RUN python3 -m pip install torchvision +RUN python3 -m pip install regex COPY ./zadanie1.py ./ COPY ./Customers.csv ./ diff --git a/train.py b/train.py index 216fce5..89c5c9f 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import sys +import re class Net(nn.Module): def __init__(self): @@ -76,7 +77,9 @@ if __name__ == '__main__': epochs = sys.argv[1] - trainNet(trainloader, criterion, optimizer, int(float(epochs))) + epochs_m = re.findall(r'\d+\.\d+', epochs) + + trainNet(trainloader, criterion, optimizer, int(float(epochs_m[0]))) PATH = './cifar_net.pth' torch.save(net.state_dict(), PATH)