neural_word_gap/inference.py

89 lines
2.4 KiB
Python

from torch import nn
import torch
from torch.utils.data import IterableDataset
import itertools
import lzma
import regex as re
import pickle
class SimpleTrigramNeuralLanguageModel(nn.Module):
def __init__(self, vocabulary_size, embedding_size):
super(SimpleTrigramNeuralLanguageModel, self).__init__()
self.embedings = nn.Embedding(vocabulary_size, embedding_size)
self.linear = nn.Linear(embedding_size*2, vocabulary_size)
self.softmax = nn.Softmax()
# self.model = nn.Sequential(
# nn.Embedding(vocabulary_size, embedding_size),
# nn.Linear(embedding_size, vocabulary_size),
# nn.Softmax()
# )
def forward(self, x):
emb_1 = self.embedings(x[0])
emb_2 = self.embedings(x[1])
concated = self.linear(torch.cat((emb_1, emb_2), dim=1))
y = self.softmax(concated)
return y
vocab_size = 20000
embed_size = 100
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size)
model.load_state_dict(torch.load('model1_5400.bin'))
model.eval()
with open("vocab.pickle", 'rb') as handle:
vocab = pickle.load(handle)
vocab.set_default_index(vocab['<unk>'])
device = 'cpu'
# data = DataLoader(train_dataset, batch_size=5000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
test_pred = ['ala', 'has', 'cat']
step = 0
with lzma.open('dev-0/in.tsv.xz', 'rb') as file:
for line in file:
line = line.decode('utf-8')
line = line.rstrip()
line_splitted = line.split('\t')[-2:]
prev = line[0].split(' ')[-1]
next = line[1].split(' ')[0]
x = torch.tensor(vocab.forward([prev]))
z = torch.tensor(vocab.forward([next]))
x = x.to(device)
z = z.to(device)
ypredicted = model([x, z])
top = torch.topk(ypredicted[0], 5000)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
string_to_print = ''
sum_probs = 0
for w, p in zip(top_words, top_probs):
if '<unk>' in w:
continue
if re.search(r'\p{L}+', w):
string_to_print += f"{w}:{p} "
sum_probs += p
if string_to_print == '':
print(f"the:0.5 a:0.3 :0.2")
continue
unknow_prob = 1 - sum_probs
string_to_print += f":{unknow_prob}"
print(string_to_print)