challenging-america-word-ga.../lab5.py

178 lines
9.5 KiB
Python
Raw Normal View History

2023-04-15 02:48:34 +02:00
from tqdm import tqdm
from collections import Counter
import mmap
import pickle
from math import prod
from copy import deepcopy
import random
class TetragramModel:
def __init__(self):
self.vocab = None
self.ngram_counts = None
def get_num_lines(self, filename):
fp = open(filename, 'r+')
buf = mmap.mmap(fp.fileno(), 0)
lines = 0
while buf.readline():
lines += 1
fp.close()
return lines
def train(self, filename, vocab_size=5000, load_ngrams=None):
def get_vocab(filename, vocab_size):
file_vocab = Counter()
with open(filename, encoding='utf-8') as f:
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating vocab'):
line = ' '.join(line.strip().split('\t')[-2:]).replace(r'\n', ' ').split()
line_vocab = Counter(line)
file_vocab.update(line_vocab)
if len(file_vocab) > vocab_size:
file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)]
else:
file_vocab = file_vocab.keys()
return file_vocab
def get_gram_counts(filename):
gram_names = ['unigrams', 'bigrams', 'trigrams', 'tetragrams']
ngram_counts = {name: Counter() for name in gram_names}
with open(filename, encoding='utf-8') as f:
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating n-gram counts'):
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
for line_part in line:
line_part = [word if word in self.vocab else '<UNK>' for word in line_part.split()]
line_unigrams = Counter(line_part)
line_bigrams = Counter([tuple(line_part[i: i + 2]) for i in range(len(line_part) - 2 + 1)])
line_trigrams = Counter([tuple(line_part[i: i + 3]) for i in range(len(line_part) - 3 + 1)])
line_tetragrams = Counter([tuple(line_part[i: i + 4]) for i in range(len(line_part) - 4 + 1)])
ngram_counts['unigrams'].update(line_unigrams)
ngram_counts['bigrams'].update(line_bigrams)
ngram_counts['trigrams'].update(line_trigrams)
ngram_counts['tetragrams'].update(line_tetragrams)
return ngram_counts
self.vocab = get_vocab(filename, vocab_size)
if load_ngrams:
print('Loading n-gram model from file...')
with open(load_ngrams, "rb") as f:
self.ngram_counts = pickle.load(f)
print('Model loaded.')
else:
self.ngram_counts = get_gram_counts(filename)
def get_ngram_prob(self, ngram):
try:
if type(ngram) == str:
return self.ngram_counts['unigrams'][ngram] / sum(self.ngram_counts['unigrams'].values())
ngram_len = len(ngram)
if ngram_len == 2:
return self.ngram_counts['bigrams'][ngram] / self.ngram_counts['unigrams'][ngram[0]]
elif ngram_len == 3:
return self.ngram_counts['trigrams'][ngram] / self.ngram_counts['bigrams'][ngram[:2]]
else:
return self.ngram_counts['tetragrams'][ngram] / self.ngram_counts['trigrams'][ngram[:3]]
except ZeroDivisionError:
return 0
def predict_gaps(self, filename, lambdas=(0.25, 0.25, 0.25, 0.25), top_k=False, k=10):
with open(filename, encoding='utf-8') as f, open('out.tsv', 'w', encoding='utf-8') as out:
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating gap predictions'):
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
left_context = [word if word in self.vocab else '<UNK>' for word in line[0].split()]
right_context = [word if word in self.vocab else '<UNK>' for word in line[1].split()]
context_probs = dict()
vocab_keys = deepcopy(list(self.vocab))
vocab_keys.append('<UNK>')
for word in vocab_keys:
tetragrams_probs = []
for i in range(4):
tetragram = tuple(left_context[-4:][1 + i:] + [word] + right_context[:4][:-4 + i])
unigram_prob = self.get_ngram_prob(tetragram[-1])
bigram_prob = self.get_ngram_prob(tetragram[-2:])
trigram_prob = self.get_ngram_prob(tetragram[-3:])
tetragram_prob = self.get_ngram_prob(tetragram)
interpolated_prob = (unigram_prob * lambdas[0]) + (bigram_prob * lambdas[1]) + (
trigram_prob * lambdas[2]) + (tetragram_prob * lambdas[3])
tetragrams_probs.append(interpolated_prob)
tetragrams_final_prob = prod(tetragrams_probs)
context_probs[word] = tetragrams_final_prob
if top_k:
sorted_top = sorted(context_probs.items(), key=lambda x: x[1], reverse=True)[:k]
probs_sum = sum([y for x, y in sorted_top])
sorted_top = [(x, y / probs_sum) for x, y in sorted_top]
probs_string = ''
unk_string = ''
for tup in sorted_top:
if tup[0] == '<UNK>':
unk_string += f':{tup[1]}'
else:
probs_string += f'{tup[0]}:{tup[1]}\t'
probs_string += unk_string
probs_string = probs_string.strip()
out.write(probs_string + '\n')
else:
probs_sum = sum(context_probs.values())
unk_prob = context_probs.pop('<UNK>') / probs_sum
context_probs_normalized = [(unigram, prob / probs_sum) for unigram, prob in context_probs.items()]
probs_string = '\t'.join([f'{unigram}:{prob}' for unigram, prob in
sorted(context_probs_normalized, key=lambda x: x[1], reverse=True)])
if unk_prob > 0:
probs_string += f"\t:{unk_prob}"
out.write(probs_string + '\n')
def generate(self, prompt, length, temperature=0.5):
# for simplicity we assume that the prompt is always at least a trigram in terms of length
generation = prompt.split()
prompt = prompt.split()[-3:]
prompt = [word if word in self.vocab else '<UNK>' for word in prompt]
for i in tqdm(range(length), desc=f'Generating text'):
next_tri = [(k, v) for k, v in self.ngram_counts['tetragrams'].items() if list(k[:3]) == prompt and k[-1] != '<UNK>']
if next_tri:
top3 = sorted(next_tri, key=lambda x: x[1], reverse=True)[:min(3, len(next_tri))]
if random.choice([i for i in range(1, 11)]) > temperature * 10:
if len(next_tri) < 3:
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
else:
generation.append(random.choice(top3)[0][-1])
else:
generation.append(top3[0][0][-1])
else:
next_bi = [(k, v) for k, v in self.ngram_counts['trigrams'].items() if list(k[:2]) == prompt[-2:] and k[-1] != '<UNK>']
if next_bi:
top3 = sorted(next_bi, key=lambda x: x[1], reverse=True)[:min(3, len(next_bi))]
if random.choice([i for i in range(1, 11)]) > temperature * 10:
if len(next_bi) < 3:
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
else:
generation.append(random.choice(top3)[0][-1])
else:
generation.append(top3[0][0][-1])
else:
next_uni = [(k, v) for k, v in self.ngram_counts['bigrams'].items() if list(k[:1]) == prompt[-1:] and k[-1] != '<UNK>']
if next_uni:
top3 = sorted(next_uni, key=lambda x: x[1], reverse=True)[:min(3, len(next_uni))]
if random.choice([i for i in range(1, 11)]) > temperature * 10:
if len(next_uni) < 3:
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
else:
generation.append(random.choice(top3)[0][-1])
else:
generation.append(top3[0][0][-1])
else:
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
prompt = generation[-3:]
print(' '.join(generation))
# model = TetragramModel()
# model.train('train.tsv', vocab_size=X)
# model.predict_gaps('dev-0/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X))
# model.predict_gaps('test-A/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X))
# model.generate('According to recent news', length=50)
# model.generate('Recent studies have shown that', length=50)
# model.generate('Today I was taking a stroll in the park when suddenly', length=50)
# model.generate('The most unbelievable story ever told goes like this', length=50)
# model.generate('The war between', length=50)