From 3e53e29514aed95b12d434e43eaff71380b1f8e0 Mon Sep 17 00:00:00 2001 From: Jakub Henyk Date: Sun, 7 May 2023 18:31:23 +0200 Subject: [PATCH] fix13 --- train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 89c5c9f..8700078 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import sys -import re class Net(nn.Module): def __init__(self): @@ -76,10 +75,9 @@ if __name__ == '__main__': optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) epochs = sys.argv[1] + print(epochs) - epochs_m = re.findall(r'\d+\.\d+', epochs) - - trainNet(trainloader, criterion, optimizer, int(float(epochs_m[0]))) + trainNet(trainloader, criterion, optimizer, int(float(epochs))) PATH = './cifar_net.pth' torch.save(net.state_dict(), PATH)