s
This commit is contained in:
parent
7dc5e630f6
commit
1c1a1d6658
8
train.py
8
train.py
@ -6,6 +6,8 @@ import pandas as pd
|
|||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
from sacred.observers import MongoObserver, FileStorageObserver
|
from sacred.observers import MongoObserver, FileStorageObserver
|
||||||
import os
|
import os
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
|
||||||
os.environ["SACRED_NO_GIT"] = "1"
|
os.environ["SACRED_NO_GIT"] = "1"
|
||||||
|
|
||||||
@ -70,13 +72,17 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
|||||||
print('Test loss:', loss)
|
print('Test loss:', loss)
|
||||||
|
|
||||||
model.save(model_file)
|
model.save(model_file)
|
||||||
|
input_signature = {
|
||||||
|
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
|
||||||
|
}
|
||||||
|
signature = infer_signature(input_signature, model.output)
|
||||||
|
|
||||||
mlflow.keras.log_model(model, "model")
|
mlflow.keras.log_model(model, "model")
|
||||||
mlflow.log_artifact("model.h5")
|
mlflow.log_artifact("model.h5")
|
||||||
|
|
||||||
signature = infer_signature(X_train, model.predict(X_train))
|
signature = infer_signature(X_train, model.predict(X_train))
|
||||||
input_example = pd.DataFrame(X_train[:1])
|
input_example = pd.DataFrame(X_train[:1])
|
||||||
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example)
|
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records'))
|
||||||
|
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user