59 lines
2.0 KiB
Python
59 lines
2.0 KiB
Python
|
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device, EOS_token
|
||
|
import pickle
|
||
|
from lstm_model import EncoderRNN, DecoderRNN
|
||
|
import sys
|
||
|
import torch
|
||
|
|
||
|
with open('data/pl_lang.pkl', 'rb') as input_file:
|
||
|
input_lang = pickle.load(input_file)
|
||
|
|
||
|
with open('data/en_lang.pkl', 'rb') as out_file:
|
||
|
output_lang = pickle.load(out_file)
|
||
|
|
||
|
|
||
|
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
|
||
|
with torch.no_grad():
|
||
|
input_tensor = tensorFromSentence(input_lang, sentence)
|
||
|
input_length = input_tensor.size()[0]
|
||
|
encoder_hidden = encoder.initHidden()
|
||
|
|
||
|
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||
|
|
||
|
loss = 0
|
||
|
|
||
|
for ei in range(input_length):
|
||
|
encoder_output, encoder_hidden = encoder(input_tensor[ei],
|
||
|
encoder_hidden)
|
||
|
encoder_outputs[ei] = encoder_output[0, 0]
|
||
|
|
||
|
decoder_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device).view(-1, 1) # SOS
|
||
|
|
||
|
decoder_hidden = encoder_hidden
|
||
|
|
||
|
decoded_words = []
|
||
|
decoder_attentions = torch.zeros(max_length, max_length)
|
||
|
|
||
|
for di in range(max_length):
|
||
|
decoder_output, decoder_hidden = decoder(
|
||
|
decoder_input, decoder_hidden)
|
||
|
topv, topi = decoder_output.data.topk(1)
|
||
|
if topi.item() == EOS_token:
|
||
|
break
|
||
|
else:
|
||
|
decoded_words.append(output_lang.index2word[topi.item()])
|
||
|
|
||
|
decoder_input = topi.squeeze().detach()
|
||
|
|
||
|
return decoded_words
|
||
|
|
||
|
hidden_size = 256
|
||
|
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
||
|
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
||
|
encoder.load_state_dict(torch.load('encoder.dict'))
|
||
|
decoder.load_state_dict(torch.load('decoder.dict'))
|
||
|
|
||
|
for line in sys.stdin:
|
||
|
line = line.rstrip()
|
||
|
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
|
||
|
print(' '.join(dec_words))
|