From aab43d29a5bc77cb846458f99e40483410f78ae0 Mon Sep 17 00:00:00 2001 From: Alagris Date: Sun, 23 May 2021 19:56:11 +0200 Subject: [PATCH] fix --- train_model.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/train_model.py b/train_model.py index 8fe1dff..6e8c283 100644 --- a/train_model.py +++ b/train_model.py @@ -76,7 +76,8 @@ def encode(batch: [(torch.tensor, str)], in_alphabet, max_len): def encode_str(batch: [(str, str)], in_alphabet, max_len): - batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in batch] + batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in + batch] return encode(batch, in_alphabet, max_len) @@ -192,12 +193,13 @@ def cfg(): 'u', 'v', 'w', 'x', 'y', 'z'] -def signature(model,in_alphabet,max_len): +def signature(model, in_alphabet, max_len): mock_x = [('abc', 'xyz')] mock_text, _ = encode_str(mock_x, in_alphabet, max_len) mock_y = model(mock_text) return mlflow.models.signature.infer_signature(mock_text, mock_y) + @ex.automain def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len, total_out_len, model_file, out_lookup, in_lookup, mode): @@ -227,20 +229,13 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len, in_alphabet=in_alphabet, max_len=max_len).to(device) - if os.path.isfile(model_file): - cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'))) - else: - if mode == 'train': - train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size) - torch.save(cnn.state_dict(), model_file) - ex.add_artifact(model_file) - - mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings", - signature=signature(cnn,in_alphabet, max_len)) - log_artifacts(model_file) - else: - print(model_file + " missing!") - exit(2) + if mode == 'train': + train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size) + torch.save(cnn.state_dict(), model_file) + ex.add_artifact(model_file) + log_artifacts(model_file) + mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings", + signature=signature(cnn, in_alphabet, max_len)) if mode == 'eval': cnn.eval()