Added sacred

This commit is contained in:
Andrzej Preibisz 2022-05-08 18:41:08 +02:00
parent d76832d41e
commit 9911ceda3b

View File

@ -4,12 +4,33 @@ import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import sys
import sacred
from sacred.observers import FileStorageObserver
def main():
ex = sacred.Experiment("Training model")
ex.observers.append(FileStorageObserver('training_experiment'))
@ex.config
def get_config():
no_of_epochs = 10
if len(sys.argv) == 2:
no_of_epochs = int(sys.argv[1])
@ex.capture
def evaluate_model(model, test_x, test_y):
test_loss, test_acc, test_rec = model.evaluate(test_x, test_y, verbose=1)
# print("Accuracy:", test_acc)
# print("Loss:", test_loss)
# print("Recall:", test_rec)
return f"Accuracy: {test_acc}, Loss: {test_loss}, Recall: {test_rec}"
@ex.main
def main(no_of_epochs, _run):
# no_of_epochs = get_config()
scaler = StandardScaler()
feature_names = ["BMI", "SleepTime", "Sex", "Diabetic", "PhysicalActivity", "Smoking", "AlcoholDrinking"]
@ -61,11 +82,10 @@ def main():
model.fit(train_X, train_Y, epochs=no_of_epochs)
model.save("trained_model")
test_loss, test_acc, test_rec = model.evaluate(test_X, test_Y, verbose=1)
print("Accuracy:", test_acc)
print("Loss:", test_loss)
print("Recall:", test_rec)
metrics = evaluate_model(model, test_X, test_Y)
_run.log_scalar("model.eval", metrics)
ex.add_artifact("trained_model/saved_model.pb")
ex.add_artifact("trained_model/keras_metadata.pb")
if __name__ == '__main__':
main()
ex.run()