From 238b22ea2056a5a04fcb926821996320b6cd946a Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Fri, 6 May 2022 22:09:36 +0200 Subject: [PATCH] Log metrics --- train_model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_model.py b/train_model.py index bdd4a1f..fe2a5f3 100644 --- a/train_model.py +++ b/train_model.py @@ -16,7 +16,7 @@ default_epochs = 5 device = "cuda" if torch.cuda.is_available() else "cpu" -def main(batch_size, epochs): +def main(batch_size, epochs, _run): print(f"Using {device} device") plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') @@ -38,7 +38,8 @@ def main(batch_size, epochs): for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) - test(test_dataloader, model, loss_fn) + last_loss = test(test_dataloader, model, loss_fn) + _run.log_scalar('training.loss', last_loss, t) print("Done!") torch.save(model.state_dict(), './model_out') @@ -61,8 +62,8 @@ def experiment_config(): @ex.automain -def run(batch_size, epochs): - main(batch_size, epochs) +def run(batch_size, epochs, _run): + main(batch_size, epochs, _run) ex.add_artifact('model_out')