This commit is contained in:
parent
09fd3c44ab
commit
aab43d29a5
@ -76,7 +76,8 @@ def encode(batch: [(torch.tensor, str)], in_alphabet, max_len):
|
||||
|
||||
|
||||
def encode_str(batch: [(str, str)], in_alphabet, max_len):
|
||||
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in batch]
|
||||
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in
|
||||
batch]
|
||||
return encode(batch, in_alphabet, max_len)
|
||||
|
||||
|
||||
@ -198,6 +199,7 @@ def signature(model,in_alphabet,max_len):
|
||||
mock_y = model(mock_text)
|
||||
return mlflow.models.signature.infer_signature(mock_text, mock_y)
|
||||
|
||||
|
||||
@ex.automain
|
||||
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
|
||||
total_out_len, model_file, out_lookup, in_lookup, mode):
|
||||
@ -227,20 +229,13 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili
|
||||
|
||||
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
|
||||
in_alphabet=in_alphabet, max_len=max_len).to(device)
|
||||
if os.path.isfile(model_file):
|
||||
cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
|
||||
else:
|
||||
if mode == 'train':
|
||||
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
|
||||
torch.save(cnn.state_dict(), model_file)
|
||||
ex.add_artifact(model_file)
|
||||
|
||||
log_artifacts(model_file)
|
||||
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
|
||||
signature=signature(cnn, in_alphabet, max_len))
|
||||
log_artifacts(model_file)
|
||||
else:
|
||||
print(model_file + " missing!")
|
||||
exit(2)
|
||||
|
||||
if mode == 'eval':
|
||||
cnn.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user