Added sacred
This commit is contained in:
parent
d76832d41e
commit
9911ceda3b
@ -4,12 +4,33 @@ import tensorflow as tf
|
|||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
import sys
|
import sys
|
||||||
|
import sacred
|
||||||
|
from sacred.observers import FileStorageObserver
|
||||||
|
|
||||||
def main():
|
ex = sacred.Experiment("Training model")
|
||||||
|
ex.observers.append(FileStorageObserver('training_experiment'))
|
||||||
|
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def get_config():
|
||||||
no_of_epochs = 10
|
no_of_epochs = 10
|
||||||
if len(sys.argv) == 2:
|
if len(sys.argv) == 2:
|
||||||
no_of_epochs = int(sys.argv[1])
|
no_of_epochs = int(sys.argv[1])
|
||||||
|
|
||||||
|
|
||||||
|
@ex.capture
|
||||||
|
def evaluate_model(model, test_x, test_y):
|
||||||
|
test_loss, test_acc, test_rec = model.evaluate(test_x, test_y, verbose=1)
|
||||||
|
# print("Accuracy:", test_acc)
|
||||||
|
# print("Loss:", test_loss)
|
||||||
|
# print("Recall:", test_rec)
|
||||||
|
return f"Accuracy: {test_acc}, Loss: {test_loss}, Recall: {test_rec}"
|
||||||
|
|
||||||
|
|
||||||
|
@ex.main
|
||||||
|
def main(no_of_epochs, _run):
|
||||||
|
# no_of_epochs = get_config()
|
||||||
|
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
feature_names = ["BMI", "SleepTime", "Sex", "Diabetic", "PhysicalActivity", "Smoking", "AlcoholDrinking"]
|
feature_names = ["BMI", "SleepTime", "Sex", "Diabetic", "PhysicalActivity", "Smoking", "AlcoholDrinking"]
|
||||||
|
|
||||||
@ -61,11 +82,10 @@ def main():
|
|||||||
model.fit(train_X, train_Y, epochs=no_of_epochs)
|
model.fit(train_X, train_Y, epochs=no_of_epochs)
|
||||||
model.save("trained_model")
|
model.save("trained_model")
|
||||||
|
|
||||||
test_loss, test_acc, test_rec = model.evaluate(test_X, test_Y, verbose=1)
|
metrics = evaluate_model(model, test_X, test_Y)
|
||||||
print("Accuracy:", test_acc)
|
_run.log_scalar("model.eval", metrics)
|
||||||
print("Loss:", test_loss)
|
ex.add_artifact("trained_model/saved_model.pb")
|
||||||
print("Recall:", test_rec)
|
ex.add_artifact("trained_model/keras_metadata.pb")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
ex.run()
|
||||||
main()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user