diff --git a/vgsales-mlflow.py b/vgsales-mlflow.py index 3ab8ce6..13bea19 100644 --- a/vgsales-mlflow.py +++ b/vgsales-mlflow.py @@ -18,7 +18,8 @@ from mlflow.tracking import MlflowClient mlflow.set_tracking_uri("http://172.17.0.1:5000") - +mlflow.set_experiment("s434695") +client = MlflowClient() def my_main(epochs, batch_size): @@ -59,7 +60,6 @@ def my_main(epochs, batch_size): epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 15 batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 16 -mlflow.set_experiment("s434695") with mlflow.start_run(): @@ -70,4 +70,4 @@ with mlflow.start_run(): mlflow.log_metric("rmse", rmse) #mlflow.keras.log_model(model, 'vgsales_model.h5') mlflow.keras.save_model(model, "my_model", signature=mlflow.models.signature.infer_signature(x_train, y_train), input_example=x_train) - mlflow.keras.log_model(model, "model",registered_model_name="s434695", signature= mlflow.models.signature.infer_signature(x_train, y_train), input_example= x_train[0]) \ No newline at end of file + mlflow.keras.log_model(model, "model",registered_model_name="s434695", signature=mlflow.models.signature.infer_signature(x_train, y_train), input_example=x_train) \ No newline at end of file