diff --git a/src/tf.py b/src/tf.py index f5459c9..f440e3b 100755 --- a/src/tf.py +++ b/src/tf.py @@ -35,7 +35,8 @@ model.compile(optimizer='adam', metrics=['accuracy']) def train(): - train_x, train_y, le = load_data('./stop_times.train.tsv', le) + global model, le + train_x, train_y, _ = load_data('./stop_times.train.tsv', le) train_x = tf.convert_to_tensor(train_x, dtype=tf.float32) train_y = tf.convert_to_tensor(train_y) @@ -52,6 +53,7 @@ def train(): model.save_weights('model.ckpt') def test(): + global model, le test_x, test_y, _ = load_data('./stop_times.test.tsv', le) test_x = tf.convert_to_tensor(test_x, dtype=tf.float32) test_y = tf.convert_to_tensor(test_y) @@ -64,6 +66,6 @@ SUBCOMMANDS = { } import sys -assert sys.argv == 2 +assert len(sys.argv) == 2 assert sys.argv[1] in SUBCOMMANDS.keys() SUBCOMMANDS[sys.argv[1]]()