bugs
This commit is contained in:
parent
59ea069a07
commit
2deab093cf
@ -35,7 +35,8 @@ model.compile(optimizer='adam',
|
|||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
train_x, train_y, le = load_data('./stop_times.train.tsv', le)
|
global model, le
|
||||||
|
train_x, train_y, _ = load_data('./stop_times.train.tsv', le)
|
||||||
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
|
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)
|
||||||
train_y = tf.convert_to_tensor(train_y)
|
train_y = tf.convert_to_tensor(train_y)
|
||||||
|
|
||||||
@ -52,6 +53,7 @@ def train():
|
|||||||
model.save_weights('model.ckpt')
|
model.save_weights('model.ckpt')
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
|
global model, le
|
||||||
test_x, test_y, _ = load_data('./stop_times.test.tsv', le)
|
test_x, test_y, _ = load_data('./stop_times.test.tsv', le)
|
||||||
test_x = tf.convert_to_tensor(test_x, dtype=tf.float32)
|
test_x = tf.convert_to_tensor(test_x, dtype=tf.float32)
|
||||||
test_y = tf.convert_to_tensor(test_y)
|
test_y = tf.convert_to_tensor(test_y)
|
||||||
@ -64,6 +66,6 @@ SUBCOMMANDS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
assert sys.argv == 2
|
assert len(sys.argv) == 2
|
||||||
assert sys.argv[1] in SUBCOMMANDS.keys()
|
assert sys.argv[1] in SUBCOMMANDS.keys()
|
||||||
SUBCOMMANDS[sys.argv[1]]()
|
SUBCOMMANDS[sys.argv[1]]()
|
||||||
|
Loading…
Reference in New Issue
Block a user