Added Sacred
Some checks failed
s444380-training/pipeline/head There was a failure building this commit
Some checks failed
s444380-training/pipeline/head There was a failure building this commit
This commit is contained in:
parent
ca0d892c3b
commit
798937cb87
@ -16,6 +16,6 @@ RUN apt-get update && apt-get install -y python3-pip unzip && rm -rf /var/lib/ap
|
||||
|
||||
RUN export PATH="$PATH:/root/.local/bin"
|
||||
|
||||
RUN pip3 install kaggle pandas scikit-learn tensorflow keras matplotlib numpy
|
||||
RUN pip3 install kaggle pandas scikit-learn tensorflow keras matplotlib numpy sacred
|
||||
|
||||
RUN mkdir /.kaggle && chmod o+w /.kaggle
|
||||
|
@ -28,8 +28,8 @@ pipeline {
|
||||
stage("Train model") {
|
||||
steps {
|
||||
sh "chmod u+x ./train_model.py"
|
||||
sh "python3 ./train_model.py $EPOCHS"
|
||||
archiveArtifacts artifacts: "model/*, out.csv", onlyIfSuccessful: true
|
||||
sh "python3 ./train_model.py with 'epcohs=$EPOCHS'"
|
||||
archiveArtifacts artifacts: "model/*, out.csv, experiments/*/*", onlyIfSuccessful: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,21 @@ from keras.layers import Dense
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver, MongoObserver
|
||||
|
||||
ex = Experiment()
|
||||
ex.observers.append(FileStorageObserver("experiments"))
|
||||
#ex.observers.append(MongoObserver(url="mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017", db_name="sacred"))
|
||||
|
||||
|
||||
@ex.config
|
||||
def config():
|
||||
epochs = 10
|
||||
|
||||
|
||||
@ex.automain
|
||||
def main(epochs):
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
|
||||
# Read and split data
|
||||
@ -65,13 +79,8 @@ model.add(Dense(128, activation='relu'))
|
||||
model.add(Dense(num_categories, activation='softmax'))
|
||||
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy', 'sparse_categorical_accuracy'])
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
epochs = sys.argv[1]
|
||||
else:
|
||||
epochs = 10
|
||||
|
||||
# Train model
|
||||
model.fit(x_train, y_train, epochs=int(epochs), validation_data=(x_val, y_val))
|
||||
history = model.fit(x_train, y_train, epochs=int(epochs), validation_data=(x_val, y_val))
|
||||
|
||||
# Make predictions
|
||||
y_pred = model.predict(x_test)
|
||||
@ -85,3 +94,10 @@ data_to_save.to_csv("out.csv")
|
||||
|
||||
# Save model
|
||||
model.save("model")
|
||||
ex.add_artifact("model/saved_model.pb")
|
||||
|
||||
# Log metrics
|
||||
ex.log_scalar("loss", history.history["loss"])
|
||||
ex.log_scalar("accuracy", history.history["accuracy"])
|
||||
ex.log_scalar("val_loss", history.history["val_loss"])
|
||||
ex.log_scalar("val_accuracy", history.history["val_accuracy"])
|
||||
|
Loading…
Reference in New Issue
Block a user