update dllib-mlflow.py
All checks were successful
s444356-evaluation/pipeline/head This commit looks good
s444356-training/pipeline/head This commit looks good

This commit is contained in:
Maciej Czajka 2022-05-11 15:23:22 +02:00
parent 09d0026265
commit 777e10f0ef

View File

@ -305,13 +305,13 @@ def my_main(epochs):
loss = loss_fn(y_pred, y_train.squeeze(-1))
print_(loss.item())
mlflow.log_param("loss", loss.item)
# Zero gradients
optimizer.zero_grad()
loss.backward() # Gradients
optimizer.step() # Update
mlflow.log_param("loss", loss.item)
# Prediction
x_test = Variable(torch.from_numpy(features_test_g)).float()
pred = model(x_test)