challenging-america-word-ga.../inference.py
2023-06-04 17:07:15 +02:00

118 lines
3.3 KiB
Python

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import string
import lzma
import pdb
import copy
from torch.utils.data import IterableDataset
import itertools
import lzma
import regex as re
import pickle
import string
import pdb
import utils
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda'
vocab_size = utils.vocab_size
with open("vocab.pickle", 'rb') as handle:
vocab = pickle.load( handle)
vocab.set_default_index(vocab['<unk>'])
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 150
self.embedding_dim = 200
self.num_layers = 1
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.embedding_dim,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
batch_first=True,
bidirectional=True,
# dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
model = Model(vocab_size = vocab_size).to(device)
model.load_state_dict(torch.load('lstm_step_10000.bin'))
model.eval()
def predict(model, text_splitted):
model.eval()
words = text_splitted
x = torch.tensor([[vocab[w] for w in words]]).to(device)
state_h, state_c = model.init_state(x.size()[0])
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0)
top = torch.topk(p, 64)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
return top_words, top_probs
inference_result = []
with lzma.open(f'test-A/in.tsv.xz', 'r') as file:
for line in file:
line = line.decode("utf-8")
line = line.rstrip()
line = line.translate(str.maketrans('', '', string.punctuation))
line_splitted_by_tab = line.split('\t')
left_context = line_splitted_by_tab[-2]
left_context_splitted = list(utils.get_words_from_line(left_context))
top_words, top_probs = predict(model, left_context_splitted)
string_to_print = ''
sum_probs = 0
for w, p in zip(top_words, top_probs):
# print(top_words)
if '<unk>' in w:
continue
string_to_print += f"{w}:{p} "
sum_probs += p
if string_to_print == '':
inference_result.append("the:0.2 a:0.3 :0.5")
continue
unknow_prob = 1 - sum_probs
string_to_print += f":{unknow_prob}"
inference_result.append(string_to_print)
with open('test-A/out.tsv', 'w') as f:
for line in inference_result:
f.write(line+'\n')
print('All done')