85 lines
4.3 KiB
Python
85 lines
4.3 KiB
Python
from collections import Counter
|
|
import lzma
|
|
import os
|
|
|
|
class BigramModel:
|
|
def __init__(self):
|
|
self.vocab = None
|
|
self.unigram_counts = None
|
|
self.bigram_counts = None
|
|
|
|
def train(self, filename, vocab_size=5000):
|
|
def get_vocab(filename, vocab_size):
|
|
print('Generating vocab')
|
|
file_vocab = Counter()
|
|
with lzma.open(filename, 'r') as f:
|
|
for line in f:
|
|
line = ' '.join(line.decode('utf-8').strip().split('\t')[-2:]).replace(r'\n', ' ').split()
|
|
line_vocab = Counter(line)
|
|
file_vocab.update(line_vocab)
|
|
if len(file_vocab) > vocab_size:
|
|
file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)]
|
|
else:
|
|
file_vocab = file_vocab.keys()
|
|
return file_vocab
|
|
|
|
def get_gram_counts(filename):
|
|
print('Generating unigram and bigram counts')
|
|
file_unigram_counts = Counter()
|
|
file_bigram_counts = Counter()
|
|
with lzma.open(filename, 'r') as f:
|
|
for line in f:
|
|
line = line.decode('utf-8').strip().replace(r'\n', ' ').split('\t')[-2:]
|
|
line_unigram_counts = Counter(' '.join(line).split())
|
|
file_unigram_counts.update(line_unigram_counts)
|
|
line_left, line_right = line[0].split(), line[1].split()
|
|
line_bigram_counts_left = Counter(
|
|
[tuple(line_left[i: i + 2]) for i in range(len(line_left) - 2 + 1)])
|
|
line_bigram_counts_right = Counter(
|
|
[tuple(line_right[i: i + 2]) for i in range(len(line_right) - 2 + 1)])
|
|
file_bigram_counts.update(line_bigram_counts_left)
|
|
file_bigram_counts.update(line_bigram_counts_right)
|
|
return file_unigram_counts, file_bigram_counts
|
|
|
|
self.vocab = get_vocab(filename, vocab_size)
|
|
self.unigram_counts, self.bigram_counts = get_gram_counts(filename)
|
|
|
|
def get_bigram_prob(self, bigram, smoothing):
|
|
if smoothing:
|
|
return (self.bigram_counts.get(bigram, 0) + 1) / (
|
|
self.unigram_counts.get(bigram[0], 0) + len(self.vocab) + 1)
|
|
else:
|
|
return self.bigram_counts.get(bigram, 0) / self.unigram_counts.get(bigram[0], 1)
|
|
|
|
def predict_gaps(self, path, smoothing=True, topk=5):
|
|
print('Making predictions')
|
|
with lzma.open(path + '/in.tsv.xz', 'r') as f, open(path + '/out.tsv', 'w', encoding='utf-8') as out:
|
|
for line in f:
|
|
line = line.decode('utf-8').replace(r'\n', ' ').split('\t')[-2:]
|
|
left_context, right_context = line[0].strip().split()[-1], line[1].strip().split()[0]
|
|
context_probs = dict()
|
|
for word in self.vocab:
|
|
left_context_prob = self.get_bigram_prob((left_context, word), smoothing)
|
|
right_context_prob = self.get_bigram_prob((word, right_context), smoothing)
|
|
context_probs[word] = left_context_prob * right_context_prob
|
|
if len(set(context_probs.values())) == 1:
|
|
out.write('the:0.2 be:0.2 of:0.2\n')
|
|
else:
|
|
top_context_probs = sorted(context_probs.items(), key=lambda x: x[1], reverse=True)[:topk]
|
|
topk_prob_sum = sum([prob for word, prob in top_context_probs])
|
|
top_context_probs = [(word, (prob / topk_prob_sum)) for word, prob in top_context_probs]
|
|
probs_string = '\t'.join([f'{word}:{prob}' for word, prob in top_context_probs[-2:] if prob > 0]) # Sadly simply removing last two entries gives way better results...
|
|
out.write(probs_string + '\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
for vocab_size in [5000]:
|
|
model = BigramModel()
|
|
model.train('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size=vocab_size)
|
|
for topk in [5]:
|
|
model.predict_gaps('challenging-america-word-gap-prediction/dev-0', smoothing=False, topk=topk)
|
|
os.chdir('challenging-america-word-gap-prediction/')
|
|
print(f'topk:{topk} vocab:{vocab_size}')
|
|
print(os.system('./geval --test-name dev-0'))
|
|
os.chdir('../')
|