update dllib-mlflow.py
Some checks failed
s444356-evaluation/pipeline/head This commit looks good
s444356-training/pipeline/head There was a failure building this commit

This commit is contained in:
Maciej Czajka 2022-05-11 15:21:21 +02:00
parent e10b25c9c9
commit 09d0026265

View File

@ -291,6 +291,7 @@ def my_main(epochs):
loss_fn = nn.CrossEntropyLoss()
# epochs = 1000
# epochs = epochs
mlflow.log_param("epochs", epochs)
def print_(loss):
print ("The loss calculated: ", loss)
@ -304,6 +305,8 @@ def my_main(epochs):
loss = loss_fn(y_pred, y_train.squeeze(-1))
print_(loss.item())
mlflow.log_param("loss", loss.item)
# Zero gradients
optimizer.zero_grad()
loss.backward() # Gradients
@ -329,4 +332,3 @@ def my_main(epochs):
with mlflow.start_run() as run:
my_main(epochs)
mlflow.log_param("epochs", epochs)