diff --git a/train_model.py b/train_model.py index fb441dd..a2e8049 100644 --- a/train_model.py +++ b/train_model.py @@ -197,7 +197,7 @@ def signature(model, in_alphabet, max_len): mock_x = [('abc', 'xyz')] mock_text, _ = encode_str(mock_x, in_alphabet, max_len) mock_y = model(mock_text) - return mlflow.models.signature.infer_signature(mock_text.numpy(), mock_y.numpy()) + return mlflow.models.signature.infer_signature(mock_text.detach().numpy(), mock_y.detach().numpy()) @ex.automain