From bb121718aaf7467256fb5c899e4fed4a5f8e8e00 Mon Sep 17 00:00:00 2001 From: Kacper Dudzic Date: Sat, 15 Apr 2023 02:50:00 +0200 Subject: [PATCH] Delete 'bigram_model.py' --- bigram_model.py | 83 ------------------------------------------------- 1 file changed, 83 deletions(-) delete mode 100644 bigram_model.py diff --git a/bigram_model.py b/bigram_model.py deleted file mode 100644 index f0cdfec..0000000 --- a/bigram_model.py +++ /dev/null @@ -1,83 +0,0 @@ -from tqdm import tqdm -from collections import Counter -import mmap - - -class BigramModel: - def __init__(self): - self.vocab = None - self.unigram_counts = None - self.bigram_counts = None - - def get_num_lines(self, filename): - fp = open(filename, 'r+') - buf = mmap.mmap(fp.fileno(), 0) - lines = 0 - while buf.readline(): - lines += 1 - fp.close() - return lines - - def train(self, filename, vocab_size=5000): - def get_vocab(filename, vocab_size): - file_vocab = Counter() - with open(filename, encoding='utf-8') as f: - for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating vocab'): - line = ' '.join(line.strip().split('\t')[-2:]).replace(r'\n', ' ').split() - line_vocab = Counter(line) - file_vocab.update(line_vocab) - if len(file_vocab) > vocab_size: - file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)] - else: - file_vocab = file_vocab.keys() - return file_vocab - - def get_gram_counts(filename): - file_unigram_counts = Counter() - file_bigram_counts = Counter() - with open(filename, encoding='utf-8') as f: - for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating unigram and bigram counts'): - line = line.strip().replace(r'\n', ' ').split('\t')[-2:] - line_unigram_counts = Counter(' '.join(line).split()) - file_unigram_counts.update(line_unigram_counts) - line_left, line_right = line[0].split(), line[1].split() - line_bigram_counts_left = Counter([tuple(line_left[i: i + 2]) for i in range(len(line_left) - 2 + 1)]) - line_bigram_counts_right = Counter([tuple(line_right[i: i + 2]) for i in range(len(line_right) - 2 + 1)]) - file_bigram_counts.update(line_bigram_counts_left) - file_bigram_counts.update(line_bigram_counts_right) - return file_unigram_counts, file_bigram_counts - - self.vocab = get_vocab(filename, vocab_size) - self.unigram_counts, self.bigram_counts = get_gram_counts(filename) - - def get_bigram_prob(self, bigram, smoothing): - if smoothing: - return (self.bigram_counts.get(bigram, 0) + 1) / (self.unigram_counts.get(bigram[0], 0) + len(self.vocab) + 1) - else: - return self.bigram_counts.get(bigram, 0) / self.unigram_counts.get(bigram[0], 1) - - def predict_gaps(self, filename, smoothing=True): - with open(filename, encoding='utf-8') as f, open('out.tsv', 'w', encoding='utf-8') as out: - for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating gap predictions'): - line = line.strip().replace(r'\n', ' ').split('\t')[-2:] - left_context, right_context = line[0].split()[-1], line[1].split()[0] - context_probs = dict() - for word in self.vocab: - left_context_prob = self.get_bigram_prob((left_context, word), smoothing) - right_context_prob = self.get_bigram_prob((word, right_context), smoothing) - #context_probs[word] = left_context_prob * right_context_prob - context_probs[word] = max(left_context_prob, right_context_prob) - vocab_prob = sum(context_probs.values()) - probs_string = '\t'.join([f'{unigram}:{prob}' for unigram, prob in sorted(context_probs.items(), key=lambda x: x[1], reverse=True) if prob > 0]) - remaining_prob = 1.0 - vocab_prob - if remaining_prob > 0: - probs_string += f'\t:{remaining_prob}\n' - else: - probs_string += '\n' - out.write(probs_string) - - -if __name__ == '__main__': - model = BigramModel() - model.train('train.tsv', vocab_size=5000) - model.predict_gaps('in.tsv', smoothing=True)