diff --git a/learning/ml-mlflow.py b/learning/ml-mlflow.py index e880fca..116a085 100644 --- a/learning/ml-mlflow.py +++ b/learning/ml-mlflow.py @@ -72,8 +72,10 @@ with mlflow.start_run(): mlflow.log_param("test size", testset.size) mlflow.log_param("epochs", EPOCHS) + predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() + signature = mlflow.models.signature.infer_signature( - x_train.values, model.predict(x_train.values)) + x_train.values, predicted) mlflow.set_experiment("s434700") tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme