diff --git a/src/predict.py b/src/predict.py new file mode 100644 index 0000000..3c4a1d1 --- /dev/null +++ b/src/predict.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +import argparse +import torch +import pickle +import unicodedata +import re + +from Decoder import Decoder +from Encoder import Encoder + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +SOS=0 +EOS=1 +MAX_LENGTH=25 + +def clear_line(string): + string = ''.join( + c for c in unicodedata.normalize('NFD', string) + if unicodedata.category(c) != 'Mn' + ) + string = re.sub("[^a-z ]", "", string.lower()) + return string + +def read_clear_data(in_file_path): + print("Reading data...") + data = [] + with open(in_file_path) as in_file: + for line in in_file: + string = clear_line(line) + if len(string.split(' ')) < MAX_LENGTH: + data.append(string) + else: + data.append("To bedzie zle") + return data + +def indexes_from_sentence(vocab, sentence): + a = [] + for word in sentence.split(' '): + try: + index = vocab.word2index[word] + a.append(index) + except: + a.append(vocab.word2index['jest']) + return a + #return [vocab.word2index[word] for word in sentence.split(' ')] + +def tensor_from_sentece(vocab, sentence): + indexes = indexes_from_sentence(vocab, sentence) + indexes.append(EOS) + return torch.tensor(indexes, dtype=torch.long, device=device).view(-1,1) + +def eval(encoder, decoder, sentence, input_vocab, target_vocab, max_length=MAX_LENGTH): + with torch.no_grad(): + input_tensor = tensor_from_sentece(input_vocab, sentence) + input_length = input_tensor.size()[0] + + encoder_hidden = encoder.init_hidden(device) + + encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) + + for i in range(input_length): + encoder_output, encoder_hidden = encoder(input_tensor[i], encoder_hidden) + encoder_outputs[i] += encoder_output[0, 0] + + decoder_input = torch.tensor([[SOS]], device=device) + + decoder_hidden = encoder_hidden + decoder_words = [] + + for i in range(max_length): + decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden) + topv, topi = decoder_output.data.topk(1) + if topi.item() == EOS: + decoder_words.append('') + break + else: + decoder_words.append(target_vocab.index2word[topi.item()]) + decoder_input = topi.squeeze().detach() + return decoder_words + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--in_f') + parser.add_argument('--out_f') + parser.add_argument('--vocab') + parser.add_argument('--encoder') + parser.add_argument('--decoder') + args = parser.parse_args() + + with open(args.vocab, 'rb') as p: + pairs, input_vocab, target_vocab = pickle.load(p) + + hidden_size = 256 + encoder = Encoder(input_vocab.size, hidden_size).to(device) + decoder = Decoder(hidden_size,target_vocab.size).to(device) + + encoder.load_state_dict(torch.load(args.encoder)) + decoder.load_state_dict(torch.load(args.decoder)) + + data = read_clear_data(args.in_f) + with open(args.out_f, 'w+') as f: + for line in data: + out = eval(encoder, decoder, line, input_vocab, target_vocab, max_length=MAX_LENGTH) + f.write(" ".join(out) + "\n") + +main()