This commit is contained in:
Jakub Zaręba 2023-05-11 01:01:37 +02:00
parent 7dc5e630f6
commit 1c1a1d6658

View File

@ -6,6 +6,8 @@ import pandas as pd
from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver
import os
import tensorflow as tf
from tensorflow.python.framework import tensor_spec
os.environ["SACRED_NO_GIT"] = "1"
@ -70,13 +72,17 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
print('Test loss:', loss)
model.save(model_file)
input_signature = {
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
}
signature = infer_signature(input_signature, model.output)
mlflow.keras.log_model(model, "model")
mlflow.log_artifact("model.h5")
signature = infer_signature(X_train, model.predict(X_train))
input_example = pd.DataFrame(X_train[:1])
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example)
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records'))
return accuracy