s
This commit is contained in:
parent
156c19a688
commit
e265d5a987
7
train.py
7
train.py
@ -73,9 +73,10 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
|||||||
print('Test loss:', loss)
|
print('Test loss:', loss)
|
||||||
|
|
||||||
model.save("model.h5")
|
model.save("model.h5")
|
||||||
input_signature = {
|
# input_signature = {
|
||||||
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
|
# 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0])
|
||||||
}
|
# }
|
||||||
|
|
||||||
X_train_numpy = X_train.to_numpy()
|
X_train_numpy = X_train.to_numpy()
|
||||||
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
||||||
input_example = X_train.head(1).to_numpy()
|
input_example = X_train.head(1).to_numpy()
|
||||||
|
Loading…
Reference in New Issue
Block a user