diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index e5f3f16..9d28a36 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -130,8 +130,9 @@ model=WineQuality() def evaluate(model, val_loader): outputs = [model.validation_step(batch) for batch in val_loader] return model.validation_epoch_end(outputs) - -def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): + +@ex.capture +def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, _run): history = [] optimizer = opt_func(model.parameters(), lr) for epoch in range(epochs): @@ -143,15 +144,25 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): result = evaluate(model, val_loader) model.epoch_end(epoch, result, epochs) history.append(result) + ex.add_artifact("saved_model.pb") return history # In[12]: +try: + numberOfEpochParam = int(sys.argv[1]) +except: + # dafault val + numberOfEpochParam = 1500 + + +@ex.config +def my_config(): + epochs = numberOfEpochParam + + -#epochs = int(sys.argv[1]) -lr = 1e-6 -history5 = fit(epochs, lr, model, train_loader, val_loader) # In[27]: @@ -184,4 +195,9 @@ with open("result.txt", "w+") as file: file.write(str(predict_single(input_, target, model))) +@ex.main +def main(): + lr = 1e-6 + history5 = fit(epochs, lr, model, train_loader, val_loader) +ex.run()