added sacred

This commit is contained in:
Maciej Sobkowiak 2021-05-17 00:17:36 +02:00
parent 39de53d780
commit e95bd5a701
2 changed files with 69 additions and 1 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ suicide_model.h5
test.csv test.csv
train.csv train.csv
validate.csv validate.csv
sacred_runs/

67
sacred_training.py Normal file
View File

@ -0,0 +1,67 @@
import sys
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sklearn.metrics import mean_squared_error
ex = Experiment("434784_sacred_scopes", interactive=True)
ex.observers.append(FileStorageObserver('sacred_runs'))
@ex.config
def my_config():
epochs = 10
batch_size = 10
@ex.capture
def prepare(epochs, batch_size, _run):
sc = pd.read_csv('who_suicide_statistics.csv')
sc.dropna(inplace=True)
sc = pd.get_dummies(
sc, columns=['age', 'sex', 'country'], prefix='', prefix_sep='')
train, validate, test = np.split(sc.sample(frac=1, random_state=42),
[int(.6*len(sc)), int(.8*len(sc))])
X_train = train.loc[:, train.columns != 'suicides_no']
y_train = train[['suicides_no']]
X_test = test.loc[:, train.columns != 'suicides_no']
y_test = test[['suicides_no']]
normalizer = preprocessing.Normalization()
normalizer.adapt(np.array(X_train))
model = tf.keras.Sequential([
normalizer,
layers.Dense(units=1)
])
model.compile(
optimizer=tf.optimizers.Adam(learning_rate=0.1),
loss='mean_absolute_error')
model.fit(
X_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
model.save_weights('suicide_model.h5')
predictions = model.predict(X_test)
error = mean_squared_error(y_test, predictions)
_run.info["mean_squared_error"] = str(error)
return error
@ex.automain
def my_main(epochs, batch_size):
print(prepare())
r = ex.run()
ex.add_artifact('suicide_model.h5')