From 7714196702fa49e1e388f5ea36e31db041f1278b Mon Sep 17 00:00:00 2001 From: Robert Bendun Date: Mon, 15 May 2023 03:50:02 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 2 +- src/tf_train.py | 17 +++++++++-------- src/tf_train_sacred.py | 21 +++++++++++++++++++++ training.jenkinsfile | 4 ++-- 4 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 src/tf_train_sacred.py diff --git a/Dockerfile b/Dockerfile index 103fe81..ef1f9e6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ FROM ubuntu:22.04 RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git -RUN pip install pandas matplotlib scikit-learn tensorflow +RUN pip install pandas matplotlib scikit-learn tensorflow sacred CMD "bash" diff --git a/src/tf_train.py b/src/tf_train.py index caeb2d1..2d6bb24 100755 --- a/src/tf_train.py +++ b/src/tf_train.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 +from sklearn.preprocessing import LabelEncoder import pandas as pd import tensorflow as tf -from sklearn.preprocessing import LabelEncoder pd.set_option('display.float_format', lambda x: '%.5f' % x) le = LabelEncoder() le.fit([x.strip() for x in pd.read_csv('./stop_times.categories.tsv', sep='\t')['stop_headsign'].to_numpy()]) -def load_data(path: str, le: LabelEncoder): - train = pd.read_csv(path, sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str }) +def load_data(path: str, le: LabelEncoder, open_file): + train = pd.read_csv(open_file(path), sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str }) departure_time = train['departure_time'].to_numpy() stop_id = train['stop_id'].to_numpy() @@ -23,7 +23,7 @@ def load_data(path: str, le: LabelEncoder): num_classes = len(le.classes_) -def train(epochs: int): +def train(epochs: int, open_file, add_artifact): global le model = tf.keras.Sequential([ @@ -38,18 +38,19 @@ def train(epochs: int): loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy']) - train_x, train_y, _ = load_data('./stop_times.train.tsv', le) + train_x, train_y, _ = load_data('./stop_times.train.tsv', le, open_file) train_x = tf.convert_to_tensor(train_x, dtype=tf.float32) train_y = tf.convert_to_tensor(train_y) - valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le) + valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le, open_file) valid_x = tf.convert_to_tensor(valid_x, dtype=tf.float32) valid_y = tf.convert_to_tensor(valid_y) - model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024) + history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024) model.save('model.keras') + add_artifact('model.keras', history.history) if __name__ == "__main__": import sys epochs = int('2' if len(sys.argv) != 2 else sys.argv[1]) - train(epochs) + train(epochs, lambda x: x, lambda _1, _2: ()) diff --git a/src/tf_train_sacred.py b/src/tf_train_sacred.py new file mode 100644 index 0000000..5da441e --- /dev/null +++ b/src/tf_train_sacred.py @@ -0,0 +1,21 @@ +from tf_train import * +from sacred import Experiment +from sacred.observers import FileStorageObserver + +ex = Experiment('s452639') +ex.observers.append(FileStorageObserver('experiments')) + +@ex.config +def params(): + epochs = 2 + +@ex.automain +def train_sacred(epochs: int): + def postprocess(artifact_path, history): + ex.add_artifact(artifact_path) + + for name, values in history.items(): + for value in values: + ex.log_scalar(name, value) + + train(epochs, lambda path: ex.open_resource(path), postprocess) diff --git a/training.jenkinsfile b/training.jenkinsfile index 7eb3f0d..fe78894 100644 --- a/training.jenkinsfile +++ b/training.jenkinsfile @@ -32,8 +32,8 @@ node { flatten: true, target: 'src/' - sh 'cd src; python tf_train.py $EPOCHS' - archiveArtifacts artifacts: 'src/model.keras', followSymlinks: false + sh 'cd src; python tf_train_sacred.py with epochs=$EPOCHS' + archiveArtifacts artifacts: 'src/model.keras,src/experiments/**/*', followSymlinks: false } }