from torch import nn import torch from torch.utils.data import IterableDataset import itertools import lzma import regex as re import pickle import scripts import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" class SimpleTrigramNeuralLanguageModel(nn.Module): def __init__(self, vocabulary_size, embedding_size): super(SimpleTrigramNeuralLanguageModel, self).__init__() self.embedings = nn.Embedding(vocabulary_size, embedding_size) self.linear = nn.Linear(embedding_size*2, vocabulary_size) self.linear_first_layer = nn.Linear(embedding_size*2, embedding_size*2) self.relu = nn.ReLU() self.softmax = nn.Softmax() # self.model = nn.Sequential( # nn.Embedding(vocabulary_size, embedding_size), # nn.Linear(embedding_size, vocabulary_size), # nn.Softmax() # ) def forward(self, x): emb_1 = self.embedings(x[0]) emb_2 = self.embedings(x[1]) first_layer = self.linear_first_layer(torch.cat((emb_1, emb_2), dim=1)) after_relu = self.relu(first_layer) concated = self.linear(after_relu) y = self.softmax(concated) return y vocab_size = scripts.vocab_size embed_size = 100 device = 'cuda' model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device) model.load_state_dict(torch.load('batch_model_epoch_0.bin')) model.eval() with open("vocab.pickle", 'rb') as handle: vocab = pickle.load(handle) vocab.set_default_index(vocab['']) step = 0 with lzma.open('dev-0/in.tsv.xz', 'rb') as file: for line in file: line = line.decode('utf-8') line = line.rstrip() # line = line.lower() line = line.replace("\\\\n", ' ') line_splitted = line.split('\t')[-2:] prev = list(scripts.get_words_from_line(line_splitted[0]))[-1] next = list(scripts.get_words_from_line(line_splitted[1]))[0] # prev = line[0].split(' ')[-1] # next = line[1].split(' ')[0] x = torch.tensor(vocab.forward([prev])) z = torch.tensor(vocab.forward([next])) x = x.to(device) z = z.to(device) ypredicted = model([x, z]) try: top = torch.topk(ypredicted[0], 128) except: print(ypredicted[0]) raise Exception('aa') top_indices = top.indices.tolist() top_probs = top.values.tolist() top_words = vocab.lookup_tokens(top_indices) string_to_print = '' sum_probs = 0 for w, p in zip(top_words, top_probs): if '' in w: continue if re.search(r'\p{L}+', w): string_to_print += f"{w}:{p} " sum_probs += p if string_to_print == '': print(f"the:0.2 a:0.3 :0.5") continue unknow_prob = 1 - sum_probs string_to_print += f":{unknow_prob}" print(string_to_print)