Delete 'inference.py'
This commit is contained in:
parent
d713b4a83f
commit
4b3fb1c333
107
inference.py
107
inference.py
@ -1,107 +0,0 @@
|
|||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
from torch.utils.data import IterableDataset
|
|
||||||
import itertools
|
|
||||||
import lzma
|
|
||||||
import regex as re
|
|
||||||
import pickle
|
|
||||||
import scripts
|
|
||||||
import os
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
||||||
|
|
||||||
class SimpleTrigramNeuralLanguageModel(nn.Module):
|
|
||||||
def __init__(self, vocabulary_size, embedding_size):
|
|
||||||
super(SimpleTrigramNeuralLanguageModel, self).__init__()
|
|
||||||
self.embedings = nn.Embedding(vocabulary_size, embedding_size)
|
|
||||||
self.linear = nn.Linear(embedding_size*2, vocabulary_size)
|
|
||||||
|
|
||||||
self.linear_first_layer = nn.Linear(embedding_size*2, embedding_size*2)
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
self.softmax = nn.Softmax()
|
|
||||||
|
|
||||||
# self.model = nn.Sequential(
|
|
||||||
# nn.Embedding(vocabulary_size, embedding_size),
|
|
||||||
# nn.Linear(embedding_size, vocabulary_size),
|
|
||||||
# nn.Softmax()
|
|
||||||
# )
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
emb_1 = self.embedings(x[0])
|
|
||||||
emb_2 = self.embedings(x[1])
|
|
||||||
|
|
||||||
first_layer = self.linear_first_layer(torch.cat((emb_1, emb_2), dim=1))
|
|
||||||
after_relu = self.relu(first_layer)
|
|
||||||
concated = self.linear(after_relu)
|
|
||||||
|
|
||||||
y = self.softmax(concated)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
vocab_size = scripts.vocab_size
|
|
||||||
embed_size = 100
|
|
||||||
device = 'cuda'
|
|
||||||
|
|
||||||
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device)
|
|
||||||
|
|
||||||
model.load_state_dict(torch.load('batch_model_epoch_0.bin'))
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
with open("vocab.pickle", 'rb') as handle:
|
|
||||||
vocab = pickle.load(handle)
|
|
||||||
vocab.set_default_index(vocab['<unk>'])
|
|
||||||
|
|
||||||
|
|
||||||
step = 0
|
|
||||||
|
|
||||||
|
|
||||||
with lzma.open('dev-0/in.tsv.xz', 'rb') as file:
|
|
||||||
for line in file:
|
|
||||||
line = line.decode('utf-8')
|
|
||||||
line = line.rstrip()
|
|
||||||
# line = line.lower()
|
|
||||||
line = line.replace("\\\\n", ' ')
|
|
||||||
|
|
||||||
|
|
||||||
line_splitted = line.split('\t')[-2:]
|
|
||||||
|
|
||||||
prev = list(scripts.get_words_from_line(line_splitted[0]))[-1]
|
|
||||||
next = list(scripts.get_words_from_line(line_splitted[1]))[0]
|
|
||||||
|
|
||||||
# prev = line[0].split(' ')[-1]
|
|
||||||
# next = line[1].split(' ')[0]
|
|
||||||
|
|
||||||
|
|
||||||
x = torch.tensor(vocab.forward([prev]))
|
|
||||||
z = torch.tensor(vocab.forward([next]))
|
|
||||||
x = x.to(device)
|
|
||||||
z = z.to(device)
|
|
||||||
ypredicted = model([x, z])
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
top = torch.topk(ypredicted[0], 128)
|
|
||||||
except:
|
|
||||||
print(ypredicted[0])
|
|
||||||
raise Exception('aa')
|
|
||||||
top_indices = top.indices.tolist()
|
|
||||||
top_probs = top.values.tolist()
|
|
||||||
top_words = vocab.lookup_tokens(top_indices)
|
|
||||||
|
|
||||||
string_to_print = ''
|
|
||||||
sum_probs = 0
|
|
||||||
|
|
||||||
for w, p in zip(top_words, top_probs):
|
|
||||||
if '<unk>' in w:
|
|
||||||
continue
|
|
||||||
if re.search(r'\p{L}+', w):
|
|
||||||
string_to_print += f"{w}:{p} "
|
|
||||||
sum_probs += p
|
|
||||||
if string_to_print == '':
|
|
||||||
print(f"the:0.2 a:0.3 :0.5")
|
|
||||||
continue
|
|
||||||
unknow_prob = 1 - sum_probs
|
|
||||||
string_to_print += f":{unknow_prob}"
|
|
||||||
|
|
||||||
print(string_to_print)
|
|
Loading…
Reference in New Issue
Block a user