From 2a5e48f63c4385a48120acfdf19ba741256cb758 Mon Sep 17 00:00:00 2001 From: Adrian Charkiewicz Date: Sat, 7 May 2022 21:45:27 +0200 Subject: [PATCH] assure that epochs is saved --- pytorch/pytorch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index 6b8f5df..c4ddd63 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -38,7 +38,7 @@ except: @ex.config def my_config(): epochs = numberOfEpochParam -my_config() + dataframe_raw = pd.read_csv("winequality-red.csv") @@ -150,7 +150,7 @@ def evaluate(model, val_loader): return model.validation_epoch_end(outputs) @ex.capture -def fit(epochs, lr, model, train_loader, val_loader,_run, opt_func=torch.optim.SGD): +def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD): history = [] optimizer = opt_func(model.parameters(), lr) for epoch in range(epochs): @@ -216,6 +216,7 @@ with open("result.txt", "w+") as file: @ex.main def main(): lr = 1e-6 + my_config() history5 = fit(epochs, lr, model, train_loader, val_loader) ex.run()