from torch import nn import torch from torch.utils.data import IterableDataset import itertools import lzma import regex as re import pickle class SimpleTrigramNeuralLanguageModel(nn.Module): def __init__(self, vocabulary_size, embedding_size): super(SimpleTrigramNeuralLanguageModel, self).__init__() self.embedings = nn.Embedding(vocabulary_size, embedding_size) self.linear = nn.Linear(embedding_size*2, vocabulary_size) self.softmax = nn.Softmax() # self.model = nn.Sequential( # nn.Embedding(vocabulary_size, embedding_size), # nn.Linear(embedding_size, vocabulary_size), # nn.Softmax() # ) def forward(self, x): emb_1 = self.embedings(x[0]) emb_2 = self.embedings(x[1]) concated = self.linear(torch.cat((emb_1, emb_2), dim=1)) y = self.softmax(concated) return y vocab_size = 20000 embed_size = 100 model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size) model.load_state_dict(torch.load('model1_5400.bin')) model.eval() with open("vocab.pickle", 'rb') as handle: vocab = pickle.load(handle) vocab.set_default_index(vocab['']) device = 'cpu' # data = DataLoader(train_dataset, batch_size=5000) optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.NLLLoss() test_pred = ['ala', 'has', 'cat'] step = 0 with lzma.open('dev-0/in.tsv.xz', 'rb') as file: for line in file: line = line.decode('utf-8') line = line.rstrip() line_splitted = line.split('\t')[-2:] prev = line[0].split(' ')[-1] next = line[1].split(' ')[0] x = torch.tensor(vocab.forward([prev])) z = torch.tensor(vocab.forward([next])) x = x.to(device) z = z.to(device) ypredicted = model([x, z]) top = torch.topk(ypredicted[0], 5000) top_indices = top.indices.tolist() top_probs = top.values.tolist() top_words = vocab.lookup_tokens(top_indices) string_to_print = '' sum_probs = 0 for w, p in zip(top_words, top_probs): if '' in w: continue if re.search(r'\p{L}+', w): string_to_print += f"{w}:{p} " sum_probs += p if string_to_print == '': print(f"the:0.5 a:0.3 :0.2") continue unknow_prob = 1 - sum_probs string_to_print += f":{unknow_prob}" print(string_to_print)