update dllib-mlflow.py
This commit is contained in:
parent
73bd0a68c7
commit
e10b25c9c9
@ -316,7 +316,7 @@ def my_main(epochs):
|
||||
pred = pred.detach().numpy()
|
||||
|
||||
print("The accuracy is", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
||||
# mlflow.log_metric("accuracy", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
||||
mlflow.log_metric("accuracy", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
|
||||
|
||||
pred = pd.DataFrame(pred)
|
||||
|
||||
@ -324,11 +324,9 @@ def my_main(epochs):
|
||||
|
||||
# save model
|
||||
torch.save(model, "games_model.pkl")
|
||||
acc = accuracy_score(labels_test_g, np.argmax(pred, axis=1))
|
||||
return acc
|
||||
|
||||
|
||||
|
||||
with mlflow.start_run() as run:
|
||||
acc = my_main(epochs)
|
||||
my_main(epochs)
|
||||
mlflow.log_param("epochs", epochs)
|
||||
mlflow.log_metric("accuracy", acc)
|
Loading…
Reference in New Issue
Block a user