From 023b0756a04f6fd6f141b6c35f767528c3725c63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Wa=C5=82=C4=99sa?= Date: Sun, 8 May 2022 20:38:37 +0200 Subject: [PATCH] Zaktualizuj 'ml_pytroch_sacred.py' --- ml_pytroch_sacred.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ml_pytroch_sacred.py b/ml_pytroch_sacred.py index f4be1cc..3af01a5 100644 --- a/ml_pytroch_sacred.py +++ b/ml_pytroch_sacred.py @@ -100,21 +100,23 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): history.append(result) return history - -input_size = len(input_cols) -output_size = len(output_cols) -model=Model_xPosition() -lr = 1e-5 -learning_proccess = fit(epochs, lr, model, train_loader, val_loader) def predict_single(input, target, model): inputs = input.unsqueeze(0) predictions = model(inputs) prediction = predictions[0].detach() return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n" + +input_size = len(input_cols) +output_size = len(output_cols) +model=Model_xPosition() +lr = 1e-5 + + @ex.automain -def my_main(epochs): +def my_main(epochs): + learning_proccess = fit(epochs, lr, model, train_loader, val_loader) for i in random.sample(range(0, len(val_ds)), 10): input_, target = val_ds[i] print(predict_single(input_, target, model),end="")