update dllib-mlflow.py
All checks were successful
s444356-evaluation/pipeline/head This commit looks good
s444356-training/pipeline/head This commit looks good

This commit is contained in:
Maciej Czajka 2022-05-11 15:17:17 +02:00
parent 73bd0a68c7
commit e10b25c9c9

View File

@ -316,7 +316,7 @@ def my_main(epochs):
pred = pred.detach().numpy() pred = pred.detach().numpy()
print("The accuracy is", accuracy_score(labels_test_g, np.argmax(pred, axis=1))) print("The accuracy is", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
# mlflow.log_metric("accuracy", accuracy_score(labels_test_g, np.argmax(pred, axis=1))) mlflow.log_metric("accuracy", accuracy_score(labels_test_g, np.argmax(pred, axis=1)))
pred = pd.DataFrame(pred) pred = pd.DataFrame(pred)
@ -324,11 +324,9 @@ def my_main(epochs):
# save model # save model
torch.save(model, "games_model.pkl") torch.save(model, "games_model.pkl")
acc = accuracy_score(labels_test_g, np.argmax(pred, axis=1))
return acc
with mlflow.start_run() as run: with mlflow.start_run() as run:
acc = my_main(epochs) my_main(epochs)
mlflow.log_param("epochs", epochs) mlflow.log_param("epochs", epochs)
mlflow.log_metric("accuracy", acc)