diff --git a/train.py b/train.py index 8061ce4..e07529b 100644 --- a/train.py +++ b/train.py @@ -102,7 +102,7 @@ with mlflow.start_run() as run: loss.backward() optimizer.step() - # Infer model signature to log it +torch.save(model,"classificationn_model.pt") signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy()) input_example = {"input": X_train[0].numpy().tolist()}