diff --git a/vgsales-mlflow.py b/vgsales-mlflow.py index 399b42e..3ab8ce6 100644 --- a/vgsales-mlflow.py +++ b/vgsales-mlflow.py @@ -18,8 +18,7 @@ from mlflow.tracking import MlflowClient mlflow.set_tracking_uri("http://172.17.0.1:5000") -mlflow.set_experiment("s434695") -client = MlflowClient() + def my_main(epochs, batch_size): @@ -60,6 +59,7 @@ def my_main(epochs, batch_size): epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 15 batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 16 +mlflow.set_experiment("s434695") with mlflow.start_run():