diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index 6c537d2..5dbd84b 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -186,7 +186,7 @@ def fit(epochs, lr, model, train_loader, val_loader, _log, _run, opt_func=torch. _run.info["epochs"] = epochs - signature = mlflow.models.signature.infer_signature(train_ds) + signature = mlflow.models.signature.infer_signature(inputs_array) tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme