diff --git a/train_model.py b/train_model.py index fa57638..fb441dd 100644 --- a/train_model.py +++ b/train_model.py @@ -197,7 +197,7 @@ def signature(model, in_alphabet, max_len): mock_x = [('abc', 'xyz')] mock_text, _ = encode_str(mock_x, in_alphabet, max_len) mock_y = model(mock_text) - return mlflow.models.signature.infer_signature(mock_text, mock_y) + return mlflow.models.signature.infer_signature(mock_text.numpy(), mock_y.numpy()) @ex.automain