from tqdm import tqdm from collections import Counter import mmap import pickle from math import prod from copy import deepcopy import random class TetragramModel: def __init__(self): self.vocab = None self.ngram_counts = None def get_num_lines(self, filename): fp = open(filename, 'r+') buf = mmap.mmap(fp.fileno(), 0) lines = 0 while buf.readline(): lines += 1 fp.close() return lines def train(self, filename, vocab_size=5000, load_ngrams=None): def get_vocab(filename, vocab_size): file_vocab = Counter() with open(filename, encoding='utf-8') as f: for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating vocab'): line = ' '.join(line.strip().split('\t')[-2:]).replace(r'\n', ' ').split() line_vocab = Counter(line) file_vocab.update(line_vocab) if len(file_vocab) > vocab_size: file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)] else: file_vocab = file_vocab.keys() return file_vocab def get_gram_counts(filename): gram_names = ['unigrams', 'bigrams', 'trigrams', 'tetragrams'] ngram_counts = {name: Counter() for name in gram_names} with open(filename, encoding='utf-8') as f: for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating n-gram counts'): line = line.strip().replace(r'\n', ' ').split('\t')[-2:] for line_part in line: line_part = [word if word in self.vocab else '' for word in line_part.split()] line_unigrams = Counter(line_part) line_bigrams = Counter([tuple(line_part[i: i + 2]) for i in range(len(line_part) - 2 + 1)]) line_trigrams = Counter([tuple(line_part[i: i + 3]) for i in range(len(line_part) - 3 + 1)]) line_tetragrams = Counter([tuple(line_part[i: i + 4]) for i in range(len(line_part) - 4 + 1)]) ngram_counts['unigrams'].update(line_unigrams) ngram_counts['bigrams'].update(line_bigrams) ngram_counts['trigrams'].update(line_trigrams) ngram_counts['tetragrams'].update(line_tetragrams) return ngram_counts self.vocab = get_vocab(filename, vocab_size) if load_ngrams: print('Loading n-gram model from file...') with open(load_ngrams, "rb") as f: self.ngram_counts = pickle.load(f) print('Model loaded.') else: self.ngram_counts = get_gram_counts(filename) def get_ngram_prob(self, ngram): try: if type(ngram) == str: return self.ngram_counts['unigrams'][ngram] / sum(self.ngram_counts['unigrams'].values()) ngram_len = len(ngram) if ngram_len == 2: return self.ngram_counts['bigrams'][ngram] / self.ngram_counts['unigrams'][ngram[0]] elif ngram_len == 3: return self.ngram_counts['trigrams'][ngram] / self.ngram_counts['bigrams'][ngram[:2]] else: return self.ngram_counts['tetragrams'][ngram] / self.ngram_counts['trigrams'][ngram[:3]] except ZeroDivisionError: return 0 def predict_gaps(self, filename, lambdas=(0.25, 0.25, 0.25, 0.25), top_k=False, k=10): with open(filename, encoding='utf-8') as f, open('out.tsv', 'w', encoding='utf-8') as out: for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating gap predictions'): line = line.strip().replace(r'\n', ' ').split('\t')[-2:] left_context = [word if word in self.vocab else '' for word in line[0].split()] right_context = [word if word in self.vocab else '' for word in line[1].split()] context_probs = dict() vocab_keys = deepcopy(list(self.vocab)) vocab_keys.append('') for word in vocab_keys: tetragrams_probs = [] for i in range(4): tetragram = tuple(left_context[-4:][1 + i:] + [word] + right_context[:4][:-4 + i]) unigram_prob = self.get_ngram_prob(tetragram[-1]) bigram_prob = self.get_ngram_prob(tetragram[-2:]) trigram_prob = self.get_ngram_prob(tetragram[-3:]) tetragram_prob = self.get_ngram_prob(tetragram) interpolated_prob = (unigram_prob * lambdas[0]) + (bigram_prob * lambdas[1]) + ( trigram_prob * lambdas[2]) + (tetragram_prob * lambdas[3]) tetragrams_probs.append(interpolated_prob) tetragrams_final_prob = prod(tetragrams_probs) context_probs[word] = tetragrams_final_prob if top_k: sorted_top = sorted(context_probs.items(), key=lambda x: x[1], reverse=True)[:k] probs_sum = sum([y for x, y in sorted_top]) sorted_top = [(x, y / probs_sum) for x, y in sorted_top] probs_string = '' unk_string = '' for tup in sorted_top: if tup[0] == '': unk_string += f':{tup[1]}' else: probs_string += f'{tup[0]}:{tup[1]}\t' probs_string += unk_string probs_string = probs_string.strip() out.write(probs_string + '\n') else: probs_sum = sum(context_probs.values()) unk_prob = context_probs.pop('') / probs_sum context_probs_normalized = [(unigram, prob / probs_sum) for unigram, prob in context_probs.items()] probs_string = '\t'.join([f'{unigram}:{prob}' for unigram, prob in sorted(context_probs_normalized, key=lambda x: x[1], reverse=True)]) if unk_prob > 0: probs_string += f"\t:{unk_prob}" out.write(probs_string + '\n') def generate(self, prompt, length, temperature=0.5): # for simplicity we assume that the prompt is always at least a trigram in terms of length generation = prompt.split() prompt = prompt.split()[-3:] prompt = [word if word in self.vocab else '' for word in prompt] for i in tqdm(range(length), desc=f'Generating text'): next_tri = [(k, v) for k, v in self.ngram_counts['tetragrams'].items() if list(k[:3]) == prompt and k[-1] != ''] if next_tri: top3 = sorted(next_tri, key=lambda x: x[1], reverse=True)[:min(3, len(next_tri))] if random.choice([i for i in range(1, 11)]) > temperature * 10: if len(next_tri) < 3: generation.append(random.choice(list(self.ngram_counts['unigrams'].keys()))) else: generation.append(random.choice(top3)[0][-1]) else: generation.append(top3[0][0][-1]) else: next_bi = [(k, v) for k, v in self.ngram_counts['trigrams'].items() if list(k[:2]) == prompt[-2:] and k[-1] != ''] if next_bi: top3 = sorted(next_bi, key=lambda x: x[1], reverse=True)[:min(3, len(next_bi))] if random.choice([i for i in range(1, 11)]) > temperature * 10: if len(next_bi) < 3: generation.append(random.choice(list(self.ngram_counts['unigrams'].keys()))) else: generation.append(random.choice(top3)[0][-1]) else: generation.append(top3[0][0][-1]) else: next_uni = [(k, v) for k, v in self.ngram_counts['bigrams'].items() if list(k[:1]) == prompt[-1:] and k[-1] != ''] if next_uni: top3 = sorted(next_uni, key=lambda x: x[1], reverse=True)[:min(3, len(next_uni))] if random.choice([i for i in range(1, 11)]) > temperature * 10: if len(next_uni) < 3: generation.append(random.choice(list(self.ngram_counts['unigrams'].keys()))) else: generation.append(random.choice(top3)[0][-1]) else: generation.append(top3[0][0][-1]) else: generation.append(random.choice(list(self.ngram_counts['unigrams'].keys()))) prompt = generation[-3:] print(' '.join(generation)) # model = TetragramModel() # model.train('train.tsv', vocab_size=X) # model.predict_gaps('dev-0/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X)) # model.predict_gaps('test-A/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X)) # model.generate('According to recent news', length=50) # model.generate('Recent studies have shown that', length=50) # model.generate('Today I was taking a stroll in the park when suddenly', length=50) # model.generate('The most unbelievable story ever told goes like this', length=50) # model.generate('The war between', length=50)