Log metrics
This commit is contained in:
parent
d0ab5ae997
commit
238b22ea20
@ -16,7 +16,7 @@ default_epochs = 5
|
|||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def main(batch_size, epochs):
|
def main(batch_size, epochs, _run):
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
||||||
@ -38,7 +38,8 @@ def main(batch_size, epochs):
|
|||||||
for t in range(epochs):
|
for t in range(epochs):
|
||||||
print(f"Epoch {t + 1}\n-------------------------------")
|
print(f"Epoch {t + 1}\n-------------------------------")
|
||||||
train(train_dataloader, model, loss_fn, optimizer)
|
train(train_dataloader, model, loss_fn, optimizer)
|
||||||
test(test_dataloader, model, loss_fn)
|
last_loss = test(test_dataloader, model, loss_fn)
|
||||||
|
_run.log_scalar('training.loss', last_loss, t)
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
torch.save(model.state_dict(), './model_out')
|
torch.save(model.state_dict(), './model_out')
|
||||||
@ -61,8 +62,8 @@ def experiment_config():
|
|||||||
|
|
||||||
|
|
||||||
@ex.automain
|
@ex.automain
|
||||||
def run(batch_size, epochs):
|
def run(batch_size, epochs, _run):
|
||||||
main(batch_size, epochs)
|
main(batch_size, epochs, _run)
|
||||||
|
|
||||||
|
|
||||||
ex.add_artifact('model_out')
|
ex.add_artifact('model_out')
|
||||||
|
Loading…
Reference in New Issue
Block a user