ium_434784/sacred_training.py

72 lines
2.0 KiB
Python
Raw Normal View History

2021-05-17 00:17:36 +02:00
import sys
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sklearn.metrics import mean_squared_error
2021-05-17 00:26:29 +02:00
from sacred.observers import MongoObserver
2021-05-17 00:17:36 +02:00
2021-05-17 00:31:05 +02:00
ex = Experiment("434784_sacred_scopes", interactive=False, save_git_info=False)
2021-05-17 00:37:43 +02:00
ex.observers.append(FileStorageObserver('sacred_runs/runs'))
ex.observers.append(MongoObserver(
2021-05-17 00:39:51 +02:00
url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017', db_name='sacred'))
2021-05-17 00:17:36 +02:00
@ex.config
def my_config():
epochs = 10
batch_size = 10
@ex.capture
def prepare(epochs, batch_size, _run):
sc = pd.read_csv('who_suicide_statistics.csv')
sc.dropna(inplace=True)
sc = pd.get_dummies(
sc, columns=['age', 'sex', 'country'], prefix='', prefix_sep='')
train, validate, test = np.split(sc.sample(frac=1, random_state=42),
[int(.6*len(sc)), int(.8*len(sc))])
X_train = train.loc[:, train.columns != 'suicides_no']
y_train = train[['suicides_no']]
X_test = test.loc[:, train.columns != 'suicides_no']
y_test = test[['suicides_no']]
normalizer = preprocessing.Normalization()
normalizer.adapt(np.array(X_train))
model = tf.keras.Sequential([
normalizer,
layers.Dense(units=1)
])
model.compile(
optimizer=tf.optimizers.Adam(learning_rate=0.1),
loss='mean_absolute_error')
model.fit(
X_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
model.save_weights('suicide_model.h5')
predictions = model.predict(X_test)
error = mean_squared_error(y_test, predictions)
_run.info["mean_squared_error"] = str(error)
2021-05-17 00:46:04 +02:00
_run.log_scalar("mean_squared_error", int(error))
2021-05-17 00:17:36 +02:00
return error
@ex.automain
def my_main(epochs, batch_size):
print(prepare())
r = ex.run()
ex.add_artifact('suicide_model.h5')