fix
Some checks failed
s434749-training/pipeline/head There was a failure building this commit

This commit is contained in:
Alagris 2021-05-23 19:56:11 +02:00
parent 09fd3c44ab
commit aab43d29a5

View File

@ -76,7 +76,8 @@ def encode(batch: [(torch.tensor, str)], in_alphabet, max_len):
def encode_str(batch: [(str, str)], in_alphabet, max_len):
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in batch]
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in
batch]
return encode(batch, in_alphabet, max_len)
@ -192,12 +193,13 @@ def cfg():
'u', 'v', 'w', 'x', 'y', 'z']
def signature(model,in_alphabet,max_len):
def signature(model, in_alphabet, max_len):
mock_x = [('abc', 'xyz')]
mock_text, _ = encode_str(mock_x, in_alphabet, max_len)
mock_y = model(mock_text)
return mlflow.models.signature.infer_signature(mock_text, mock_y)
@ex.automain
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
total_out_len, model_file, out_lookup, in_lookup, mode):
@ -227,20 +229,13 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
in_alphabet=in_alphabet, max_len=max_len).to(device)
if os.path.isfile(model_file):
cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
else:
if mode == 'train':
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
torch.save(cnn.state_dict(), model_file)
ex.add_artifact(model_file)
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
signature=signature(cnn,in_alphabet, max_len))
log_artifacts(model_file)
else:
print(model_file + " missing!")
exit(2)
if mode == 'train':
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
torch.save(cnn.state_dict(), model_file)
ex.add_artifact(model_file)
log_artifacts(model_file)
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
signature=signature(cnn, in_alphabet, max_len))
if mode == 'eval':
cnn.eval()