Compare commits

...

2 Commits
ztm ... sacred

Author SHA1 Message Date
Robert Bendun
b097bdd620 🏄 2023-05-15 03:54:56 +02:00
Robert Bendun
7714196702 🎩 2023-05-15 03:50:02 +02:00
4 changed files with 33 additions and 11 deletions

View File

@ -1,5 +1,5 @@
FROM ubuntu:22.04 FROM ubuntu:22.04
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
RUN pip install pandas matplotlib scikit-learn tensorflow RUN pip install pandas matplotlib scikit-learn tensorflow sacred
CMD "bash" CMD "bash"

View File

@ -1,15 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from sklearn.preprocessing import LabelEncoder
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from sklearn.preprocessing import LabelEncoder
pd.set_option('display.float_format', lambda x: '%.5f' % x) pd.set_option('display.float_format', lambda x: '%.5f' % x)
le = LabelEncoder() le = LabelEncoder()
le.fit([x.strip() for x in pd.read_csv('./stop_times.categories.tsv', sep='\t')['stop_headsign'].to_numpy()]) le.fit([x.strip() for x in pd.read_csv('./stop_times.categories.tsv', sep='\t')['stop_headsign'].to_numpy()])
def load_data(path: str, le: LabelEncoder): def load_data(path: str, le: LabelEncoder, open_file):
train = pd.read_csv(path, sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str }) train = pd.read_csv(open_file(path), sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str })
departure_time = train['departure_time'].to_numpy() departure_time = train['departure_time'].to_numpy()
stop_id = train['stop_id'].to_numpy() stop_id = train['stop_id'].to_numpy()
@ -23,7 +23,7 @@ def load_data(path: str, le: LabelEncoder):
num_classes = len(le.classes_) num_classes = len(le.classes_)
def train(epochs: int): def train(epochs: int, open_file, add_artifact):
global le global le
model = tf.keras.Sequential([ model = tf.keras.Sequential([
@ -38,18 +38,19 @@ def train(epochs: int):
loss=tf.keras.losses.SparseCategoricalCrossentropy(), loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']) metrics=['accuracy'])
train_x, train_y, _ = load_data('./stop_times.train.tsv', le) train_x, train_y, _ = load_data('./stop_times.train.tsv', le, open_file)
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32) train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
train_y = tf.convert_to_tensor(train_y) train_y = tf.convert_to_tensor(train_y)
valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le) valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le, open_file)
valid_x = tf.convert_to_tensor(valid_x, dtype=tf.float32) valid_x = tf.convert_to_tensor(valid_x, dtype=tf.float32)
valid_y = tf.convert_to_tensor(valid_y) valid_y = tf.convert_to_tensor(valid_y)
model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024) history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024)
model.save('model.keras') model.save('model.keras')
add_artifact('model.keras', history.history)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1]) epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
train(epochs) train(epochs, lambda x: x, lambda _1, _2: ())

21
src/tf_train_sacred.py Normal file
View File

@ -0,0 +1,21 @@
from tf_train import *
from sacred import Experiment
from sacred.observers import FileStorageObserver
ex = Experiment('s452639')
ex.observers.append(FileStorageObserver('experiments'))
@ex.config
def params():
epochs = 2
@ex.automain
def train_sacred(epochs: int):
def postprocess(artifact_path, history):
ex.add_artifact(artifact_path)
for name, values in history.items():
for value in values:
ex.log_scalar(name, value)
train(epochs, lambda path: ex.open_resource(path), postprocess)

View File

@ -32,8 +32,8 @@ node {
flatten: true, flatten: true,
target: 'src/' target: 'src/'
sh 'cd src; python tf_train.py $EPOCHS' sh "cd src; python tf_train_sacred.py with epochs=${params.EPOCHS}"
archiveArtifacts artifacts: 'src/model.keras', followSymlinks: false archiveArtifacts artifacts: 'src/model.keras,src/experiments/**/*', followSymlinks: false
} }
} }