diff --git a/sacred-fileobserver.py b/sacred-fileobserver.py index e9900b8..48628d3 100644 --- a/sacred-fileobserver.py +++ b/sacred-fileobserver.py @@ -23,7 +23,7 @@ def my_config(): batch_size = 16 @ex.capture -def prepare_model(epochs, batch_size): +def prepare_model(epochs, batch_size, _run): # odczytanie danych z plików avocado_train = pd.read_csv('avocado_train.csv') avocado_test = pd.read_csv('avocado_test.csv') @@ -55,6 +55,7 @@ def prepare_model(epochs, batch_size): # ewaluacja rmse = mean_squared_error(y_test, prediction) + _run.log_scalar("rmse", rmse) # zapisanie modelu model.save('avocado_model.h5')