wmt-2017-cs-en/predict.py

59 lines
2.0 KiB
Python
Raw Normal View History

2021-01-27 04:01:04 +01:00
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device, EOS_token
import pickle
from lstm_model import EncoderRNN, DecoderRNN
import sys
import torch
with open('data/pl_lang.pkl', 'rb') as input_file:
input_lang = pickle.load(input_file)
with open('data/en_lang.pkl', 'rb') as out_file:
output_lang = pickle.load(out_file)
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
with torch.no_grad():
input_tensor = tensorFromSentence(input_lang, sentence)
input_length = input_tensor.size()[0]
encoder_hidden = encoder.initHidden()
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
loss = 0
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei],
encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device).view(-1, 1) # SOS
decoder_hidden = encoder_hidden
decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length)
for di in range(max_length):
decoder_output, decoder_hidden = decoder(
decoder_input, decoder_hidden)
topv, topi = decoder_output.data.topk(1)
if topi.item() == EOS_token:
break
else:
decoded_words.append(output_lang.index2word[topi.item()])
decoder_input = topi.squeeze().detach()
return decoded_words
hidden_size = 256
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
encoder.load_state_dict(torch.load('encoder.dict'))
decoder.load_state_dict(torch.load('decoder.dict'))
for line in sys.stdin:
line = line.rstrip()
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
print(' '.join(dec_words))