From 1c1a1d66581bf7470bbf03027d473dd7cdde5703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Zar=C4=99ba?= Date: Thu, 11 May 2023 01:01:37 +0200 Subject: [PATCH] s --- train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 7a0dd7f..4a77406 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,8 @@ import pandas as pd from sacred import Experiment from sacred.observers import MongoObserver, FileStorageObserver import os +import tensorflow as tf +from tensorflow.python.framework import tensor_spec os.environ["SACRED_NO_GIT"] = "1" @@ -70,13 +72,17 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta print('Test loss:', loss) model.save(model_file) + input_signature = { + 'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype) + } + signature = infer_signature(input_signature, model.output) mlflow.keras.log_model(model, "model") mlflow.log_artifact("model.h5") signature = infer_signature(X_train, model.predict(X_train)) input_example = pd.DataFrame(X_train[:1]) - mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example) + mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records')) return accuracy