update dllib-mlflow.py
This commit is contained in:
parent
e10b25c9c9
commit
09d0026265
@ -291,7 +291,8 @@ def my_main(epochs):
|
|||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
# epochs = 1000
|
# epochs = 1000
|
||||||
# epochs = epochs
|
# epochs = epochs
|
||||||
|
mlflow.log_param("epochs", epochs)
|
||||||
|
|
||||||
def print_(loss):
|
def print_(loss):
|
||||||
print ("The loss calculated: ", loss)
|
print ("The loss calculated: ", loss)
|
||||||
|
|
||||||
@ -304,6 +305,8 @@ def my_main(epochs):
|
|||||||
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
loss = loss_fn(y_pred, y_train.squeeze(-1))
|
||||||
print_(loss.item())
|
print_(loss.item())
|
||||||
|
|
||||||
|
mlflow.log_param("loss", loss.item)
|
||||||
|
|
||||||
# Zero gradients
|
# Zero gradients
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward() # Gradients
|
loss.backward() # Gradients
|
||||||
@ -328,5 +331,4 @@ def my_main(epochs):
|
|||||||
|
|
||||||
|
|
||||||
with mlflow.start_run() as run:
|
with mlflow.start_run() as run:
|
||||||
my_main(epochs)
|
my_main(epochs)
|
||||||
mlflow.log_param("epochs", epochs)
|
|
Loading…
Reference in New Issue
Block a user