save model pb
s444354-training/pipeline/head There was a failure building this commit Details

This commit is contained in:
Adrian Charkiewicz 2022-05-07 21:28:44 +02:00
parent 6d0ecd01c4
commit 555fb25ac9
1 changed files with 21 additions and 5 deletions

View File

@ -131,7 +131,8 @@ def evaluate(model, val_loader):
outputs = [model.validation_step(batch) for batch in val_loader]
return model.validation_epoch_end(outputs)
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
@ex.capture
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, _run):
history = []
optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs):
@ -143,15 +144,25 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
result = evaluate(model, val_loader)
model.epoch_end(epoch, result, epochs)
history.append(result)
ex.add_artifact("saved_model.pb")
return history
# In[12]:
try:
numberOfEpochParam = int(sys.argv[1])
except:
# dafault val
numberOfEpochParam = 1500
@ex.config
def my_config():
epochs = numberOfEpochParam
#epochs = int(sys.argv[1])
lr = 1e-6
history5 = fit(epochs, lr, model, train_loader, val_loader)
# In[27]:
@ -184,4 +195,9 @@ with open("result.txt", "w+") as file:
file.write(str(predict_single(input_, target, model)))
@ex.main
def main():
lr = 1e-6
history5 = fit(epochs, lr, model, train_loader, val_loader)
ex.run()