This commit is contained in:
Jakub Zaręba 2023-05-11 01:01:37 +02:00
parent 7dc5e630f6
commit 1c1a1d6658

View File

@ -6,6 +6,8 @@ import pandas as pd
from sacred import Experiment from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver from sacred.observers import MongoObserver, FileStorageObserver
import os import os
import tensorflow as tf
from tensorflow.python.framework import tensor_spec
os.environ["SACRED_NO_GIT"] = "1" os.environ["SACRED_NO_GIT"] = "1"
@ -70,13 +72,17 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
print('Test loss:', loss) print('Test loss:', loss)
model.save(model_file) model.save(model_file)
input_signature = {
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
}
signature = infer_signature(input_signature, model.output)
mlflow.keras.log_model(model, "model") mlflow.keras.log_model(model, "model")
mlflow.log_artifact("model.h5") mlflow.log_artifact("model.h5")
signature = infer_signature(X_train, model.predict(X_train)) signature = infer_signature(X_train, model.predict(X_train))
input_example = pd.DataFrame(X_train[:1]) input_example = pd.DataFrame(X_train[:1])
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example) mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records'))
return accuracy return accuracy