diff --git a/train.py b/train.py index e2e3f8c..901f2f5 100644 --- a/train.py +++ b/train.py @@ -72,7 +72,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta print('Test accuracy:', accuracy) print('Test loss:', loss) - model.save(model_file) + model.save("model.h5") input_signature = { 'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype) } @@ -87,7 +87,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta input_example = pd.DataFrame(X_train[:1]) mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records')) - return accuracy + return accuracy @ex.main def run_experiment():