Delete 'bigram_model.py'
This commit is contained in:
parent
9332c1957b
commit
bb121718aa
@ -1,83 +0,0 @@
|
||||
from tqdm import tqdm
|
||||
from collections import Counter
|
||||
import mmap
|
||||
|
||||
|
||||
class BigramModel:
|
||||
def __init__(self):
|
||||
self.vocab = None
|
||||
self.unigram_counts = None
|
||||
self.bigram_counts = None
|
||||
|
||||
def get_num_lines(self, filename):
|
||||
fp = open(filename, 'r+')
|
||||
buf = mmap.mmap(fp.fileno(), 0)
|
||||
lines = 0
|
||||
while buf.readline():
|
||||
lines += 1
|
||||
fp.close()
|
||||
return lines
|
||||
|
||||
def train(self, filename, vocab_size=5000):
|
||||
def get_vocab(filename, vocab_size):
|
||||
file_vocab = Counter()
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating vocab'):
|
||||
line = ' '.join(line.strip().split('\t')[-2:]).replace(r'\n', ' ').split()
|
||||
line_vocab = Counter(line)
|
||||
file_vocab.update(line_vocab)
|
||||
if len(file_vocab) > vocab_size:
|
||||
file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)]
|
||||
else:
|
||||
file_vocab = file_vocab.keys()
|
||||
return file_vocab
|
||||
|
||||
def get_gram_counts(filename):
|
||||
file_unigram_counts = Counter()
|
||||
file_bigram_counts = Counter()
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating unigram and bigram counts'):
|
||||
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
|
||||
line_unigram_counts = Counter(' '.join(line).split())
|
||||
file_unigram_counts.update(line_unigram_counts)
|
||||
line_left, line_right = line[0].split(), line[1].split()
|
||||
line_bigram_counts_left = Counter([tuple(line_left[i: i + 2]) for i in range(len(line_left) - 2 + 1)])
|
||||
line_bigram_counts_right = Counter([tuple(line_right[i: i + 2]) for i in range(len(line_right) - 2 + 1)])
|
||||
file_bigram_counts.update(line_bigram_counts_left)
|
||||
file_bigram_counts.update(line_bigram_counts_right)
|
||||
return file_unigram_counts, file_bigram_counts
|
||||
|
||||
self.vocab = get_vocab(filename, vocab_size)
|
||||
self.unigram_counts, self.bigram_counts = get_gram_counts(filename)
|
||||
|
||||
def get_bigram_prob(self, bigram, smoothing):
|
||||
if smoothing:
|
||||
return (self.bigram_counts.get(bigram, 0) + 1) / (self.unigram_counts.get(bigram[0], 0) + len(self.vocab) + 1)
|
||||
else:
|
||||
return self.bigram_counts.get(bigram, 0) / self.unigram_counts.get(bigram[0], 1)
|
||||
|
||||
def predict_gaps(self, filename, smoothing=True):
|
||||
with open(filename, encoding='utf-8') as f, open('out.tsv', 'w', encoding='utf-8') as out:
|
||||
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating gap predictions'):
|
||||
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
|
||||
left_context, right_context = line[0].split()[-1], line[1].split()[0]
|
||||
context_probs = dict()
|
||||
for word in self.vocab:
|
||||
left_context_prob = self.get_bigram_prob((left_context, word), smoothing)
|
||||
right_context_prob = self.get_bigram_prob((word, right_context), smoothing)
|
||||
#context_probs[word] = left_context_prob * right_context_prob
|
||||
context_probs[word] = max(left_context_prob, right_context_prob)
|
||||
vocab_prob = sum(context_probs.values())
|
||||
probs_string = '\t'.join([f'{unigram}:{prob}' for unigram, prob in sorted(context_probs.items(), key=lambda x: x[1], reverse=True) if prob > 0])
|
||||
remaining_prob = 1.0 - vocab_prob
|
||||
if remaining_prob > 0:
|
||||
probs_string += f'\t:{remaining_prob}\n'
|
||||
else:
|
||||
probs_string += '\n'
|
||||
out.write(probs_string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = BigramModel()
|
||||
model.train('train.tsv', vocab_size=5000)
|
||||
model.predict_gaps('in.tsv', smoothing=True)
|
Loading…
Reference in New Issue
Block a user