mlflow pytorch lib update

This commit is contained in:
Filip Izydorczyk 2021-06-04 16:10:51 +02:00
parent 1cd58567a0
commit 45ee10ccc0

View File

@ -80,10 +80,10 @@ with mlflow.start_run():
mlflow.set_experiment("s434700")
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlflow.keras.log_model(model, "model.pt", registered_model_name="s434700", signature=signature,
mlflow.pytorch.log_model(model, "model.pt", registered_model_name="s434700", signature=signature,
input_example=test_input)
else:
mlflow.keras.log_model(model, "model.pt",
mlflow.pytorch.log_model(model, "model.pt",
signature=signature, input_example=test_input)
mlflow.keras.save_model(
mlflow.pytorch.save_model(
model, "model.pt", signature=signature, input_example=test_input)