diff --git a/tetragram_final.ipynb b/tetragram_final.ipynb new file mode 100644 index 0000000..abcc21d --- /dev/null +++ b/tetragram_final.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import lzma\n", + "import pickle\n", + "import os\n", + "from collections import Counter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_line(line):\n", + " parts = line.split('\\t')\n", + " prefix = parts[6].replace(r'\\n', ' ')\n", + " suffix = parts[7].replace(r'\\n', ' ')\n", + " return prefix + ' ' + suffix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def read_words(path):\n", + " with lzma.open(path, 'rt', encoding='utf-8') as f:\n", + " for line in f:\n", + " text = get_line(line)\n", + " for word in text.split():\n", + " yield word" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import bisect\n", + "import itertools\n", + "\n", + "\n", + "def ngrams(words, n=1):\n", + " ngrams_counts = defaultdict(int)\n", + " sum_counts = defaultdict(int)\n", + " \n", + " for i in range(len(words) - n + 1):\n", + " ngram = tuple(words[i:i+n])\n", + " ngrams_counts[ngram] += 1\n", + " for j in range(1, n):\n", + " sum_counts[ngram[:j]] += 1\n", + " \n", + " for key, value in ngrams_counts.items():\n", + " sum_counts[key[:-1]] += value\n", + " \n", + " return ngrams_counts, sum_counts\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_size = 2000\n", + "if os.path.exists('model/vocab_top_1000.pkl'):\n", + " with open('model/vocab_top_1000.pkl', 'rb') as f:\n", + " vocab_top_1000 = pickle.load(f)\n", + "else:\n", + " counter = Counter(read_words('train/in.tsv.xz'))\n", + " vocab_top_1000 = dict(counter.most_common(vocab_size))\n", + " unk = 0\n", + " for word, count in counter.items():\n", + " if not vocab_top_1000.get(word):\n", + " unk += count\n", + " vocab_top_1000[''] = unk\n", + " with open('model/vocab_top_1000.pkl', 'wb') as f:\n", + " pickle.dump(vocab_top_1000, f, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def ngrams(filename, V: dict, n: int):\n", + " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", + " print(f'{n}-grams')\n", + " ngram_func = {\n", + " 1: lambda w: (w,),\n", + " 2: lambda w1, w2: (w1, w2),\n", + " 3: lambda w1, w2, w3: (w1, w2, w3),\n", + " 4: lambda w1, w2, w3, w4: (w1, w2, w3, w4),\n", + " }[n]\n", + " for line in fid:\n", + " text = get_line(line)\n", + " words = [''] * (n-1)\n", + " for word in text.split():\n", + " if V.get(word) is None:\n", + " word = ''\n", + " words.append(word)\n", + " yield ngram_func(*words[-n:])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.exists('model/bigrams.pkl'):\n", + " with open('model/bigrams.pkl', 'rb') as f:\n", + " bigrams = pickle.load(f)\n", + "else:\n", + " bigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 2))\n", + " with open('model/bigrams.pkl', 'wb') as f:\n", + " pickle.dump(bigrams, f, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.exists('model/trigrams.pkl'):\n", + " with open('model/trigrams.pkl', 'rb') as f:\n", + " trigrams = pickle.load(f)\n", + "else:\n", + " trigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 3))\n", + " with open('model/trigrams.pkl', 'wb') as f:\n", + " pickle.dump(trigrams, f, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.exists('model/tetragrams.pkl'):\n", + " with open('model/tetragrams.pkl', 'rb') as f:\n", + " tetragrams = pickle.load(f)\n", + "else:\n", + " tetragrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 4))\n", + " with open('model/tetragrams.pkl', 'wb') as f:\n", + " pickle.dump(tetragrams, f, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import bisect\n", + "import itertools\n", + "\n", + "def probability(first, second=None, third=None, fourth=None):\n", + "# Unigram\n", + " if not second:\n", + " return vocab_top_1000.get(first, 0) / sum(vocab_top_1000.values())\n", + " \n", + " # Bigram\n", + " bigram_key = (first, second)\n", + " if bigram_key in bigrams:\n", + " if not third:\n", + " return bigrams[bigram_key] / vocab_top_1000.get(first, 0)\n", + " \n", + " # Trigram\n", + " trigram_key = (first, second, third)\n", + " if trigram_key in trigrams:\n", + " if not fourth:\n", + " return trigrams[trigram_key] / bigrams.get(bigram_key, 0)\n", + " \n", + " # Tetragram\n", + " tetragram_key = (first, second, third, fourth)\n", + " if tetragram_key in tetragrams:\n", + " return tetragrams[tetragram_key] / trigrams.get(trigram_key, 0)\n", + " \n", + " # Key not found\n", + " return 0\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def interpolate(tetragram):\n", + " first, second, third, fourth = tetragram\n", + " if first and second and third and fourth:\n", + " return 0.4 * probability(first, second, third, fourth) + 0.3 * probability(second, third, fourth) + 0.2 * probability(third, fourth) + 0.1 * probability(fourth)\n", + " elif first and second and third:\n", + " return 0.5 * probability(first, second, third) + 0.3 * probability(second, third) + 0.2 * probability(third)\n", + " elif first and second:\n", + " return 0.6 * probability(first, second) + 0.4 * probability(second) \n", + " return probability(first)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def consider_context(left_context, right_context):\n", + " first, second, third = left_context\n", + " fifth, sixth, seventh = right_context\n", + " \n", + " probs = []\n", + " for word in vocab_top_1000:\n", + " p1 = interpolate((first, second, third, word))\n", + " p2 = interpolate((second, third, word, fifth))\n", + " p3 = interpolate((third, word, fifth, sixth)) \n", + " p4 = interpolate((word, fifth, sixth, seventh))\n", + " prob = p1 * p2 * p3 * p4\n", + " probs.append((word, prob))\n", + " \n", + " probs = sorted(probs, key=lambda x: x[1], reverse=True)[:5]\n", + " total_prob = sum(prob for _, prob in probs)\n", + " \n", + " norm = [(word, prob/total_prob) for word, prob in probs]\n", + " for index, elem in enumerate(norm):\n", + " if elem[0] == '':\n", + " norm.pop(index)\n", + " norm.append(('', elem[1]))\n", + " break\n", + " else:\n", + " norm[-1] = ('', norm[-1][1])\n", + " \n", + " return ' '.join([f'{x[0]}:{x[1]}' for x in norm])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def execute(path):\n", + " with lzma.open(f'{path}/in.tsv.xz', 'rt', encoding='utf-8') as f, \\\n", + " open(f'{path}/out.tsv', 'w', encoding='utf-8') as out:\n", + " for line in f:\n", + " prefix, suffix = line.split('\\t')[6:8]\n", + " prefix = prefix.replace(r'\\n', ' ').split()[-3:]\n", + " suffix = suffix.replace(r'\\n', ' ').split()[:3]\n", + " left = [vocab_top_1000.get(x, '') for x in prefix]\n", + " right = [vocab_top_1000.get(x, '') for x in suffix]\n", + " result = consider_context(left, right)\n", + " out.write(f\"{result}\\n\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "execute('dev-0')\n", + "execute('test-A')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}