2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
import torch
|
|
|
|
from torch import nn, optim
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
import numpy as np
|
|
|
|
from collections import Counter
|
|
|
|
import string
|
|
|
|
import lzma
|
|
|
|
import pdb
|
|
|
|
import copy
|
2023-05-09 21:44:00 +02:00
|
|
|
from torch.utils.data import IterableDataset
|
|
|
|
import itertools
|
|
|
|
import lzma
|
|
|
|
import regex as re
|
|
|
|
import pickle
|
2023-06-04 17:07:15 +02:00
|
|
|
import string
|
|
|
|
import pdb
|
|
|
|
import utils
|
2023-05-09 21:44:00 +02:00
|
|
|
import os
|
2023-06-04 17:07:15 +02:00
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
2023-05-09 21:44:00 +02:00
|
|
|
device = 'cuda'
|
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
vocab_size = utils.vocab_size
|
2023-05-09 21:44:00 +02:00
|
|
|
|
|
|
|
with open("vocab.pickle", 'rb') as handle:
|
2023-06-04 17:07:15 +02:00
|
|
|
vocab = pickle.load( handle)
|
2023-05-09 21:44:00 +02:00
|
|
|
vocab.set_default_index(vocab['<unk>'])
|
2023-06-04 17:07:15 +02:00
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, vocab_size):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.lstm_size = 150
|
|
|
|
self.embedding_dim = 200
|
|
|
|
self.num_layers = 1
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(
|
|
|
|
num_embeddings=vocab_size,
|
|
|
|
embedding_dim=self.embedding_dim,
|
|
|
|
)
|
|
|
|
self.lstm = nn.LSTM(
|
|
|
|
input_size=self.embedding_dim,
|
|
|
|
hidden_size=self.lstm_size,
|
|
|
|
num_layers=self.num_layers,
|
|
|
|
batch_first=True,
|
|
|
|
bidirectional=True,
|
|
|
|
# dropout=0.2,
|
|
|
|
)
|
|
|
|
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
|
|
|
|
|
|
|
|
def forward(self, x, prev_state = None):
|
|
|
|
embed = self.embedding(x)
|
|
|
|
output, state = self.lstm(embed, prev_state)
|
|
|
|
logits = self.fc(output)
|
|
|
|
return logits, state
|
|
|
|
|
|
|
|
def init_state(self, sequence_length):
|
|
|
|
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
|
|
|
|
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
|
|
|
|
|
|
|
|
|
|
|
|
model = Model(vocab_size = vocab_size).to(device)
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load('lstm_step_10000.bin'))
|
|
|
|
model.eval()
|
|
|
|
def predict(model, text_splitted):
|
|
|
|
model.eval()
|
|
|
|
words = text_splitted
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
x = torch.tensor([[vocab[w] for w in words]]).to(device)
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
state_h, state_c = model.init_state(x.size()[0])
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
2023-05-09 21:44:00 +02:00
|
|
|
|
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
last_word_logits = y_pred[0][-1]
|
|
|
|
p = torch.nn.functional.softmax(last_word_logits, dim=0)
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
top = torch.topk(p, 64)
|
|
|
|
top_indices = top.indices.tolist()
|
|
|
|
top_probs = top.values.tolist()
|
|
|
|
top_words = vocab.lookup_tokens(top_indices)
|
|
|
|
return top_words, top_probs
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
inference_result = []
|
|
|
|
with lzma.open(f'test-A/in.tsv.xz', 'r') as file:
|
|
|
|
for line in file:
|
|
|
|
line = line.decode("utf-8")
|
|
|
|
line = line.rstrip()
|
|
|
|
line = line.translate(str.maketrans('', '', string.punctuation))
|
|
|
|
line_splitted_by_tab = line.split('\t')
|
|
|
|
left_context = line_splitted_by_tab[-2]
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
left_context_splitted = list(utils.get_words_from_line(left_context))
|
2023-05-09 21:44:00 +02:00
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
top_words, top_probs = predict(model, left_context_splitted)
|
2023-05-09 21:44:00 +02:00
|
|
|
|
|
|
|
string_to_print = ''
|
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
sum_probs = 0
|
2023-05-09 21:44:00 +02:00
|
|
|
for w, p in zip(top_words, top_probs):
|
2023-06-04 17:07:15 +02:00
|
|
|
# print(top_words)
|
2023-05-09 21:44:00 +02:00
|
|
|
if '<unk>' in w:
|
|
|
|
continue
|
2023-06-04 17:07:15 +02:00
|
|
|
string_to_print += f"{w}:{p} "
|
|
|
|
sum_probs += p
|
|
|
|
|
2023-05-09 21:44:00 +02:00
|
|
|
if string_to_print == '':
|
2023-06-04 17:07:15 +02:00
|
|
|
inference_result.append("the:0.2 a:0.3 :0.5")
|
2023-05-09 21:44:00 +02:00
|
|
|
continue
|
|
|
|
unknow_prob = 1 - sum_probs
|
|
|
|
string_to_print += f":{unknow_prob}"
|
|
|
|
|
2023-06-04 17:07:15 +02:00
|
|
|
inference_result.append(string_to_print)
|
|
|
|
|
|
|
|
with open('test-A/out.tsv', 'w') as f:
|
|
|
|
for line in inference_result:
|
|
|
|
f.write(line+'\n')
|
|
|
|
print('All done')
|