Added sacred
This commit is contained in:
parent
d76832d41e
commit
9911ceda3b
@ -4,12 +4,33 @@ import tensorflow as tf
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import sys
|
||||
import sacred
|
||||
from sacred.observers import FileStorageObserver
|
||||
|
||||
def main():
|
||||
ex = sacred.Experiment("Training model")
|
||||
ex.observers.append(FileStorageObserver('training_experiment'))
|
||||
|
||||
|
||||
@ex.config
|
||||
def get_config():
|
||||
no_of_epochs = 10
|
||||
if len(sys.argv) == 2:
|
||||
no_of_epochs = int(sys.argv[1])
|
||||
|
||||
|
||||
@ex.capture
|
||||
def evaluate_model(model, test_x, test_y):
|
||||
test_loss, test_acc, test_rec = model.evaluate(test_x, test_y, verbose=1)
|
||||
# print("Accuracy:", test_acc)
|
||||
# print("Loss:", test_loss)
|
||||
# print("Recall:", test_rec)
|
||||
return f"Accuracy: {test_acc}, Loss: {test_loss}, Recall: {test_rec}"
|
||||
|
||||
|
||||
@ex.main
|
||||
def main(no_of_epochs, _run):
|
||||
# no_of_epochs = get_config()
|
||||
|
||||
scaler = StandardScaler()
|
||||
feature_names = ["BMI", "SleepTime", "Sex", "Diabetic", "PhysicalActivity", "Smoking", "AlcoholDrinking"]
|
||||
|
||||
@ -61,11 +82,10 @@ def main():
|
||||
model.fit(train_X, train_Y, epochs=no_of_epochs)
|
||||
model.save("trained_model")
|
||||
|
||||
test_loss, test_acc, test_rec = model.evaluate(test_X, test_Y, verbose=1)
|
||||
print("Accuracy:", test_acc)
|
||||
print("Loss:", test_loss)
|
||||
print("Recall:", test_rec)
|
||||
metrics = evaluate_model(model, test_X, test_Y)
|
||||
_run.log_scalar("model.eval", metrics)
|
||||
ex.add_artifact("trained_model/saved_model.pb")
|
||||
ex.add_artifact("trained_model/keras_metadata.pb")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
ex.run()
|
||||
|
Loading…
Reference in New Issue
Block a user