challenging-america-word-ga.../bigram-neural/model-predict.py

132 lines
3.8 KiB
Python
Raw Permalink Normal View History

2023-04-26 15:00:18 +02:00
import pickle
from torch.utils.data import IterableDataset
import itertools
from torch import nn
import torch
import lzma
from torch.utils.data import DataLoader
import pandas as pd
import tqdm
import regex as re
from nltk import word_tokenize
import csv
import nltk
vocabulary_size = 20000
most_common_en_word = "the:0.4 be:0.2 to:0.1 of:0.05 and:0.025 a:0.0125 :0.2125"
nltk.download("punkt")
vocab = None
with open('vocabulary.pickle', 'rb') as handle:
vocab = pickle.load(handle)
def look_ahead_iterator(gen):
prev = None
for item in gen:
if prev is not None:
yield (prev, item)
prev = item
def get_words_from_line(line):
line = line.rstrip()
yield '<s>'
for t in line.split(' '):
yield t
yield '</s>'
def get_word_lines_from_file(file_name):
with lzma.open(file_name, 'r') as fh:
for line in fh:
yield get_words_from_line(line.decode('utf-8'))
class Bigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = vocab
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = vocabulary_size
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator(
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))
train_dataset = Bigrams('train/in.tsv.xz', vocabulary_size)
# print(next(iter(train_dataset)))
#
# print(vocab.lookup_tokens([23, 0]))
embed_size = 100
class SimpleBigramNeuralLanguageModel(nn.Module):
def __init__(self, vocabulary_size, embedding_size):
super(SimpleBigramNeuralLanguageModel, self).__init__()
self.model = nn.Sequential(
nn.Embedding(vocabulary_size, embedding_size),
nn.Linear(embedding_size, vocabulary_size),
nn.Softmax()
)
def forward(self, x):
return self.model(x)
device = 'cuda'
model = SimpleBigramNeuralLanguageModel(vocabulary_size, embed_size).to(device)
model.load_state_dict(torch.load('model1.bin'))
model.eval()
def predict_probs(word1):
ixs = torch.tensor(vocab.forward([word1])).to(device)
out = model(ixs)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
result_model = (list(zip(top_words, top_indices, top_probs)))
n_best = 5 # choose the top 5 predictions
# Remove any <unk> tokens from the predictions
unk_prob = 0
new_predictions = []
for pred in result_model:
if pred[0] == '<unk>':
unk_prob = pred[2]
else:
new_predictions.append(pred)
# Sort the predictions by probability and choose the top n
top_n = new_predictions[:n_best]
# Format the predictions as a string
output_str = ''
for i, pred in enumerate(top_n):
output_str += pred[0] + ':' + str(round(pred[2], 3)) + ' '
output_str += ':{}'.format(round(1 - sum([pred[2] for pred in top_n]) - unk_prob, 3))
return output_str
def prepare_text(text):
text = text.lower().replace("-\\n", "").replace("\\n", " ")
text = re.sub(r"\p{P}", "", text)
return text
def predict_file(file):
data = pd.read_csv(f'{file}/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open(f'{file}/out.tsv', 'w', encoding='utf-8') as file_out:
for _, row in tqdm.tqdm(data.iterrows()):
before = word_tokenize(prepare_text(str(row[6])))
if len(before) < 2:
prediction = most_common_en_word
else:
prediction = predict_probs(before[-1])
file_out.write(prediction + '\n')
predict_file('dev-0')
predict_file('test-A')