from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device import pickle from lstm_model import EncoderRNN, DecoderRNN import sys import torch with open('data/pl_lang.pkl', 'rb') as input_file: input_lang = pickle.load(input_file) with open('data/en_lang.pkl', 'rb') as out_file: output_lang = pickle.load(out_file) def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): with torch.no_grad(): input_tensor = tensorFromSentence(input_lang, sentence) encoder_hidden = encoder.initHidden() encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) encoder_output, encoder_hidden = encoder(input_tensor,encoder_hidden) encoder_outputs = encoder_output decoder_input = torch.tensor([[SOS_token]], device=device) decoder_hidden = encoder_hidden decoded_words = [] for di in range(max_length): decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden) topv, topi = decoder_output.topk(1) decoded_words.append(topi) decoder_input = topi.transpose(0, 1) out = torch.stack(decoded_words) return out hidden_size = 256 encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device) decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device) encoder.load_state_dict(torch.load('encoder.dict')) decoder.load_state_dict(torch.load('decoder.dict')) encoder.eval() decoder.eval() for line in sys.stdin: line = line.rstrip() dec_words = evaluate(encoder, decoder, line, MAX_LENGTH) dec_words = dec_words.transpose(0, 1) for sen in dec_words: out = [] for idx in sen: if idx == 0: break out.append(output_lang.index2word[idx.item()]) print(' '.join(out))