Add preddict

This commit is contained in:
SzamanFL 2021-01-24 00:14:40 +01:00
parent 51d468f7aa
commit e45b0118bd

108
src/predict.py Normal file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env python3
import argparse
import torch
import pickle
import unicodedata
import re
from Decoder import Decoder
from Encoder import Encoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS=0
EOS=1
MAX_LENGTH=25
def clear_line(string):
string = ''.join(
c for c in unicodedata.normalize('NFD', string)
if unicodedata.category(c) != 'Mn'
)
string = re.sub("[^a-z ]", "", string.lower())
return string
def read_clear_data(in_file_path):
print("Reading data...")
data = []
with open(in_file_path) as in_file:
for line in in_file:
string = clear_line(line)
if len(string.split(' ')) < MAX_LENGTH:
data.append(string)
else:
data.append("To bedzie zle")
return data
def indexes_from_sentence(vocab, sentence):
a = []
for word in sentence.split(' '):
try:
index = vocab.word2index[word]
a.append(index)
except:
a.append(vocab.word2index['jest'])
return a
#return [vocab.word2index[word] for word in sentence.split(' ')]
def tensor_from_sentece(vocab, sentence):
indexes = indexes_from_sentence(vocab, sentence)
indexes.append(EOS)
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1,1)
def eval(encoder, decoder, sentence, input_vocab, target_vocab, max_length=MAX_LENGTH):
with torch.no_grad():
input_tensor = tensor_from_sentece(input_vocab, sentence)
input_length = input_tensor.size()[0]
encoder_hidden = encoder.init_hidden(device)
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
for i in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[i], encoder_hidden)
encoder_outputs[i] += encoder_output[0, 0]
decoder_input = torch.tensor([[SOS]], device=device)
decoder_hidden = encoder_hidden
decoder_words = []
for i in range(max_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
topv, topi = decoder_output.data.topk(1)
if topi.item() == EOS:
decoder_words.append('<EOS>')
break
else:
decoder_words.append(target_vocab.index2word[topi.item()])
decoder_input = topi.squeeze().detach()
return decoder_words
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in_f')
parser.add_argument('--out_f')
parser.add_argument('--vocab')
parser.add_argument('--encoder')
parser.add_argument('--decoder')
args = parser.parse_args()
with open(args.vocab, 'rb') as p:
pairs, input_vocab, target_vocab = pickle.load(p)
hidden_size = 256
encoder = Encoder(input_vocab.size, hidden_size).to(device)
decoder = Decoder(hidden_size,target_vocab.size).to(device)
encoder.load_state_dict(torch.load(args.encoder))
decoder.load_state_dict(torch.load(args.decoder))
data = read_clear_data(args.in_f)
with open(args.out_f, 'w+') as f:
for line in data:
out = eval(encoder, decoder, line, input_vocab, target_vocab, max_length=MAX_LENGTH)
f.write(" ".join(out) + "\n")
main()