fix
This commit is contained in:
parent
93a992d53d
commit
c579d63845
@ -18,7 +18,8 @@ from mlflow.tracking import MlflowClient
|
|||||||
|
|
||||||
|
|
||||||
mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
||||||
|
mlflow.set_experiment("s434695")
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
def my_main(epochs, batch_size):
|
def my_main(epochs, batch_size):
|
||||||
|
|
||||||
@ -59,7 +60,6 @@ def my_main(epochs, batch_size):
|
|||||||
epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 15
|
epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 15
|
||||||
batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 16
|
batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 16
|
||||||
|
|
||||||
mlflow.set_experiment("s434695")
|
|
||||||
|
|
||||||
with mlflow.start_run():
|
with mlflow.start_run():
|
||||||
|
|
||||||
@ -70,4 +70,4 @@ with mlflow.start_run():
|
|||||||
mlflow.log_metric("rmse", rmse)
|
mlflow.log_metric("rmse", rmse)
|
||||||
#mlflow.keras.log_model(model, 'vgsales_model.h5')
|
#mlflow.keras.log_model(model, 'vgsales_model.h5')
|
||||||
mlflow.keras.save_model(model, "my_model", signature=mlflow.models.signature.infer_signature(x_train, y_train), input_example=x_train)
|
mlflow.keras.save_model(model, "my_model", signature=mlflow.models.signature.infer_signature(x_train, y_train), input_example=x_train)
|
||||||
mlflow.keras.log_model(model, "model",registered_model_name="s434695", signature= mlflow.models.signature.infer_signature(x_train, y_train), input_example= x_train[0])
|
mlflow.keras.log_model(model, "model",registered_model_name="s434695", signature=mlflow.models.signature.infer_signature(x_train, y_train), input_example=x_train)
|
Loading…
Reference in New Issue
Block a user