modified train.py and Dockerfile
This commit is contained in:
parent
5cb4a5040a
commit
3e2e907dd6
@ -28,5 +28,5 @@ RUN chmod +x ./train.py
|
||||
ARG epochs=5
|
||||
RUN echo $epochs
|
||||
|
||||
CMD python3 ./train.py
|
||||
CMD python3 ./train.py epochs
|
||||
#CMD python3 ./test.py
|
10
train.py
10
train.py
@ -9,7 +9,7 @@ import torchvision.transforms as transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
import sys
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
@ -31,8 +31,8 @@ class Net(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def trainNet(trainloader, criterion, optimizer):
|
||||
for epoch in range(20):
|
||||
def trainNet(trainloader, criterion, optimizer, epochs=20):
|
||||
for epoch in range(epochs):
|
||||
|
||||
for i, data in enumerate(trainloader, 0):
|
||||
inputs, labels = data
|
||||
@ -74,7 +74,9 @@ if __name__ == '__main__':
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
trainNet(trainloader, criterion, optimizer)
|
||||
epochs = sys.argv[1]
|
||||
|
||||
trainNet(trainloader, criterion, optimizer, epochs)
|
||||
|
||||
PATH = './cifar_net.pth'
|
||||
torch.save(net.state_dict(), PATH)
|
||||
|
Loading…
Reference in New Issue
Block a user