This commit is contained in:
parent
09fd3c44ab
commit
aab43d29a5
@ -76,7 +76,8 @@ def encode(batch: [(torch.tensor, str)], in_alphabet, max_len):
|
|||||||
|
|
||||||
|
|
||||||
def encode_str(batch: [(str, str)], in_alphabet, max_len):
|
def encode_str(batch: [(str, str)], in_alphabet, max_len):
|
||||||
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in batch]
|
batch = [(torch.tensor([in_alphabet[letter] for letter in in_str], dtype=torch.int), out_str) for in_str, out_str in
|
||||||
|
batch]
|
||||||
return encode(batch, in_alphabet, max_len)
|
return encode(batch, in_alphabet, max_len)
|
||||||
|
|
||||||
|
|
||||||
@ -198,6 +199,7 @@ def signature(model,in_alphabet,max_len):
|
|||||||
mock_y = model(mock_text)
|
mock_y = model(mock_text)
|
||||||
return mlflow.models.signature.infer_signature(mock_text, mock_y)
|
return mlflow.models.signature.infer_signature(mock_text, mock_y)
|
||||||
|
|
||||||
|
|
||||||
@ex.automain
|
@ex.automain
|
||||||
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
|
def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probability, learning_rate, batch_size, max_len,
|
||||||
total_out_len, model_file, out_lookup, in_lookup, mode):
|
total_out_len, model_file, out_lookup, in_lookup, mode):
|
||||||
@ -227,20 +229,13 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili
|
|||||||
|
|
||||||
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
|
cnn = CNN(kernel_size=kernel_size, hidden_layers=hidden_layers, channels=max_len, embedding_size=max_len,
|
||||||
in_alphabet=in_alphabet, max_len=max_len).to(device)
|
in_alphabet=in_alphabet, max_len=max_len).to(device)
|
||||||
if os.path.isfile(model_file):
|
|
||||||
cnn.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
|
|
||||||
else:
|
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
|
train_model(cnn, learning_rate, in_alphabet, max_len, data, epochs, batch_size)
|
||||||
torch.save(cnn.state_dict(), model_file)
|
torch.save(cnn.state_dict(), model_file)
|
||||||
ex.add_artifact(model_file)
|
ex.add_artifact(model_file)
|
||||||
|
log_artifacts(model_file)
|
||||||
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
|
mlflow.pytorch.log_model(cnn, "cnn-model", registered_model_name="PhoneticEdDistEmbeddings",
|
||||||
signature=signature(cnn, in_alphabet, max_len))
|
signature=signature(cnn, in_alphabet, max_len))
|
||||||
log_artifacts(model_file)
|
|
||||||
else:
|
|
||||||
print(model_file + " missing!")
|
|
||||||
exit(2)
|
|
||||||
|
|
||||||
if mode == 'eval':
|
if mode == 'eval':
|
||||||
cnn.eval()
|
cnn.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user