modified train.py and Dockerfile
This commit is contained in:
parent
5cb4a5040a
commit
3e2e907dd6
@ -28,5 +28,5 @@ RUN chmod +x ./train.py
|
|||||||
ARG epochs=5
|
ARG epochs=5
|
||||||
RUN echo $epochs
|
RUN echo $epochs
|
||||||
|
|
||||||
CMD python3 ./train.py
|
CMD python3 ./train.py epochs
|
||||||
#CMD python3 ./test.py
|
#CMD python3 ./test.py
|
10
train.py
10
train.py
@ -9,7 +9,7 @@ import torchvision.transforms as transforms
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import sys
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -31,8 +31,8 @@ class Net(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def trainNet(trainloader, criterion, optimizer):
|
def trainNet(trainloader, criterion, optimizer, epochs=20):
|
||||||
for epoch in range(20):
|
for epoch in range(epochs):
|
||||||
|
|
||||||
for i, data in enumerate(trainloader, 0):
|
for i, data in enumerate(trainloader, 0):
|
||||||
inputs, labels = data
|
inputs, labels = data
|
||||||
@ -74,7 +74,9 @@ if __name__ == '__main__':
|
|||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||||
|
|
||||||
trainNet(trainloader, criterion, optimizer)
|
epochs = sys.argv[1]
|
||||||
|
|
||||||
|
trainNet(trainloader, criterion, optimizer, epochs)
|
||||||
|
|
||||||
PATH = './cifar_net.pth'
|
PATH = './cifar_net.pth'
|
||||||
torch.save(net.state_dict(), PATH)
|
torch.save(net.state_dict(), PATH)
|
||||||
|
Loading…
Reference in New Issue
Block a user