diff --git a/train.py b/train.py index 901f2f5..411b61b 100644 --- a/train.py +++ b/train.py @@ -73,9 +73,10 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta print('Test loss:', loss) model.save("model.h5") - input_signature = { - 'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype) - } + # input_signature = { + # 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0]) + # } + X_train_numpy = X_train.to_numpy() signature = infer_signature(X_train_numpy, model.predict(X_train_numpy)) input_example = X_train.head(1).to_numpy()