bugs
This commit is contained in:
parent
59ea069a07
commit
2deab093cf
@ -35,7 +35,8 @@ model.compile(optimizer='adam',
|
||||
metrics=['accuracy'])
|
||||
|
||||
def train():
|
||||
train_x, train_y, le = load_data('./stop_times.train.tsv', le)
|
||||
global model, le
|
||||
train_x, train_y, _ = load_data('./stop_times.train.tsv', le)
|
||||
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
|
||||
train_y = tf.convert_to_tensor(train_y)
|
||||
|
||||
@ -52,6 +53,7 @@ def train():
|
||||
model.save_weights('model.ckpt')
|
||||
|
||||
def test():
|
||||
global model, le
|
||||
test_x, test_y, _ = load_data('./stop_times.test.tsv', le)
|
||||
test_x = tf.convert_to_tensor(test_x, dtype=tf.float32)
|
||||
test_y = tf.convert_to_tensor(test_y)
|
||||
@ -64,6 +66,6 @@ SUBCOMMANDS = {
|
||||
}
|
||||
|
||||
import sys
|
||||
assert sys.argv == 2
|
||||
assert len(sys.argv) == 2
|
||||
assert sys.argv[1] in SUBCOMMANDS.keys()
|
||||
SUBCOMMANDS[sys.argv[1]]()
|
||||
|
Loading…
Reference in New Issue
Block a user