🎩
This commit is contained in:
parent
122d4c6357
commit
7714196702
@ -1,5 +1,5 @@
|
||||
FROM ubuntu:22.04
|
||||
|
||||
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
|
||||
RUN pip install pandas matplotlib scikit-learn tensorflow
|
||||
RUN pip install pandas matplotlib scikit-learn tensorflow sacred
|
||||
CMD "bash"
|
||||
|
@ -1,15 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
pd.set_option('display.float_format', lambda x: '%.5f' % x)
|
||||
|
||||
le = LabelEncoder()
|
||||
le.fit([x.strip() for x in pd.read_csv('./stop_times.categories.tsv', sep='\t')['stop_headsign'].to_numpy()])
|
||||
|
||||
def load_data(path: str, le: LabelEncoder):
|
||||
train = pd.read_csv(path, sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str })
|
||||
def load_data(path: str, le: LabelEncoder, open_file):
|
||||
train = pd.read_csv(open_file(path), sep='\t', dtype={ 'departure_time': float, 'stop_id': str, 'stop_headsign': str })
|
||||
|
||||
departure_time = train['departure_time'].to_numpy()
|
||||
stop_id = train['stop_id'].to_numpy()
|
||||
@ -23,7 +23,7 @@ def load_data(path: str, le: LabelEncoder):
|
||||
num_classes = len(le.classes_)
|
||||
|
||||
|
||||
def train(epochs: int):
|
||||
def train(epochs: int, open_file, add_artifact):
|
||||
global le
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
@ -38,18 +38,19 @@ def train(epochs: int):
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
train_x, train_y, _ = load_data('./stop_times.train.tsv', le)
|
||||
train_x, train_y, _ = load_data('./stop_times.train.tsv', le, open_file)
|
||||
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
|
||||
train_y = tf.convert_to_tensor(train_y)
|
||||
|
||||
valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le)
|
||||
valid_x, valid_y, _ = load_data('./stop_times.valid.tsv', le, open_file)
|
||||
valid_x = tf.convert_to_tensor(valid_x, dtype=tf.float32)
|
||||
valid_y = tf.convert_to_tensor(valid_y)
|
||||
|
||||
model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024)
|
||||
history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=epochs, batch_size=1024)
|
||||
model.save('model.keras')
|
||||
add_artifact('model.keras', history.history)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
|
||||
train(epochs)
|
||||
train(epochs, lambda x: x, lambda _1, _2: ())
|
||||
|
21
src/tf_train_sacred.py
Normal file
21
src/tf_train_sacred.py
Normal file
@ -0,0 +1,21 @@
|
||||
from tf_train import *
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver
|
||||
|
||||
ex = Experiment('s452639')
|
||||
ex.observers.append(FileStorageObserver('experiments'))
|
||||
|
||||
@ex.config
|
||||
def params():
|
||||
epochs = 2
|
||||
|
||||
@ex.automain
|
||||
def train_sacred(epochs: int):
|
||||
def postprocess(artifact_path, history):
|
||||
ex.add_artifact(artifact_path)
|
||||
|
||||
for name, values in history.items():
|
||||
for value in values:
|
||||
ex.log_scalar(name, value)
|
||||
|
||||
train(epochs, lambda path: ex.open_resource(path), postprocess)
|
@ -32,8 +32,8 @@ node {
|
||||
flatten: true,
|
||||
target: 'src/'
|
||||
|
||||
sh 'cd src; python tf_train.py $EPOCHS'
|
||||
archiveArtifacts artifacts: 'src/model.keras', followSymlinks: false
|
||||
sh 'cd src; python tf_train_sacred.py with epochs=$EPOCHS'
|
||||
archiveArtifacts artifacts: 'src/model.keras,src/experiments/**/*', followSymlinks: false
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user