update dllib-mlflow.py
This commit is contained in:
parent
09d0026265
commit
777e10f0ef
@ -305,13 +305,13 @@ def my_main(epochs):
|
|||||||
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
||||||
print_(loss.item())
|
print_(loss.item())
|
||||||
|
|
||||||
mlflow.log_param("loss", loss.item)
|
|
||||||
|
|
||||||
# Zero gradients
|
# Zero gradients
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward() # Gradients
|
loss.backward() # Gradients
|
||||||
optimizer.step() # Update
|
optimizer.step() # Update
|
||||||
|
|
||||||
|
mlflow.log_param("loss", loss.item)
|
||||||
|
|
||||||
# Prediction
|
# Prediction
|
||||||
x_test = Variable(torch.from_numpy(features_test_g)).float()
|
x_test = Variable(torch.from_numpy(features_test_g)).float()
|
||||||
pred = model(x_test)
|
pred = model(x_test)
|
||||||
|
Loading…
Reference in New Issue
Block a user