save model pb
Some checks failed
s444354-training/pipeline/head There was a failure building this commit
Some checks failed
s444354-training/pipeline/head There was a failure building this commit
This commit is contained in:
parent
6d0ecd01c4
commit
555fb25ac9
@ -130,8 +130,9 @@ model=WineQuality()
|
|||||||
def evaluate(model, val_loader):
|
def evaluate(model, val_loader):
|
||||||
outputs = [model.validation_step(batch) for batch in val_loader]
|
outputs = [model.validation_step(batch) for batch in val_loader]
|
||||||
return model.validation_epoch_end(outputs)
|
return model.validation_epoch_end(outputs)
|
||||||
|
|
||||||
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
@ex.capture
|
||||||
|
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, _run):
|
||||||
history = []
|
history = []
|
||||||
optimizer = opt_func(model.parameters(), lr)
|
optimizer = opt_func(model.parameters(), lr)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
@ -143,15 +144,25 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
|||||||
result = evaluate(model, val_loader)
|
result = evaluate(model, val_loader)
|
||||||
model.epoch_end(epoch, result, epochs)
|
model.epoch_end(epoch, result, epochs)
|
||||||
history.append(result)
|
history.append(result)
|
||||||
|
ex.add_artifact("saved_model.pb")
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
# In[12]:
|
# In[12]:
|
||||||
|
|
||||||
|
try:
|
||||||
|
numberOfEpochParam = int(sys.argv[1])
|
||||||
|
except:
|
||||||
|
# dafault val
|
||||||
|
numberOfEpochParam = 1500
|
||||||
|
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def my_config():
|
||||||
|
epochs = numberOfEpochParam
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#epochs = int(sys.argv[1])
|
|
||||||
lr = 1e-6
|
|
||||||
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
|
||||||
|
|
||||||
|
|
||||||
# In[27]:
|
# In[27]:
|
||||||
@ -184,4 +195,9 @@ with open("result.txt", "w+") as file:
|
|||||||
file.write(str(predict_single(input_, target, model)))
|
file.write(str(predict_single(input_, target, model)))
|
||||||
|
|
||||||
|
|
||||||
|
@ex.main
|
||||||
|
def main():
|
||||||
|
lr = 1e-6
|
||||||
|
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
||||||
|
|
||||||
|
ex.run()
|
||||||
|
Loading…
Reference in New Issue
Block a user