sd
This commit is contained in:
parent
e265d5a987
commit
0e25c6462e
16
train.py
16
train.py
@ -73,20 +73,22 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
|||||||
print('Test loss:', loss)
|
print('Test loss:', loss)
|
||||||
|
|
||||||
model.save("model.h5")
|
model.save("model.h5")
|
||||||
|
|
||||||
|
X_train_numpy = X_train.values
|
||||||
|
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
||||||
|
input_example = X_train.head(1).values
|
||||||
|
|
||||||
# input_signature = {
|
# input_signature = {
|
||||||
# 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0])
|
# 'input': tensor_spec.TensorSpec(shape=X_train.iloc[0].shape, dtype=X_train.dtypes[0])
|
||||||
# }
|
# }
|
||||||
|
|
||||||
X_train_numpy = X_train.to_numpy()
|
|
||||||
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
|
||||||
input_example = X_train.head(1).to_numpy()
|
|
||||||
|
|
||||||
mlflow.keras.log_model(model, "model")
|
mlflow.keras.log_model(model, "model")
|
||||||
mlflow.log_artifact("model.h5")
|
mlflow.log_artifact("model.h5")
|
||||||
|
|
||||||
signature = infer_signature(X_train, model.predict(X_train))
|
# Use the ndarray form for infer_signature and input_example
|
||||||
input_example = pd.DataFrame(X_train[:1])
|
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
||||||
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records'))
|
input_example = X_train.head(1).values
|
||||||
|
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example)
|
||||||
|
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user