fix
Some checks failed
s434749-training/pipeline/head There was a failure building this commit

This commit is contained in:
Alagris 2021-05-23 19:56:11 +02:00
parent 09fd3c44ab
commit aab43d29a5

View File

@ -76,7 +76,8 @@ def encode(batch: [(torch.tensor, str)], in_alphabet, max_len):
def encode_str(batch: [(str, str)], in_alphabet, max_len): def encode_str(batch: [(str, str)], in_alphabet, max_len):
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in batch] batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in
batch]
return encode(batch, in_alphabet, max_len) return encode(batch, in_alphabet, max_len)
@ -198,6 +199,7 @@ def signature(model,in_alphabet,max_len):
mock_y = model(mock_text) mock_y = model(mock_text)
return mlflow.models.signature.infer_signature(mock_text, mock_y) return mlflow.models.signature.infer_signature(mock_text, mock_y)
@ex.automain @ex.automain
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len, def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
total_out_len, model_file, out_lookup, in_lookup, mode): total_out_len, model_file, out_lookup, in_lookup, mode):
@ -227,20 +229,13 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len, cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
in_alphabet=in_alphabet, max_len=max_len).to(device) in_alphabet=in_alphabet, max_len=max_len).to(device)
if os.path.isfile(model_file):
cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
else:
if mode == 'train': if mode == 'train':
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size) train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
torch.save(cnn.state_dict(), model_file) torch.save(cnn.state_dict(), model_file)
ex.add_artifact(model_file) ex.add_artifact(model_file)
log_artifacts(model_file)
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings", mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
signature=signature(cnn, in_alphabet, max_len)) signature=signature(cnn, in_alphabet, max_len))
log_artifacts(model_file)
else:
print(model_file + " missing!")
exit(2)
if mode == 'eval': if mode == 'eval':
cnn.eval() cnn.eval()