diff --git a/train.py b/train.py index 623a4c0..af8b6d3 100644 --- a/train.py +++ b/train.py @@ -103,14 +103,14 @@ with mlflow.start_run() as run: optimizer.step() torch.save(model,"classificationn_model.pt") - signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy()) - input_example = {"input": X_train[0].numpy().tolist()} +signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy()) +input_example = {"input": X_train[0].numpy().tolist()} - # Log model - tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme - if tracking_url_type_store != "file": - mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example, registered_model_name="ClassificationModel") - else: - mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example) +# Log model +tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme +if tracking_url_type_store != "file": + mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example, registered_model_name="ClassificationModel") +else: + mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example)