57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device
|
|
import pickle
|
|
from lstm_model import EncoderRNN, DecoderRNN
|
|
import sys
|
|
import torch
|
|
|
|
with open('data/pl_lang.pkl', 'rb') as input_file:
|
|
input_lang = pickle.load(input_file)
|
|
|
|
with open('data/en_lang.pkl', 'rb') as out_file:
|
|
output_lang = pickle.load(out_file)
|
|
|
|
|
|
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
|
|
with torch.no_grad():
|
|
input_tensor = tensorFromSentence(input_lang, sentence)
|
|
encoder_hidden = encoder.initHidden()
|
|
|
|
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
|
|
|
encoder_output, encoder_hidden = encoder(input_tensor,encoder_hidden)
|
|
|
|
encoder_outputs = encoder_output
|
|
|
|
decoder_input = torch.tensor([[SOS_token]], device=device)
|
|
|
|
decoder_hidden = encoder_hidden
|
|
|
|
decoded_words = []
|
|
|
|
for di in range(max_length):
|
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
|
topv, topi = decoder_output.topk(1)
|
|
decoded_words.append(topi)
|
|
decoder_input = topi.transpose(0, 1)
|
|
out = torch.stack(decoded_words)
|
|
return out
|
|
|
|
hidden_size = 256
|
|
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
|
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
|
encoder.load_state_dict(torch.load('encoder.dict'))
|
|
decoder.load_state_dict(torch.load('decoder.dict'))
|
|
encoder.eval()
|
|
decoder.eval()
|
|
for line in sys.stdin:
|
|
line = line.rstrip()
|
|
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
|
|
dec_words = dec_words.transpose(0, 1)
|
|
for sen in dec_words:
|
|
out = []
|
|
for idx in sen:
|
|
if idx == 0:
|
|
break
|
|
out.append(output_lang.index2word[idx.item()])
|
|
print(' '.join(out))
|