From 407d6187afebc7a47728124526031da5a0ec8ad4 Mon Sep 17 00:00:00 2001 From: jakubknczny Date: Sun, 16 May 2021 16:04:31 +0200 Subject: [PATCH] slicing --- lab5/train/train.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lab5/train/train.py b/lab5/train/train.py index 0716c6e..b8ca4c9 100644 --- a/lab5/train/train.py +++ b/lab5/train/train.py @@ -7,13 +7,16 @@ import tensorflow from tensorflow.keras import layers ex = Experiment("470607", interactive=False, save_git_info=False) -ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017', db_name='sacred')) +ex.observers.append( + MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017', db_name='sacred')) ex.observers.append(FileStorageObserver('my_runs')) + @ex.config def my_config(): learning_rate = float(sys.argv[1]) + @ex.capture def prepare_train_model(learning_rate, _run): _run.info["prepare_model"] = str(datetime.now()) @@ -28,10 +31,10 @@ def prepare_train_model(learning_rate, _run): Y_valid = pd.get_dummies(Y_valid) model = tensorflow.keras.Sequential([ - layers.Input(shape=(12,)), - layers.Dense(32), - layers.Dense(16), - layers.Dense(2, activation='softmax') + layers.Input(shape=(12,)), + layers.Dense(32), + layers.Dense(16), + layers.Dense(2, activation='softmax') ]) model.compile( @@ -43,7 +46,8 @@ def prepare_train_model(learning_rate, _run): model.save('grid-stability-dense.h5') - _run['history'] = str(history) + _run['history'] = str(history.history[:, -1]) + @ex.main def my_main(learning_rate): @@ -52,4 +56,3 @@ def my_main(learning_rate): r = ex.run() ex.add_artifact('grid-stability-dense.h5') -