From d1376a046ad103f5bd454f75a5ea2ad18f8aa4d3 Mon Sep 17 00:00:00 2001 From: Adam Wojdyla Date: Mon, 9 May 2022 00:23:12 +0200 Subject: [PATCH] params update 2 --- lab07_sacred.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lab07_sacred.py b/lab07_sacred.py index 18c487f..0908f29 100644 --- a/lab07_sacred.py +++ b/lab07_sacred.py @@ -76,7 +76,7 @@ def prepare_labels_features(dataset): return lab, feat -@ex.main +@ex.automain def my_main(epochs, _run): # Prepare dataset print("Loading dataset...") @@ -96,17 +96,17 @@ def my_main(epochs, _run): # number of epochs is parametrized try: - epochs = int(epochs) + epochs_n = int(epochs) except Exception as e: print(e) print("Setting default epochs value to 1000.") - epochs = 10 + epochs_n = 10 - print(f"Number of epochs: {epochs}") + print(f"Number of epochs: {epochs_n}") print("Starting model training...") x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long() - for epoch in range(1, epochs + 1): + for epoch in range(1, epochs_n + 1): print("Epoch #", epoch) y_pred = model(x_train) loss = loss_fn(y_pred, y_train) @@ -138,5 +138,5 @@ def my_main(epochs, _run): pd_predictions = pd.DataFrame(pred) pd_predictions.to_csv("./prediction_results.csv") -ex.run() +# ex.run() ex.add_artifact("CarPrices_pytorch_model.pkl") \ No newline at end of file