This commit is contained in:
Jakub Zaręba 2023-05-14 15:31:15 +02:00
parent e265d5a987
commit 0e25c6462e

View File

@ -73,20 +73,22 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
print('Test loss:', loss) print('Test loss:', loss)
model.save("model.h5") model.save("model.h5")
X_train_numpy = X_train.values
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
input_example = X_train.head(1).values
# input_signature = { # input_signature = {
# 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0]) # 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0])
# } # }
X_train_numpy = X_train.to_numpy()
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
input_example = X_train.head(1).to_numpy()
mlflow.keras.log_model(model, "model") mlflow.keras.log_model(model, "model")
mlflow.log_artifact("model.h5") mlflow.log_artifact("model.h5")
signature = infer_signature(X_train, model.predict(X_train)) # Use the ndarray form for infer_signature and input_example
input_example = pd.DataFrame(X_train[:1]) signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records')) input_example = X_train.head(1).values
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example)
return accuracy return accuracy