This commit is contained in:
Jakub Zaręba 2023-05-14 15:22:54 +02:00
parent 156c19a688
commit e265d5a987

View File

@ -73,9 +73,10 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
print('Test loss:', loss) print('Test loss:', loss)
model.save("model.h5") model.save("model.h5")
input_signature = { # input_signature = {
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype) # 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0])
} # }
X_train_numpy = X_train.to_numpy() X_train_numpy = X_train.to_numpy()
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy)) signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
input_example = X_train.head(1).to_numpy() input_example = X_train.head(1).to_numpy()