Added Sacred
Some checks failed
s444380-training/pipeline/head There was a failure building this commit

This commit is contained in:
Kamil Guttmann 2022-05-07 19:44:54 +02:00
parent ca0d892c3b
commit 798937cb87
3 changed files with 89 additions and 73 deletions

View File

@ -16,6 +16,6 @@ RUN apt-get update && apt-get install -y python3-pip unzip && rm -rf /var/lib/ap
RUN export PATH="$PATH:/root/.local/bin"
RUN pip3 install kaggle pandas scikit-learn tensorflow keras matplotlib numpy
RUN pip3 install kaggle pandas scikit-learn tensorflow keras matplotlib numpy sacred
RUN mkdir /.kaggle && chmod o+w /.kaggle

View File

@ -28,8 +28,8 @@ pipeline {
stage("Train model") {
steps {
sh "chmod u+x ./train_model.py"
sh "python3 ./train_model.py $EPOCHS"
archiveArtifacts artifacts: "model/*, out.csv", onlyIfSuccessful: true
sh "python3 ./train_model.py with 'epcohs=$EPOCHS'"
archiveArtifacts artifacts: "model/*, out.csv, experiments/*/*", onlyIfSuccessful: true
}
}
}

View File

@ -7,7 +7,21 @@ from keras.layers import Dense
import tensorflow as tf
import numpy as np
from sacred import Experiment
from sacred.observers import FileStorageObserver, MongoObserver
ex = Experiment()
ex.observers.append(FileStorageObserver("experiments"))
#ex.observers.append(MongoObserver(url="mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017", db_name="sacred"))
@ex.config
def config():
epochs = 10
@ex.automain
def main(epochs):
tf.config.set_visible_devices([], 'GPU')
# Read and split data
@ -65,13 +79,8 @@ model.add(Dense(128, activation='relu'))
model.add(Dense(num_categories, activation='softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy', 'sparse_categorical_accuracy'])
if len(sys.argv) > 1:
epochs = sys.argv[1]
else:
epochs = 10
# Train model
model.fit(x_train, y_train, epochs=int(epochs), validation_data=(x_val, y_val))
history = model.fit(x_train, y_train, epochs=int(epochs), validation_data=(x_val, y_val))
# Make predictions
y_pred = model.predict(x_test)
@ -85,3 +94,10 @@ data_to_save.to_csv("out.csv")
# Save model
model.save("model")
ex.add_artifact("model/saved_model.pb")
# Log metrics
ex.log_scalar("loss", history.history["loss"])
ex.log_scalar("accuracy", history.history["accuracy"])
ex.log_scalar("val_loss", history.history["val_loss"])
ex.log_scalar("val_accuracy", history.history["val_accuracy"])