signature
s444354-training/pipeline/head There was a failure building this commit Details

This commit is contained in:
Adrian Charkiewicz 2022-05-15 20:06:19 +02:00
parent 47571f8509
commit 6b4ac85f8a
1 changed files with 7 additions and 5 deletions

View File

@ -22,7 +22,8 @@ from sacred import Experiment
from sacred.observers import FileStorageObserver from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver from sacred.observers import MongoObserver
import mlflow import mlflow
import mlflow.keras import mlflow.pytorch
from mlflow.models import infer_signature
# In[2]: # In[2]:
ex = Experiment(save_git_info=False) ex = Experiment(save_git_info=False)
@ -185,16 +186,17 @@ def fit(epochs, lr, model, train_loader, val_loader, _log, _run, opt_func=torch.
_run.info["epochs"] = epochs _run.info["epochs"] = epochs
signature = mlflow.models.signature.infer_signature(house_price_features, linear_model.predict(house_price_features)) signature = mlflow.models.signature.infer_signature(train_ds)
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
sampleInp = [0.1, 0.1, 546.0, 555.223, 1., 1., 33.16376, 84.12426] input_example = [0.1, 0.1, 546.0, 555.223, 1., 1., 33.16376, 84.12426]
if tracking_url_type_store != "file": if tracking_url_type_store != "file":
mlflow.keras.log_model(model, "model", registered_model_name="red-wine-quality", signature=signature) mlflow.pytorch.log_model(model, "model", registered_model_name="s444354", signature=siganture, input_example=input_example)
else: else:
mlflow.keras.log_model(model, "model", signature=signature, input_example=np.array(sampleInp)) mlflow.pytorch.log_model(model, "model", signature=siganture, input_example=input_example)
mlflow.pytorch.save_model(model, "my_model", signature=siganture, input_example=input_example)