This commit is contained in:
Robert Bendun 2023-05-12 23:15:32 +02:00
parent 59ea069a07
commit 2deab093cf

View File

@ -35,7 +35,8 @@ model.compile(optimizer='adam',
metrics=['accuracy']) metrics=['accuracy'])
def train(): def train():
train_x, train_y, le = load_data('./stop_times.train.tsv', le) global model, le
train_x, train_y, _ = load_data('./stop_times.train.tsv', le)
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32) train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
train_y = tf.convert_to_tensor(train_y) train_y = tf.convert_to_tensor(train_y)
@ -52,6 +53,7 @@ def train():
model.save_weights('model.ckpt') model.save_weights('model.ckpt')
def test(): def test():
global model, le
test_x, test_y, _ = load_data('./stop_times.test.tsv', le) test_x, test_y, _ = load_data('./stop_times.test.tsv', le)
test_x = tf.convert_to_tensor(test_x, dtype=tf.float32) test_x = tf.convert_to_tensor(test_x, dtype=tf.float32)
test_y = tf.convert_to_tensor(test_y) test_y = tf.convert_to_tensor(test_y)
@ -64,6 +66,6 @@ SUBCOMMANDS = {
} }
import sys import sys
assert sys.argv == 2 assert len(sys.argv) == 2
assert sys.argv[1] in SUBCOMMANDS.keys() assert sys.argv[1] in SUBCOMMANDS.keys()
SUBCOMMANDS[sys.argv[1]]() SUBCOMMANDS[sys.argv[1]]()