challenging-america-word-ga.../inference.py

118 lines
3.3 KiB
Python
Raw Normal View History

2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import string
import lzma
import pdb
import copy
2023-05-09 21:44:00 +02:00
from torch.utils.data import IterableDataset
import itertools
import lzma
import regex as re
import pickle
2023-06-04 17:07:15 +02:00
import string
import pdb
import utils
2023-05-09 21:44:00 +02:00
import os
2023-06-04 17:07:15 +02:00
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
2023-05-09 21:44:00 +02:00
device = 'cuda'
2023-06-04 17:07:15 +02:00
vocab_size = utils.vocab_size
2023-05-09 21:44:00 +02:00
with open("vocab.pickle", 'rb') as handle:
2023-06-04 17:07:15 +02:00
vocab = pickle.load( handle)
2023-05-09 21:44:00 +02:00
vocab.set_default_index(vocab['<unk>'])
2023-06-04 17:07:15 +02:00
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 150
self.embedding_dim = 200
self.num_layers = 1
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.embedding_dim,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
batch_first=True,
bidirectional=True,
# dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
model = Model(vocab_size = vocab_size).to(device)
model.load_state_dict(torch.load('lstm_step_10000.bin'))
model.eval()
def predict(model, text_splitted):
model.eval()
words = text_splitted
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
x = torch.tensor([[vocab[w] for w in words]]).to(device)
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
state_h, state_c = model.init_state(x.size()[0])
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0)
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
top = torch.topk(p, 64)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
return top_words, top_probs
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
inference_result = []
with lzma.open(f'test-A/in.tsv.xz', 'r') as file:
for line in file:
line = line.decode("utf-8")
line = line.rstrip()
line = line.translate(str.maketrans('', '', string.punctuation))
line_splitted_by_tab = line.split('\t')
left_context = line_splitted_by_tab[-2]
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
left_context_splitted = list(utils.get_words_from_line(left_context))
2023-05-09 21:44:00 +02:00
2023-06-04 17:07:15 +02:00
top_words, top_probs = predict(model, left_context_splitted)
2023-05-09 21:44:00 +02:00
string_to_print = ''
2023-06-04 17:07:15 +02:00
sum_probs = 0
2023-05-09 21:44:00 +02:00
for w, p in zip(top_words, top_probs):
2023-06-04 17:07:15 +02:00
# print(top_words)
2023-05-09 21:44:00 +02:00
if '<unk>' in w:
continue
2023-06-04 17:07:15 +02:00
string_to_print += f"{w}:{p} "
sum_probs += p
2023-05-09 21:44:00 +02:00
if string_to_print == '':
2023-06-04 17:07:15 +02:00
inference_result.append("the:0.2 a:0.3 :0.5")
2023-05-09 21:44:00 +02:00
continue
unknow_prob = 1 - sum_probs
string_to_print += f":{unknow_prob}"
2023-06-04 17:07:15 +02:00
inference_result.append(string_to_print)
with open('test-A/out.tsv', 'w') as f:
for line in inference_result:
f.write(line+'\n')
print('All done')