From 256b377625d14e37de0a035886ab533cc41752ac Mon Sep 17 00:00:00 2001 From: Adrian Charkiewicz Date: Sat, 7 May 2022 23:03:06 +0200 Subject: [PATCH] automain --- pytorch/pytorch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index faddc97..29e58b1 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -152,6 +152,7 @@ def evaluate(model, val_loader): @ex.capture def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD): + epochs=epochs history = [] optimizer = opt_func(model.parameters(), lr) for epoch in range(epochs): @@ -195,11 +196,11 @@ with open("result.txt", "w+") as file: file.write(str(predict_single(input_, target, model))) -@ex.main +@ex.automain def main(epochs, _run): lr = 1e-6 my_config() print("number of epochs is: ", epochs) history5 = fit(epochs, lr, model, train_loader, val_loader) -ex.run() +#ex.run()