From e95bd5a701eb52badacc25918c11cbd48257da7d Mon Sep 17 00:00:00 2001 From: Maciej Sobkowiak Date: Mon, 17 May 2021 00:17:36 +0200 Subject: [PATCH] added sacred --- .gitignore | 3 ++- sacred_training.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 sacred_training.py diff --git a/.gitignore b/.gitignore index a5f500c..7ed5d44 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ suicide_model.h5 test.csv train.csv -validate.csv \ No newline at end of file +validate.csv +sacred_runs/ \ No newline at end of file diff --git a/sacred_training.py b/sacred_training.py new file mode 100644 index 0000000..d00eab9 --- /dev/null +++ b/sacred_training.py @@ -0,0 +1,67 @@ +import sys +import pandas as pd +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.layers.experimental import preprocessing +from sacred import Experiment +from sacred.observers import FileStorageObserver +from sklearn.metrics import mean_squared_error + +ex = Experiment("434784_sacred_scopes", interactive=True) +ex.observers.append(FileStorageObserver('sacred_runs')) + + +@ex.config +def my_config(): + epochs = 10 + batch_size = 10 + + +@ex.capture +def prepare(epochs, batch_size, _run): + sc = pd.read_csv('who_suicide_statistics.csv') + sc.dropna(inplace=True) + sc = pd.get_dummies( + sc, columns=['age', 'sex', 'country'], prefix='', prefix_sep='') + + train, validate, test = np.split(sc.sample(frac=1, random_state=42), + [int(.6*len(sc)), int(.8*len(sc))]) + + X_train = train.loc[:, train.columns != 'suicides_no'] + y_train = train[['suicides_no']] + X_test = test.loc[:, train.columns != 'suicides_no'] + y_test = test[['suicides_no']] + + normalizer = preprocessing.Normalization() + normalizer.adapt(np.array(X_train)) + + model = tf.keras.Sequential([ + normalizer, + layers.Dense(units=1) + ]) + + model.compile( + optimizer=tf.optimizers.Adam(learning_rate=0.1), + loss='mean_absolute_error') + + model.fit( + X_train, y_train, + batch_size=batch_size, + epochs=epochs, + validation_split=0.2) + + model.save_weights('suicide_model.h5') + predictions = model.predict(X_test) + error = mean_squared_error(y_test, predictions) + _run.info["mean_squared_error"] = str(error) + return error + + +@ex.automain +def my_main(epochs, batch_size): + print(prepare()) + + +r = ex.run() +ex.add_artifact('suicide_model.h5')