diff --git a/sacred_training.py b/sacred_training.py index 1eb1db1..252bfd9 100644 --- a/sacred_training.py +++ b/sacred_training.py @@ -58,6 +58,7 @@ def prepare(epochs, batch_size, _run): predictions = model.predict(X_test) error = mean_squared_error(y_test, predictions) _run.info["mean_squared_error"] = str(error) + _run.log_scalar("mean_squared_error", int(error)) return error