import torch from torch import nn, optim from torch.utils.data import DataLoader import numpy as np from collections import Counter import string import lzma import pdb import copy from torch.utils.data import IterableDataset import itertools import lzma import regex as re import pickle import string import pdb import utils import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" device = 'cuda' vocab_size = utils.vocab_size with open("vocab.pickle", 'rb') as handle: vocab = pickle.load( handle) vocab.set_default_index(vocab['']) class Model(nn.Module): def __init__(self, vocab_size): super(Model, self).__init__() self.lstm_size = 150 self.embedding_dim = 200 self.num_layers = 1 self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=self.embedding_dim, ) self.lstm = nn.LSTM( input_size=self.embedding_dim, hidden_size=self.lstm_size, num_layers=self.num_layers, batch_first=True, bidirectional=True, # dropout=0.2, ) self.fc = nn.Linear(self.lstm_size*2, vocab_size) def forward(self, x, prev_state = None): embed = self.embedding(x) output, state = self.lstm(embed, prev_state) logits = self.fc(output) return logits, state def init_state(self, sequence_length): return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device), torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device)) model = Model(vocab_size = vocab_size).to(device) model.load_state_dict(torch.load('lstm_step_10000.bin')) model.eval() def predict(model, text_splitted): model.eval() words = text_splitted x = torch.tensor([[vocab[w] for w in words]]).to(device) state_h, state_c = model.init_state(x.size()[0]) y_pred, (state_h, state_c) = model(x, (state_h, state_c)) last_word_logits = y_pred[0][-1] p = torch.nn.functional.softmax(last_word_logits, dim=0) top = torch.topk(p, 64) top_indices = top.indices.tolist() top_probs = top.values.tolist() top_words = vocab.lookup_tokens(top_indices) return top_words, top_probs inference_result = [] with lzma.open(f'test-A/in.tsv.xz', 'r') as file: for line in file: line = line.decode("utf-8") line = line.rstrip() line = line.translate(str.maketrans('', '', string.punctuation)) line_splitted_by_tab = line.split('\t') left_context = line_splitted_by_tab[-2] left_context_splitted = list(utils.get_words_from_line(left_context)) top_words, top_probs = predict(model, left_context_splitted) string_to_print = '' sum_probs = 0 for w, p in zip(top_words, top_probs): # print(top_words) if '' in w: continue string_to_print += f"{w}:{p} " sum_probs += p if string_to_print == '': inference_result.append("the:0.2 a:0.3 :0.5") continue unknow_prob = 1 - sum_probs string_to_print += f":{unknow_prob}" inference_result.append(string_to_print) with open('test-A/out.tsv', 'w') as f: for line in inference_result: f.write(line+'\n') print('All done')