From 0e25c6462ec2a9695d5a5616cafb67a0a145a97e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Zar=C4=99ba?= Date: Sun, 14 May 2023 15:31:15 +0200 Subject: [PATCH] sd --- train.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 411b61b..9ec2e83 100644 --- a/train.py +++ b/train.py @@ -73,20 +73,22 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta print('Test loss:', loss) model.save("model.h5") + + X_train_numpy = X_train.values + signature = infer_signature(X_train_numpy, model.predict(X_train_numpy)) + input_example = X_train.head(1).values + # input_signature = { # 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0]) # } - X_train_numpy = X_train.to_numpy() - signature = infer_signature(X_train_numpy, model.predict(X_train_numpy)) - input_example = X_train.head(1).to_numpy() - mlflow.keras.log_model(model, "model") mlflow.log_artifact("model.h5") - signature = infer_signature(X_train, model.predict(X_train)) - input_example = pd.DataFrame(X_train[:1]) - mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records')) + # Use the ndarray form for infer_signature and input_example + signature = infer_signature(X_train_numpy, model.predict(X_train_numpy)) + input_example = X_train.head(1).values + mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example) return accuracy