from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device, EOS_token import pickle from lstm_model import EncoderRNN, DecoderRNN import sys import torch with open('data/pl_lang.pkl', 'rb') as input_file: input_lang = pickle.load(input_file) with open('data/en_lang.pkl', 'rb') as out_file: output_lang = pickle.load(out_file) def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): with torch.no_grad(): input_tensor = tensorFromSentence(input_lang, sentence) input_length = input_tensor.size()[0] encoder_hidden = encoder.initHidden() encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) loss = 0 for ei in range(input_length): encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0, 0] decoder_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device).view(-1, 1) # SOS decoder_hidden = encoder_hidden decoded_words = [] decoder_attentions = torch.zeros(max_length, max_length) for di in range(max_length): decoder_output, decoder_hidden = decoder( decoder_input, decoder_hidden) topv, topi = decoder_output.data.topk(1) if topi.item() == EOS_token: break else: decoded_words.append(output_lang.index2word[topi.item()]) decoder_input = topi.squeeze().detach() return decoded_words hidden_size = 256 encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device) decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device) encoder.load_state_dict(torch.load('encoder.dict')) decoder.load_state_dict(torch.load('decoder.dict')) for line in sys.stdin: line = line.rstrip() dec_words = evaluate(encoder, decoder, line, MAX_LENGTH) print(' '.join(dec_words))