added sacred
This commit is contained in:
parent
39de53d780
commit
e95bd5a701
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,4 +3,5 @@ __pycache__
|
||||
suicide_model.h5
|
||||
test.csv
|
||||
train.csv
|
||||
validate.csv
|
||||
validate.csv
|
||||
sacred_runs/
|
67
sacred_training.py
Normal file
67
sacred_training.py
Normal file
@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers
|
||||
from tensorflow.keras.layers.experimental import preprocessing
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
ex = Experiment("434784_sacred_scopes", interactive=True)
|
||||
ex.observers.append(FileStorageObserver('sacred_runs'))
|
||||
|
||||
|
||||
@ex.config
|
||||
def my_config():
|
||||
epochs = 10
|
||||
batch_size = 10
|
||||
|
||||
|
||||
@ex.capture
|
||||
def prepare(epochs, batch_size, _run):
|
||||
sc = pd.read_csv('who_suicide_statistics.csv')
|
||||
sc.dropna(inplace=True)
|
||||
sc = pd.get_dummies(
|
||||
sc, columns=['age', 'sex', 'country'], prefix='', prefix_sep='')
|
||||
|
||||
train, validate, test = np.split(sc.sample(frac=1, random_state=42),
|
||||
[int(.6*len(sc)), int(.8*len(sc))])
|
||||
|
||||
X_train = train.loc[:, train.columns != 'suicides_no']
|
||||
y_train = train[['suicides_no']]
|
||||
X_test = test.loc[:, train.columns != 'suicides_no']
|
||||
y_test = test[['suicides_no']]
|
||||
|
||||
normalizer = preprocessing.Normalization()
|
||||
normalizer.adapt(np.array(X_train))
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
normalizer,
|
||||
layers.Dense(units=1)
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer=tf.optimizers.Adam(learning_rate=0.1),
|
||||
loss='mean_absolute_error')
|
||||
|
||||
model.fit(
|
||||
X_train, y_train,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=0.2)
|
||||
|
||||
model.save_weights('suicide_model.h5')
|
||||
predictions = model.predict(X_test)
|
||||
error = mean_squared_error(y_test, predictions)
|
||||
_run.info["mean_squared_error"] = str(error)
|
||||
return error
|
||||
|
||||
|
||||
@ex.automain
|
||||
def my_main(epochs, batch_size):
|
||||
print(prepare())
|
||||
|
||||
|
||||
r = ex.run()
|
||||
ex.add_artifact('suicide_model.h5')
|
Loading…
Reference in New Issue
Block a user