diff --git a/Dockerfile b/Dockerfile index 818d24c..edd189c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,5 +28,5 @@ RUN chmod +x ./train.py ARG epochs=5 RUN echo $epochs -CMD python3 ./train.py +CMD python3 ./train.py epochs #CMD python3 ./test.py \ No newline at end of file diff --git a/train.py b/train.py index 67343f7..1d6ec60 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import torchvision.transforms as transforms import torch.nn as nn import torch.nn.functional as F import torch.optim as optim - +import sys class Net(nn.Module): def __init__(self): @@ -31,8 +31,8 @@ class Net(nn.Module): return x -def trainNet(trainloader, criterion, optimizer): - for epoch in range(20): +def trainNet(trainloader, criterion, optimizer, epochs=20): + for epoch in range(epochs): for i, data in enumerate(trainloader, 0): inputs, labels = data @@ -74,7 +74,9 @@ if __name__ == '__main__': criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - trainNet(trainloader, criterion, optimizer) + epochs = sys.argv[1] + + trainNet(trainloader, criterion, optimizer, epochs) PATH = './cifar_net.pth' torch.save(net.state_dict(), PATH)