update dllib-mlflow.py
This commit is contained in:
parent
e10b25c9c9
commit
09d0026265
@ -291,6 +291,7 @@ def my_main(epochs):
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
# epochs = 1000
|
||||
# epochs = epochs
|
||||
mlflow.log_param("epochs", epochs)
|
||||
|
||||
def print_(loss):
|
||||
print ("The loss calculated: ", loss)
|
||||
@ -304,6 +305,8 @@ def my_main(epochs):
|
||||
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
||||
print_(loss.item())
|
||||
|
||||
mlflow.log_param("loss", loss.item)
|
||||
|
||||
# Zero gradients
|
||||
optimizer.zero_grad()
|
||||
loss.backward() # Gradients
|
||||
@ -329,4 +332,3 @@ def my_main(epochs):
|
||||
|
||||
with mlflow.start_run() as run:
|
||||
my_main(epochs)
|
||||
mlflow.log_param("epochs", epochs)
|
Loading…
Reference in New Issue
Block a user