9.4 KiB
9.4 KiB
import lzma
import pickle
import os
from collections import Counter
def get_line(line):
parts = line.split('\t')
prefix = parts[6].replace(r'\n', ' ')
suffix = parts[7].replace(r'\n', ' ')
return prefix + ' ' + suffix
def read_words(path):
with lzma.open(path, 'rt', encoding='utf-8') as f:
for line in f:
text = get_line(line)
for word in text.split():
yield word
from collections import defaultdict
import bisect
import itertools
def ngrams(words, n=1):
ngrams_counts = defaultdict(int)
sum_counts = defaultdict(int)
for i in range(len(words) - n + 1):
ngram = tuple(words[i:i+n])
ngrams_counts[ngram] += 1
for j in range(1, n):
sum_counts[ngram[:j]] += 1
for key, value in ngrams_counts.items():
sum_counts[key[:-1]] += value
return ngrams_counts, sum_counts
vocab_size = 2000
if os.path.exists('model/vocab_top_1000.pkl'):
with open('model/vocab_top_1000.pkl', 'rb') as f:
vocab_top_1000 = pickle.load(f)
else:
counter = Counter(read_words('train/in.tsv.xz'))
vocab_top_1000 = dict(counter.most_common(vocab_size))
unk = 0
for word, count in counter.items():
if not vocab_top_1000.get(word):
unk += count
vocab_top_1000['<unk>'] = unk
with open('model/vocab_top_1000.pkl', 'wb') as f:
pickle.dump(vocab_top_1000, f, protocol=pickle.HIGHEST_PROTOCOL)
def ngrams(filename, V: dict, n: int):
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
print(f'{n}-grams')
ngram_func = {
1: lambda w: (w,),
2: lambda w1, w2: (w1, w2),
3: lambda w1, w2, w3: (w1, w2, w3),
4: lambda w1, w2, w3, w4: (w1, w2, w3, w4),
}[n]
for line in fid:
text = get_line(line)
words = [''] * (n-1)
for word in text.split():
if V.get(word) is None:
word = '<unk>'
words.append(word)
yield ngram_func(*words[-n:])
if os.path.exists('model/bigrams.pkl'):
with open('model/bigrams.pkl', 'rb') as f:
bigrams = pickle.load(f)
else:
bigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 2))
with open('model/bigrams.pkl', 'wb') as f:
pickle.dump(bigrams, f, protocol=pickle.HIGHEST_PROTOCOL)
if os.path.exists('model/trigrams.pkl'):
with open('model/trigrams.pkl', 'rb') as f:
trigrams = pickle.load(f)
else:
trigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 3))
with open('model/trigrams.pkl', 'wb') as f:
pickle.dump(trigrams, f, protocol=pickle.HIGHEST_PROTOCOL)
if os.path.exists('model/tetragrams.pkl'):
with open('model/tetragrams.pkl', 'rb') as f:
tetragrams = pickle.load(f)
else:
tetragrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 4))
with open('model/tetragrams.pkl', 'wb') as f:
pickle.dump(tetragrams, f, protocol=pickle.HIGHEST_PROTOCOL)
from collections import defaultdict
import bisect
import itertools
def probability(first, second=None, third=None, fourth=None):
# Unigram
if not second:
return vocab_top_1000.get(first, 0) / sum(vocab_top_1000.values())
# Bigram
bigram_key = (first, second)
if bigram_key in bigrams:
if not third:
return bigrams[bigram_key] / vocab_top_1000.get(first, 0)
# Trigram
trigram_key = (first, second, third)
if trigram_key in trigrams:
if not fourth:
return trigrams[trigram_key] / bigrams.get(bigram_key, 0)
# Tetragram
tetragram_key = (first, second, third, fourth)
if tetragram_key in tetragrams:
return tetragrams[tetragram_key] / trigrams.get(trigram_key, 0)
# Key not found
return 0
def interpolate(tetragram):
first, second, third, fourth = tetragram
if first and second and third and fourth:
return 0.4 * probability(first, second, third, fourth) + 0.3 * probability(second, third, fourth) + 0.2 * probability(third, fourth) + 0.1 * probability(fourth)
elif first and second and third:
return 0.5 * probability(first, second, third) + 0.3 * probability(second, third) + 0.2 * probability(third)
elif first and second:
return 0.6 * probability(first, second) + 0.4 * probability(second)
return probability(first)
def consider_context(left_context, right_context):
first, second, third = left_context
fifth, sixth, seventh = right_context
probs = []
for word in vocab_top_1000:
p1 = interpolate((first, second, third, word))
p2 = interpolate((second, third, word, fifth))
p3 = interpolate((third, word, fifth, sixth))
p4 = interpolate((word, fifth, sixth, seventh))
prob = p1 * p2 * p3 * p4
probs.append((word, prob))
probs = sorted(probs, key=lambda x: x[1], reverse=True)[:5]
total_prob = sum(prob for _, prob in probs)
norm = [(word, prob/total_prob) for word, prob in probs]
for index, elem in enumerate(norm):
if elem[0] == '<unk>':
norm.pop(index)
norm.append(('', elem[1]))
break
else:
norm[-1] = ('', norm[-1][1])
return ' '.join([f'{x[0]}:{x[1]}' for x in norm])
def execute(path):
with lzma.open(f'{path}/in.tsv.xz', 'rt', encoding='utf-8') as f, \
open(f'{path}/out.tsv', 'w', encoding='utf-8') as out:
for line in f:
prefix, suffix = line.split('\t')[6:8]
prefix = prefix.replace(r'\n', ' ').split()[-3:]
suffix = suffix.replace(r'\n', ' ').split()[:3]
left = [vocab_top_1000.get(x, '<unk>') for x in prefix]
right = [vocab_top_1000.get(x, '<unk>') for x in suffix]
result = consider_context(left, right)
out.write(f"{result}\n")
execute('dev-0')
execute('test-A')