s
This commit is contained in:
parent
7dc5e630f6
commit
1c1a1d6658
8
train.py
8
train.py
@ -6,6 +6,8 @@ import pandas as pd
|
||||
from sacred import Experiment
|
||||
from sacred.observers import MongoObserver, FileStorageObserver
|
||||
import os
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
|
||||
os.environ["SACRED_NO_GIT"] = "1"
|
||||
|
||||
@ -70,13 +72,17 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
||||
print('Test loss:', loss)
|
||||
|
||||
model.save(model_file)
|
||||
input_signature = {
|
||||
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
|
||||
}
|
||||
signature = infer_signature(input_signature, model.output)
|
||||
|
||||
mlflow.keras.log_model(model, "model")
|
||||
mlflow.log_artifact("model.h5")
|
||||
|
||||
signature = infer_signature(X_train, model.predict(X_train))
|
||||
input_example = pd.DataFrame(X_train[:1])
|
||||
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example)
|
||||
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records'))
|
||||
|
||||
return accuracy
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user