{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import lzma\n", "import pickle\n", "import os\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_line(line):\n", " parts = line.split('\\t')\n", " prefix = parts[6].replace(r'\\n', ' ')\n", " suffix = parts[7].replace(r'\\n', ' ')\n", " return prefix + ' ' + suffix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def read_words(path):\n", " with lzma.open(path, 'rt', encoding='utf-8') as f:\n", " for line in f:\n", " text = get_line(line)\n", " for word in text.split():\n", " yield word" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "import bisect\n", "import itertools\n", "\n", "\n", "def ngrams(words, n=1):\n", " ngrams_counts = defaultdict(int)\n", " sum_counts = defaultdict(int)\n", " \n", " for i in range(len(words) - n + 1):\n", " ngram = tuple(words[i:i+n])\n", " ngrams_counts[ngram] += 1\n", " for j in range(1, n):\n", " sum_counts[ngram[:j]] += 1\n", " \n", " for key, value in ngrams_counts.items():\n", " sum_counts[key[:-1]] += value\n", " \n", " return ngrams_counts, sum_counts\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab_size = 2000\n", "if os.path.exists('model/vocab_top_1000.pkl'):\n", " with open('model/vocab_top_1000.pkl', 'rb') as f:\n", " vocab_top_1000 = pickle.load(f)\n", "else:\n", " counter = Counter(read_words('train/in.tsv.xz'))\n", " vocab_top_1000 = dict(counter.most_common(vocab_size))\n", " unk = 0\n", " for word, count in counter.items():\n", " if not vocab_top_1000.get(word):\n", " unk += count\n", " vocab_top_1000[''] = unk\n", " with open('model/vocab_top_1000.pkl', 'wb') as f:\n", " pickle.dump(vocab_top_1000, f, protocol=pickle.HIGHEST_PROTOCOL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def ngrams(filename, V: dict, n: int):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " print(f'{n}-grams')\n", " ngram_func = {\n", " 1: lambda w: (w,),\n", " 2: lambda w1, w2: (w1, w2),\n", " 3: lambda w1, w2, w3: (w1, w2, w3),\n", " 4: lambda w1, w2, w3, w4: (w1, w2, w3, w4),\n", " }[n]\n", " for line in fid:\n", " text = get_line(line)\n", " words = [''] * (n-1)\n", " for word in text.split():\n", " if V.get(word) is None:\n", " word = ''\n", " words.append(word)\n", " yield ngram_func(*words[-n:])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if os.path.exists('model/bigrams.pkl'):\n", " with open('model/bigrams.pkl', 'rb') as f:\n", " bigrams = pickle.load(f)\n", "else:\n", " bigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 2))\n", " with open('model/bigrams.pkl', 'wb') as f:\n", " pickle.dump(bigrams, f, protocol=pickle.HIGHEST_PROTOCOL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if os.path.exists('model/trigrams.pkl'):\n", " with open('model/trigrams.pkl', 'rb') as f:\n", " trigrams = pickle.load(f)\n", "else:\n", " trigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 3))\n", " with open('model/trigrams.pkl', 'wb') as f:\n", " pickle.dump(trigrams, f, protocol=pickle.HIGHEST_PROTOCOL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if os.path.exists('model/tetragrams.pkl'):\n", " with open('model/tetragrams.pkl', 'rb') as f:\n", " tetragrams = pickle.load(f)\n", "else:\n", " tetragrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 4))\n", " with open('model/tetragrams.pkl', 'wb') as f:\n", " pickle.dump(tetragrams, f, protocol=pickle.HIGHEST_PROTOCOL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "import bisect\n", "import itertools\n", "\n", "def probability(first, second=None, third=None, fourth=None):\n", "# Unigram\n", " if not second:\n", " return vocab_top_1000.get(first, 0) / sum(vocab_top_1000.values())\n", " \n", " # Bigram\n", " bigram_key = (first, second)\n", " if bigram_key in bigrams:\n", " if not third:\n", " return bigrams[bigram_key] / vocab_top_1000.get(first, 0)\n", " \n", " # Trigram\n", " trigram_key = (first, second, third)\n", " if trigram_key in trigrams:\n", " if not fourth:\n", " return trigrams[trigram_key] / bigrams.get(bigram_key, 0)\n", " \n", " # Tetragram\n", " tetragram_key = (first, second, third, fourth)\n", " if tetragram_key in tetragrams:\n", " return tetragrams[tetragram_key] / trigrams.get(trigram_key, 0)\n", " \n", " # Key not found\n", " return 0\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def interpolate(tetragram):\n", " first, second, third, fourth = tetragram\n", " if first and second and third and fourth:\n", " return 0.4 * probability(first, second, third, fourth) + 0.3 * probability(second, third, fourth) + 0.2 * probability(third, fourth) + 0.1 * probability(fourth)\n", " elif first and second and third:\n", " return 0.5 * probability(first, second, third) + 0.3 * probability(second, third) + 0.2 * probability(third)\n", " elif first and second:\n", " return 0.6 * probability(first, second) + 0.4 * probability(second) \n", " return probability(first)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def consider_context(left_context, right_context):\n", " first, second, third = left_context\n", " fifth, sixth, seventh = right_context\n", " \n", " probs = []\n", " for word in vocab_top_1000:\n", " p1 = interpolate((first, second, third, word))\n", " p2 = interpolate((second, third, word, fifth))\n", " p3 = interpolate((third, word, fifth, sixth)) \n", " p4 = interpolate((word, fifth, sixth, seventh))\n", " prob = p1 * p2 * p3 * p4\n", " probs.append((word, prob))\n", " \n", " probs = sorted(probs, key=lambda x: x[1], reverse=True)[:5]\n", " total_prob = sum(prob for _, prob in probs)\n", " \n", " norm = [(word, prob/total_prob) for word, prob in probs]\n", " for index, elem in enumerate(norm):\n", " if elem[0] == '':\n", " norm.pop(index)\n", " norm.append(('', elem[1]))\n", " break\n", " else:\n", " norm[-1] = ('', norm[-1][1])\n", " \n", " return ' '.join([f'{x[0]}:{x[1]}' for x in norm])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def execute(path):\n", " with lzma.open(f'{path}/in.tsv.xz', 'rt', encoding='utf-8') as f, \\\n", " open(f'{path}/out.tsv', 'w', encoding='utf-8') as out:\n", " for line in f:\n", " prefix, suffix = line.split('\\t')[6:8]\n", " prefix = prefix.replace(r'\\n', ' ').split()[-3:]\n", " suffix = suffix.replace(r'\\n', ' ').split()[:3]\n", " left = [vocab_top_1000.get(x, '') for x in prefix]\n", " right = [vocab_top_1000.get(x, '') for x in suffix]\n", " result = consider_context(left, right)\n", " out.write(f\"{result}\\n\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "execute('dev-0')\n", "execute('test-A')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }