This commit is contained in:
Jakub Zaręba 2023-05-14 15:13:13 +02:00
parent e1ca9046a6
commit 156c19a688

View File

@ -72,7 +72,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
print('Test accuracy:', accuracy)
print('Test loss:', loss)
model.save(model_file)
model.save("model.h5")
input_signature = {
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
}
@ -87,7 +87,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
input_example = pd.DataFrame(X_train[:1])
mlflow.keras.save_model(model, "model", signature=signature, input_example=input_example.to_dict('records'))
return accuracy
return accuracy
@ex.main
def run_experiment():