Add preddict
This commit is contained in:
parent
51d468f7aa
commit
e45b0118bd
108
src/predict.py
Normal file
108
src/predict.py
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import pickle
|
||||
import unicodedata
|
||||
import re
|
||||
|
||||
from Decoder import Decoder
|
||||
from Encoder import Encoder
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
SOS=0
|
||||
EOS=1
|
||||
MAX_LENGTH=25
|
||||
|
||||
def clear_line(string):
|
||||
string = ''.join(
|
||||
c for c in unicodedata.normalize('NFD', string)
|
||||
if unicodedata.category(c) != 'Mn'
|
||||
)
|
||||
string = re.sub("[^a-z ]", "", string.lower())
|
||||
return string
|
||||
|
||||
def read_clear_data(in_file_path):
|
||||
print("Reading data...")
|
||||
data = []
|
||||
with open(in_file_path) as in_file:
|
||||
for line in in_file:
|
||||
string = clear_line(line)
|
||||
if len(string.split(' ')) < MAX_LENGTH:
|
||||
data.append(string)
|
||||
else:
|
||||
data.append("To bedzie zle")
|
||||
return data
|
||||
|
||||
def indexes_from_sentence(vocab, sentence):
|
||||
a = []
|
||||
for word in sentence.split(' '):
|
||||
try:
|
||||
index = vocab.word2index[word]
|
||||
a.append(index)
|
||||
except:
|
||||
a.append(vocab.word2index['jest'])
|
||||
return a
|
||||
#return [vocab.word2index[word] for word in sentence.split(' ')]
|
||||
|
||||
def tensor_from_sentece(vocab, sentence):
|
||||
indexes = indexes_from_sentence(vocab, sentence)
|
||||
indexes.append(EOS)
|
||||
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1,1)
|
||||
|
||||
def eval(encoder, decoder, sentence, input_vocab, target_vocab, max_length=MAX_LENGTH):
|
||||
with torch.no_grad():
|
||||
input_tensor = tensor_from_sentece(input_vocab, sentence)
|
||||
input_length = input_tensor.size()[0]
|
||||
|
||||
encoder_hidden = encoder.init_hidden(device)
|
||||
|
||||
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||
|
||||
for i in range(input_length):
|
||||
encoder_output, encoder_hidden = encoder(input_tensor[i], encoder_hidden)
|
||||
encoder_outputs[i] += encoder_output[0, 0]
|
||||
|
||||
decoder_input = torch.tensor([[SOS]], device=device)
|
||||
|
||||
decoder_hidden = encoder_hidden
|
||||
decoder_words = []
|
||||
|
||||
for i in range(max_length):
|
||||
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||
topv, topi = decoder_output.data.topk(1)
|
||||
if topi.item() == EOS:
|
||||
decoder_words.append('<EOS>')
|
||||
break
|
||||
else:
|
||||
decoder_words.append(target_vocab.index2word[topi.item()])
|
||||
decoder_input = topi.squeeze().detach()
|
||||
return decoder_words
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--in_f')
|
||||
parser.add_argument('--out_f')
|
||||
parser.add_argument('--vocab')
|
||||
parser.add_argument('--encoder')
|
||||
parser.add_argument('--decoder')
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.vocab, 'rb') as p:
|
||||
pairs, input_vocab, target_vocab = pickle.load(p)
|
||||
|
||||
hidden_size = 256
|
||||
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
||||
decoder = Decoder(hidden_size,target_vocab.size).to(device)
|
||||
|
||||
encoder.load_state_dict(torch.load(args.encoder))
|
||||
decoder.load_state_dict(torch.load(args.decoder))
|
||||
|
||||
data = read_clear_data(args.in_f)
|
||||
with open(args.out_f, 'w+') as f:
|
||||
for line in data:
|
||||
out = eval(encoder, decoder, line, input_vocab, target_vocab, max_length=MAX_LENGTH)
|
||||
f.write(" ".join(out) + "\n")
|
||||
|
||||
main()
|
Loading…
Reference in New Issue
Block a user