178 lines
9.5 KiB
Python
178 lines
9.5 KiB
Python
from tqdm import tqdm
|
|
from collections import Counter
|
|
import mmap
|
|
import pickle
|
|
from math import prod
|
|
from copy import deepcopy
|
|
import random
|
|
|
|
|
|
class TetragramModel:
|
|
def __init__(self):
|
|
self.vocab = None
|
|
self.ngram_counts = None
|
|
|
|
def get_num_lines(self, filename):
|
|
fp = open(filename, 'r+')
|
|
buf = mmap.mmap(fp.fileno(), 0)
|
|
lines = 0
|
|
while buf.readline():
|
|
lines += 1
|
|
fp.close()
|
|
return lines
|
|
|
|
def train(self, filename, vocab_size=5000, load_ngrams=None):
|
|
def get_vocab(filename, vocab_size):
|
|
file_vocab = Counter()
|
|
with open(filename, encoding='utf-8') as f:
|
|
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating vocab'):
|
|
line = ' '.join(line.strip().split('\t')[-2:]).replace(r'\n', ' ').split()
|
|
line_vocab = Counter(line)
|
|
file_vocab.update(line_vocab)
|
|
if len(file_vocab) > vocab_size:
|
|
file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)]
|
|
else:
|
|
file_vocab = file_vocab.keys()
|
|
return file_vocab
|
|
|
|
def get_gram_counts(filename):
|
|
gram_names = ['unigrams', 'bigrams', 'trigrams', 'tetragrams']
|
|
ngram_counts = {name: Counter() for name in gram_names}
|
|
with open(filename, encoding='utf-8') as f:
|
|
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating n-gram counts'):
|
|
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
|
|
for line_part in line:
|
|
line_part = [word if word in self.vocab else '<UNK>' for word in line_part.split()]
|
|
line_unigrams = Counter(line_part)
|
|
line_bigrams = Counter([tuple(line_part[i: i + 2]) for i in range(len(line_part) - 2 + 1)])
|
|
line_trigrams = Counter([tuple(line_part[i: i + 3]) for i in range(len(line_part) - 3 + 1)])
|
|
line_tetragrams = Counter([tuple(line_part[i: i + 4]) for i in range(len(line_part) - 4 + 1)])
|
|
ngram_counts['unigrams'].update(line_unigrams)
|
|
ngram_counts['bigrams'].update(line_bigrams)
|
|
ngram_counts['trigrams'].update(line_trigrams)
|
|
ngram_counts['tetragrams'].update(line_tetragrams)
|
|
return ngram_counts
|
|
|
|
self.vocab = get_vocab(filename, vocab_size)
|
|
if load_ngrams:
|
|
print('Loading n-gram model from file...')
|
|
with open(load_ngrams, "rb") as f:
|
|
self.ngram_counts = pickle.load(f)
|
|
print('Model loaded.')
|
|
else:
|
|
self.ngram_counts = get_gram_counts(filename)
|
|
|
|
def get_ngram_prob(self, ngram):
|
|
try:
|
|
if type(ngram) == str:
|
|
return self.ngram_counts['unigrams'][ngram] / sum(self.ngram_counts['unigrams'].values())
|
|
ngram_len = len(ngram)
|
|
if ngram_len == 2:
|
|
return self.ngram_counts['bigrams'][ngram] / self.ngram_counts['unigrams'][ngram[0]]
|
|
elif ngram_len == 3:
|
|
return self.ngram_counts['trigrams'][ngram] / self.ngram_counts['bigrams'][ngram[:2]]
|
|
else:
|
|
return self.ngram_counts['tetragrams'][ngram] / self.ngram_counts['trigrams'][ngram[:3]]
|
|
except ZeroDivisionError:
|
|
return 0
|
|
|
|
def predict_gaps(self, filename, lambdas=(0.25, 0.25, 0.25, 0.25), top_k=False, k=10):
|
|
with open(filename, encoding='utf-8') as f, open('out.tsv', 'w', encoding='utf-8') as out:
|
|
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating gap predictions'):
|
|
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
|
|
left_context = [word if word in self.vocab else '<UNK>' for word in line[0].split()]
|
|
right_context = [word if word in self.vocab else '<UNK>' for word in line[1].split()]
|
|
context_probs = dict()
|
|
vocab_keys = deepcopy(list(self.vocab))
|
|
vocab_keys.append('<UNK>')
|
|
for word in vocab_keys:
|
|
tetragrams_probs = []
|
|
for i in range(4):
|
|
tetragram = tuple(left_context[-4:][1 + i:] + [word] + right_context[:4][:-4 + i])
|
|
unigram_prob = self.get_ngram_prob(tetragram[-1])
|
|
bigram_prob = self.get_ngram_prob(tetragram[-2:])
|
|
trigram_prob = self.get_ngram_prob(tetragram[-3:])
|
|
tetragram_prob = self.get_ngram_prob(tetragram)
|
|
interpolated_prob = (unigram_prob * lambdas[0]) + (bigram_prob * lambdas[1]) + (
|
|
trigram_prob * lambdas[2]) + (tetragram_prob * lambdas[3])
|
|
tetragrams_probs.append(interpolated_prob)
|
|
tetragrams_final_prob = prod(tetragrams_probs)
|
|
context_probs[word] = tetragrams_final_prob
|
|
if top_k:
|
|
sorted_top = sorted(context_probs.items(), key=lambda x: x[1], reverse=True)[:k]
|
|
probs_sum = sum([y for x, y in sorted_top])
|
|
sorted_top = [(x, y / probs_sum) for x, y in sorted_top]
|
|
probs_string = ''
|
|
unk_string = ''
|
|
for tup in sorted_top:
|
|
if tup[0] == '<UNK>':
|
|
unk_string += f':{tup[1]}'
|
|
else:
|
|
probs_string += f'{tup[0]}:{tup[1]}\t'
|
|
probs_string += unk_string
|
|
probs_string = probs_string.strip()
|
|
out.write(probs_string + '\n')
|
|
else:
|
|
probs_sum = sum(context_probs.values())
|
|
unk_prob = context_probs.pop('<UNK>') / probs_sum
|
|
context_probs_normalized = [(unigram, prob / probs_sum) for unigram, prob in context_probs.items()]
|
|
probs_string = '\t'.join([f'{unigram}:{prob}' for unigram, prob in
|
|
sorted(context_probs_normalized, key=lambda x: x[1], reverse=True)])
|
|
if unk_prob > 0:
|
|
probs_string += f"\t:{unk_prob}"
|
|
out.write(probs_string + '\n')
|
|
|
|
def generate(self, prompt, length, temperature=0.5):
|
|
# for simplicity we assume that the prompt is always at least a trigram in terms of length
|
|
generation = prompt.split()
|
|
prompt = prompt.split()[-3:]
|
|
prompt = [word if word in self.vocab else '<UNK>' for word in prompt]
|
|
for i in tqdm(range(length), desc=f'Generating text'):
|
|
next_tri = [(k, v) for k, v in self.ngram_counts['tetragrams'].items() if list(k[:3]) == prompt and k[-1] != '<UNK>']
|
|
if next_tri:
|
|
top3 = sorted(next_tri, key=lambda x: x[1], reverse=True)[:min(3, len(next_tri))]
|
|
if random.choice([i for i in range(1, 11)]) > temperature * 10:
|
|
if len(next_tri) < 3:
|
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
|
else:
|
|
generation.append(random.choice(top3)[0][-1])
|
|
else:
|
|
generation.append(top3[0][0][-1])
|
|
else:
|
|
next_bi = [(k, v) for k, v in self.ngram_counts['trigrams'].items() if list(k[:2]) == prompt[-2:] and k[-1] != '<UNK>']
|
|
if next_bi:
|
|
top3 = sorted(next_bi, key=lambda x: x[1], reverse=True)[:min(3, len(next_bi))]
|
|
if random.choice([i for i in range(1, 11)]) > temperature * 10:
|
|
if len(next_bi) < 3:
|
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
|
else:
|
|
generation.append(random.choice(top3)[0][-1])
|
|
else:
|
|
generation.append(top3[0][0][-1])
|
|
else:
|
|
next_uni = [(k, v) for k, v in self.ngram_counts['bigrams'].items() if list(k[:1]) == prompt[-1:] and k[-1] != '<UNK>']
|
|
if next_uni:
|
|
top3 = sorted(next_uni, key=lambda x: x[1], reverse=True)[:min(3, len(next_uni))]
|
|
if random.choice([i for i in range(1, 11)]) > temperature * 10:
|
|
if len(next_uni) < 3:
|
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
|
else:
|
|
generation.append(random.choice(top3)[0][-1])
|
|
else:
|
|
generation.append(top3[0][0][-1])
|
|
else:
|
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
|
prompt = generation[-3:]
|
|
print(' '.join(generation))
|
|
|
|
|
|
# model = TetragramModel()
|
|
# model.train('train.tsv', vocab_size=X)
|
|
# model.predict_gaps('dev-0/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X))
|
|
# model.predict_gaps('test-A/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X))
|
|
# model.generate('According to recent news', length=50)
|
|
# model.generate('Recent studies have shown that', length=50)
|
|
# model.generate('Today I was taking a stroll in the park when suddenly', length=50)
|
|
# model.generate('The most unbelievable story ever told goes like this', length=50)
|
|
# model.generate('The war between', length=50)
|