From e265d5a98704d35d15a218291cdff3130cc6597a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Zar=C4=99ba?= Date: Sun, 14 May 2023 15:22:54 +0200 Subject: [PATCH] s --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 901f2f5..411b61b 100644 --- a/train.py +++ b/train.py @@ -73,9 +73,10 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta print('Test loss:', loss) model.save("model.h5") - input_signature = { - 'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype) - } + # input_signature = { + # 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0]) + # } + X_train_numpy = X_train.to_numpy() signature = infer_signature(X_train_numpy, model.predict(X_train_numpy)) input_example = X_train.head(1).to_numpy()