Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
b097bdd620 | ||
|
7714196702 |
@ -1,5 +1,5 @@
|
|||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
|
|
||||||
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
|
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
|
||||||
RUN pip install pandas matplotlib scikit-learn tensorflow
|
RUN pip install pandas matplotlib scikit-learn tensorflow sacred
|
||||||
CMD "bash"
|
CMD "bash"
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from sklearn.preprocessing import LabelEncoder
|
|
||||||
|
|
||||||
pd.set_option('display.float_format', lambda x: '%.5f' % x)
|
pd.set_option('display.float_format', lambda x: '%.5f' % x)
|
||||||
|
|
||||||
le = LabelEncoder()
|
le = LabelEncoder()
|
||||||
le.fit([x.strip() for x in pd.read_csv('./stop_times.categories.tsv', sep='\t')['stop_headsign'].to_numpy()])
|
le.fit([x.strip() for x in pd.read_csv('./stop_times.categories.tsv', sep='\t')['stop_headsign'].to_numpy()])
|
||||||
|
|
||||||
def load_data(path: str, le: LabelEncoder):
|
def load_data(path: str, le: LabelEncoder, open_file):
|
||||||
train = pd.read_csv(path, sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str })
|
train = pd.read_csv(open_file(path), sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str })
|
||||||
|
|
||||||
departure_time = train['departure_time'].to_numpy()
|
departure_time = train['departure_time'].to_numpy()
|
||||||
stop_id = train['stop_id'].to_numpy()
|
stop_id = train['stop_id'].to_numpy()
|
||||||
@ -23,7 +23,7 @@ def load_data(path: str, le: LabelEncoder):
|
|||||||
num_classes = len(le.classes_)
|
num_classes = len(le.classes_)
|
||||||
|
|
||||||
|
|
||||||
def train(epochs: int):
|
def train(epochs: int, open_file, add_artifact):
|
||||||
global le
|
global le
|
||||||
|
|
||||||
model = tf.keras.Sequential([
|
model = tf.keras.Sequential([
|
||||||
@ -38,18 +38,19 @@ def train(epochs: int):
|
|||||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
train_x, train_y, _ = load_data('./stop_times.train.tsv', le)
|
train_x, train_y, _ = load_data('./stop_times.train.tsv', le, open_file)
|
||||||
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
|
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
|
||||||
train_y = tf.convert_to_tensor(train_y)
|
train_y = tf.convert_to_tensor(train_y)
|
||||||
|
|
||||||
valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le)
|
valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le, open_file)
|
||||||
valid_x = tf.convert_to_tensor(valid_x, dtype=tf.float32)
|
valid_x = tf.convert_to_tensor(valid_x, dtype=tf.float32)
|
||||||
valid_y = tf.convert_to_tensor(valid_y)
|
valid_y = tf.convert_to_tensor(valid_y)
|
||||||
|
|
||||||
model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024)
|
history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024)
|
||||||
model.save('model.keras')
|
model.save('model.keras')
|
||||||
|
add_artifact('model.keras', history.history)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
|
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
|
||||||
train(epochs)
|
train(epochs, lambda x: x, lambda _1, _2: ())
|
||||||
|
21
src/tf_train_sacred.py
Normal file
21
src/tf_train_sacred.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from tf_train import *
|
||||||
|
from sacred import Experiment
|
||||||
|
from sacred.observers import FileStorageObserver
|
||||||
|
|
||||||
|
ex = Experiment('s452639')
|
||||||
|
ex.observers.append(FileStorageObserver('experiments'))
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def params():
|
||||||
|
epochs = 2
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
|
def train_sacred(epochs: int):
|
||||||
|
def postprocess(artifact_path, history):
|
||||||
|
ex.add_artifact(artifact_path)
|
||||||
|
|
||||||
|
for name, values in history.items():
|
||||||
|
for value in values:
|
||||||
|
ex.log_scalar(name, value)
|
||||||
|
|
||||||
|
train(epochs, lambda path: ex.open_resource(path), postprocess)
|
@ -32,8 +32,8 @@ node {
|
|||||||
flatten: true,
|
flatten: true,
|
||||||
target: 'src/'
|
target: 'src/'
|
||||||
|
|
||||||
sh 'cd src; python tf_train.py $EPOCHS'
|
sh "cd src; python tf_train_sacred.py with epochs=${params.EPOCHS}"
|
||||||
archiveArtifacts artifacts: 'src/model.keras', followSymlinks: false
|
archiveArtifacts artifacts: 'src/model.keras,src/experiments/**/*', followSymlinks: false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user