Add preddict
This commit is contained in:
parent
51d468f7aa
commit
e45b0118bd
108
src/predict.py
Normal file
108
src/predict.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import unicodedata
|
||||||
|
import re
|
||||||
|
|
||||||
|
from Decoder import Decoder
|
||||||
|
from Encoder import Encoder
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
SOS=0
|
||||||
|
EOS=1
|
||||||
|
MAX_LENGTH=25
|
||||||
|
|
||||||
|
def clear_line(string):
|
||||||
|
string = ''.join(
|
||||||
|
c for c in unicodedata.normalize('NFD', string)
|
||||||
|
if unicodedata.category(c) != 'Mn'
|
||||||
|
)
|
||||||
|
string = re.sub("[^a-z ]", "", string.lower())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def read_clear_data(in_file_path):
|
||||||
|
print("Reading data...")
|
||||||
|
data = []
|
||||||
|
with open(in_file_path) as in_file:
|
||||||
|
for line in in_file:
|
||||||
|
string = clear_line(line)
|
||||||
|
if len(string.split(' ')) < MAX_LENGTH:
|
||||||
|
data.append(string)
|
||||||
|
else:
|
||||||
|
data.append("To bedzie zle")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def indexes_from_sentence(vocab, sentence):
|
||||||
|
a = []
|
||||||
|
for word in sentence.split(' '):
|
||||||
|
try:
|
||||||
|
index = vocab.word2index[word]
|
||||||
|
a.append(index)
|
||||||
|
except:
|
||||||
|
a.append(vocab.word2index['jest'])
|
||||||
|
return a
|
||||||
|
#return [vocab.word2index[word] for word in sentence.split(' ')]
|
||||||
|
|
||||||
|
def tensor_from_sentece(vocab, sentence):
|
||||||
|
indexes = indexes_from_sentence(vocab, sentence)
|
||||||
|
indexes.append(EOS)
|
||||||
|
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1,1)
|
||||||
|
|
||||||
|
def eval(encoder, decoder, sentence, input_vocab, target_vocab, max_length=MAX_LENGTH):
|
||||||
|
with torch.no_grad():
|
||||||
|
input_tensor = tensor_from_sentece(input_vocab, sentence)
|
||||||
|
input_length = input_tensor.size()[0]
|
||||||
|
|
||||||
|
encoder_hidden = encoder.init_hidden(device)
|
||||||
|
|
||||||
|
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||||
|
|
||||||
|
for i in range(input_length):
|
||||||
|
encoder_output, encoder_hidden = encoder(input_tensor[i], encoder_hidden)
|
||||||
|
encoder_outputs[i] += encoder_output[0, 0]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor([[SOS]], device=device)
|
||||||
|
|
||||||
|
decoder_hidden = encoder_hidden
|
||||||
|
decoder_words = []
|
||||||
|
|
||||||
|
for i in range(max_length):
|
||||||
|
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||||
|
topv, topi = decoder_output.data.topk(1)
|
||||||
|
if topi.item() == EOS:
|
||||||
|
decoder_words.append('<EOS>')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
decoder_words.append(target_vocab.index2word[topi.item()])
|
||||||
|
decoder_input = topi.squeeze().detach()
|
||||||
|
return decoder_words
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--in_f')
|
||||||
|
parser.add_argument('--out_f')
|
||||||
|
parser.add_argument('--vocab')
|
||||||
|
parser.add_argument('--encoder')
|
||||||
|
parser.add_argument('--decoder')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.vocab, 'rb') as p:
|
||||||
|
pairs, input_vocab, target_vocab = pickle.load(p)
|
||||||
|
|
||||||
|
hidden_size = 256
|
||||||
|
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
||||||
|
decoder = Decoder(hidden_size,target_vocab.size).to(device)
|
||||||
|
|
||||||
|
encoder.load_state_dict(torch.load(args.encoder))
|
||||||
|
decoder.load_state_dict(torch.load(args.decoder))
|
||||||
|
|
||||||
|
data = read_clear_data(args.in_f)
|
||||||
|
with open(args.out_f, 'w+') as f:
|
||||||
|
for line in data:
|
||||||
|
out = eval(encoder, decoder, line, input_vocab, target_vocab, max_length=MAX_LENGTH)
|
||||||
|
f.write(" ".join(out) + "\n")
|
||||||
|
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user