modified train.py and Dockerfile

This commit is contained in:
Jakub Henyk 2023-05-07 18:22:23 +02:00
parent 5cb4a5040a
commit 3e2e907dd6
2 changed files with 7 additions and 5 deletions

View File

@ -28,5 +28,5 @@ RUN chmod +x ./train.py
ARG epochs=5 ARG epochs=5
RUN echo $epochs RUN echo $epochs
CMD python3 ./train.py CMD python3 ./train.py epochs
#CMD python3 ./test.py #CMD python3 ./test.py

View File

@ -9,7 +9,7 @@ import torchvision.transforms as transforms
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import sys
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
@ -31,8 +31,8 @@ class Net(nn.Module):
return x return x
def trainNet(trainloader, criterion, optimizer): def trainNet(trainloader, criterion, optimizer, epochs=20):
for epoch in range(20): for epoch in range(epochs):
for i, data in enumerate(trainloader, 0): for i, data in enumerate(trainloader, 0):
inputs, labels = data inputs, labels = data
@ -74,7 +74,9 @@ if __name__ == '__main__':
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
trainNet(trainloader, criterion, optimizer) epochs = sys.argv[1]
trainNet(trainloader, criterion, optimizer, epochs)
PATH = './cifar_net.pth' PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH) torch.save(net.state_dict(), PATH)