wmt-2020-pl-en/predict.py
2021-01-31 16:54:20 +01:00

57 lines
1.8 KiB
Python

from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device
import pickle
from lstm_model import EncoderRNN, DecoderRNN
import sys
import torch
with open('data/pl_lang.pkl', 'rb') as input_file:
input_lang = pickle.load(input_file)
with open('data/en_lang.pkl', 'rb') as out_file:
output_lang = pickle.load(out_file)
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
with torch.no_grad():
input_tensor = tensorFromSentence(input_lang, sentence)
encoder_hidden = encoder.initHidden()
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
encoder_output, encoder_hidden = encoder(input_tensor,encoder_hidden)
encoder_outputs = encoder_output
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden = encoder_hidden
decoded_words = []
for di in range(max_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
topv, topi = decoder_output.topk(1)
decoded_words.append(topi)
decoder_input = topi.transpose(0, 1)
out = torch.stack(decoded_words)
return out
hidden_size = 256
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
encoder.load_state_dict(torch.load('encoder.dict'))
decoder.load_state_dict(torch.load('decoder.dict'))
encoder.eval()
decoder.eval()
for line in sys.stdin:
line = line.rstrip()
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
dec_words = dec_words.transpose(0, 1)
for sen in dec_words:
out = []
for idx in sen:
if idx == 0:
break
out.append(output_lang.index2word[idx.item()])
print(' '.join(out))