mlflow pytorch lib update
This commit is contained in:
parent
1cd58567a0
commit
45ee10ccc0
@ -80,10 +80,10 @@ with mlflow.start_run():
|
||||
mlflow.set_experiment("s434700")
|
||||
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
|
||||
if tracking_url_type_store != "file":
|
||||
mlflow.keras.log_model(model, "model.pt", registered_model_name="s434700", signature=signature,
|
||||
input_example=test_input)
|
||||
mlflow.pytorch.log_model(model, "model.pt", registered_model_name="s434700", signature=signature,
|
||||
input_example=test_input)
|
||||
else:
|
||||
mlflow.keras.log_model(model, "model.pt",
|
||||
signature=signature, input_example=test_input)
|
||||
mlflow.keras.save_model(
|
||||
mlflow.pytorch.log_model(model, "model.pt",
|
||||
signature=signature, input_example=test_input)
|
||||
mlflow.pytorch.save_model(
|
||||
model, "model.pt", signature=signature, input_example=test_input)
|
||||
|
Loading…
Reference in New Issue
Block a user