challenging-america-word-ga.../tetragram_final.ipynb
2023-04-23 19:59:11 +02:00

9.4 KiB

import lzma
import pickle
import os
from collections import Counter
def get_line(line):
    parts = line.split('\t')
    prefix = parts[6].replace(r'\n', ' ')
    suffix = parts[7].replace(r'\n', ' ')
    return prefix + ' ' + suffix
def read_words(path):
    with lzma.open(path, 'rt', encoding='utf-8') as f:
        for line in f:
            text = get_line(line)
            for word in text.split():
                yield word
from collections import defaultdict
import bisect
import itertools


def ngrams(words, n=1):
    ngrams_counts = defaultdict(int)
    sum_counts = defaultdict(int)
    
    for i in range(len(words) - n + 1):
        ngram = tuple(words[i:i+n])
        ngrams_counts[ngram] += 1
        for j in range(1, n):
            sum_counts[ngram[:j]] += 1
    
    for key, value in ngrams_counts.items():
        sum_counts[key[:-1]] += value
    
    return ngrams_counts, sum_counts
vocab_size = 2000
if os.path.exists('model/vocab_top_1000.pkl'):
    with open('model/vocab_top_1000.pkl', 'rb') as f:
        vocab_top_1000 = pickle.load(f)
else:
    counter = Counter(read_words('train/in.tsv.xz'))
    vocab_top_1000 = dict(counter.most_common(vocab_size))
    unk = 0
    for word, count in counter.items():
        if not vocab_top_1000.get(word):
            unk += count
    vocab_top_1000['<unk>'] = unk
    with open('model/vocab_top_1000.pkl', 'wb') as f:
        pickle.dump(vocab_top_1000, f, protocol=pickle.HIGHEST_PROTOCOL)
def ngrams(filename, V: dict, n: int):
    with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
        print(f'{n}-grams')
        ngram_func = {
            1: lambda w: (w,),
            2: lambda w1, w2: (w1, w2),
            3: lambda w1, w2, w3: (w1, w2, w3),
            4: lambda w1, w2, w3, w4: (w1, w2, w3, w4),
        }[n]
        for line in fid:
            text = get_line(line)
            words = [''] * (n-1)
            for word in text.split():
                if V.get(word) is None:
                    word = '<unk>'
                words.append(word)
                yield ngram_func(*words[-n:])
if os.path.exists('model/bigrams.pkl'):
    with open('model/bigrams.pkl', 'rb') as f:
        bigrams = pickle.load(f)
else:
    bigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 2))
    with open('model/bigrams.pkl', 'wb') as f:
        pickle.dump(bigrams, f, protocol=pickle.HIGHEST_PROTOCOL)
if os.path.exists('model/trigrams.pkl'):
    with open('model/trigrams.pkl', 'rb') as f:
        trigrams = pickle.load(f)
else:
    trigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 3))
    with open('model/trigrams.pkl', 'wb') as f:
        pickle.dump(trigrams, f, protocol=pickle.HIGHEST_PROTOCOL)
if os.path.exists('model/tetragrams.pkl'):
    with open('model/tetragrams.pkl', 'rb') as f:
        tetragrams = pickle.load(f)
else:
    tetragrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 4))
    with open('model/tetragrams.pkl', 'wb') as f:
        pickle.dump(tetragrams, f, protocol=pickle.HIGHEST_PROTOCOL)
from collections import defaultdict
import bisect
import itertools

def probability(first, second=None, third=None, fourth=None):
# Unigram
    if not second:
        return vocab_top_1000.get(first, 0) / sum(vocab_top_1000.values())
    
    # Bigram
    bigram_key = (first, second)
    if bigram_key in bigrams:
        if not third:
            return bigrams[bigram_key] / vocab_top_1000.get(first, 0)
        
        # Trigram
        trigram_key = (first, second, third)
        if trigram_key in trigrams:
            if not fourth:
                return trigrams[trigram_key] / bigrams.get(bigram_key, 0)
            
            # Tetragram
            tetragram_key = (first, second, third, fourth)
            if tetragram_key in tetragrams:
                return tetragrams[tetragram_key] / trigrams.get(trigram_key, 0)
    
    # Key not found
    return 0

def interpolate(tetragram):
    first, second, third, fourth = tetragram
    if first and second and third and fourth:
        return 0.4 * probability(first, second, third, fourth) + 0.3 * probability(second, third, fourth) + 0.2 * probability(third, fourth) + 0.1 * probability(fourth)
    elif first and second and third:
        return 0.5 * probability(first, second, third) + 0.3 * probability(second, third) + 0.2 * probability(third)
    elif first and second:
        return 0.6 * probability(first, second) + 0.4 * probability(second) 
    return probability(first)
def consider_context(left_context, right_context):
    first, second, third = left_context
    fifth, sixth, seventh = right_context
    
    probs = []
    for word in vocab_top_1000:
        p1 = interpolate((first, second, third, word))
        p2 = interpolate((second, third, word, fifth))
        p3 = interpolate((third, word, fifth, sixth)) 
        p4 = interpolate((word, fifth, sixth, seventh))
        prob = p1 * p2 * p3 * p4
        probs.append((word, prob))
    
    probs = sorted(probs, key=lambda x: x[1], reverse=True)[:5]
    total_prob = sum(prob for _, prob in probs)
    
    norm = [(word, prob/total_prob) for word, prob in probs]
    for index, elem in enumerate(norm):
        if elem[0] == '<unk>':
            norm.pop(index)
            norm.append(('', elem[1]))
            break
    else:
        norm[-1] = ('', norm[-1][1])
        
    return ' '.join([f'{x[0]}:{x[1]}' for x in norm])
def execute(path):
    with lzma.open(f'{path}/in.tsv.xz', 'rt', encoding='utf-8') as f, \
        open(f'{path}/out.tsv', 'w', encoding='utf-8') as out:
        for line in f:
            prefix, suffix = line.split('\t')[6:8]
            prefix = prefix.replace(r'\n', ' ').split()[-3:]
            suffix = suffix.replace(r'\n', ' ').split()[:3]
            left = [vocab_top_1000.get(x, '<unk>') for x in prefix]
            right = [vocab_top_1000.get(x, '<unk>') for x in suffix]
            result = consider_context(left, right)
            out.write(f"{result}\n")
execute('dev-0')
execute('test-A')