Log metrics

This commit is contained in:
Marcin Kostrzewski 2022-05-06 22:09:36 +02:00
parent d0ab5ae997
commit 238b22ea20

View File

@ -16,7 +16,7 @@ default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
def main(batch_size, epochs):
def main(batch_size, epochs, _run):
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
@ -38,7 +38,8 @@ def main(batch_size, epochs):
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
last_loss = test(test_dataloader, model, loss_fn)
_run.log_scalar('training.loss', last_loss, t)
print("Done!")
torch.save(model.state_dict(), './model_out')
@ -61,8 +62,8 @@ def experiment_config():
@ex.automain
def run(batch_size, epochs):
main(batch_size, epochs)
def run(batch_size, epochs, _run):
main(batch_size, epochs, _run)
ex.add_artifact('model_out')