{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "machine_shape": "hm" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3dV_4SJ2xY_C", "outputId": "28f0b228-e536-410e-8d45-a2063a04455b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/gdrive\n" ] } ], "source": [ "from google.colab import drive\n", "drive.mount(\"/content/gdrive\")" ] }, { "cell_type": "code", "source": [ "# %env DATA_DIR=/content/gdrive/MyDrive/data_gralinski\n", "DATA_DIR=\"/content/gdrive/MyDrive/data_gralinski/\"" ], "metadata": { "id": "VwdW1Qm3x9-N" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "import lzma\n", "import pickle\n", "from collections import Counter" ], "metadata": { "id": "irsty5KcyYkR" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "def clean_line(line: str):\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ')\n", " suffix = separated[7].replace(r'\\n', ' ')\n", " return prefix + ' ' + suffix" ], "metadata": { "id": "LXXtiKW3yY5J" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "def words(filename):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " index = 1\n", " print('Words')\n", " for line in fid:\n", " print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n", " text = clean_line(line)\n", " for word in text.split():\n", " yield word\n", " index += 1\n", " print()\n" ], "metadata": { "id": "y9r0wmD3ycIi" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "def bigrams(filename, V: dict):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " index = 1\n", " print('Bigrams')\n", " for line in fid:\n", " print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n", " text = clean_line(line)\n", " first_word = ''\n", " for second_word in text.split():\n", " if V.get(second_word) is None:\n", " second_word = 'UNK'\n", " if second_word:\n", " yield first_word, second_word\n", " first_word = second_word\n", " index += 1\n", " print()" ], "metadata": { "id": "HE3YfiHkycKt" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "def trigrams(filename, V: dict):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " index = 1\n", " print('Trigrams')\n", " for line in fid:\n", " print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n", " text = clean_line(line)\n", " first_word = ''\n", " second_word = ''\n", " for third_word in text.split():\n", " if V.get(third_word) is None:\n", " third_word = 'UNK'\n", " if first_word:\n", " yield first_word, second_word, third_word\n", " first_word = second_word\n", " second_word = third_word\n", " index += 1\n", " print()" ], "metadata": { "id": "lvHvJV6XycNZ" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "def tetragrams(filename, V: dict):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " index = 1\n", " print('Tetragrams')\n", " for line in fid:\n", " print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n", " text = clean_line(line)\n", " first_word = ''\n", " second_word = ''\n", " third_word = ''\n", " for fourth_word in text.split():\n", " if V.get(fourth_word) is None:\n", " fourth_word = 'UNK'\n", " if first_word:\n", " yield first_word, second_word, third_word, fourth_word\n", " first_word = second_word\n", " second_word = third_word\n", " third_word = fourth_word\n", " index += 1\n", " print()" ], "metadata": { "id": "sOKeZN9cycP-" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "def P(first_word, second_word=None, third_word=None, fourth_word=None):\n", " try:\n", " if second_word is None:\n", " return V_common_dict[first_word] / total\n", " if third_word is None:\n", " return V2_dict[(first_word, second_word)] / V_common_dict[first_word]\n", " if fourth_word is None:\n", " return V3_dict[(first_word, second_word, third_word)] / V2_dict[(first_word, second_word)]\n", " else:\n", " return V4_dict[(first_word, second_word, third_word, fourth_word)] / V3_dict[\n", " (first_word, second_word, third_word)]\n", " except KeyError:\n", " return 0" ], "metadata": { "id": "MN_RftZNycSB" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "def smoothed(tetragram):\n", " first, second, third, fourth = tetragram\n", " return 0.5 * P(first, second, third, fourth) + 0.25 * P(second, third, fourth) + 0.15 * P(third, fourth) + 0.1 * P(\n", " fourth)" ], "metadata": { "id": "n9wIsbLEycUd" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "def candidates(left_context, right_context):\n", " cand = {}\n", " first, second, third = left_context\n", " fifth, sixth, seventh = right_context\n", " for word in V_common_dict:\n", " p1 = smoothed((first, second, third, word))\n", " p2 = smoothed((second, third, word, fifth))\n", " p3 = smoothed((third, word, fifth, sixth))\n", " p4 = smoothed((word, fifth, sixth, seventh))\n", " cand[word] = p1 * p2 * p3 * p4\n", " cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5]\n", " norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand]\n", " for index, elem in enumerate(norm):\n", " unk = None\n", " if 'UNK' in elem:\n", " unk = norm.pop(index)\n", " norm.append(('', unk[1]))\n", " break\n", " if unk is None:\n", " norm[-1] = ('', norm[-1][1])\n", " return ' '.join([f'{x[0]}:{x[1]}' for x in norm])" ], "metadata": { "id": "l490B5KFycXj" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "def outputs(folder_name):\n", " print(f'Creating outputs in {folder_name}')\n", " with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n", " for line in fid:\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').split()\n", " suffix = separated[7].replace(r'\\n', ' ').split()\n", " left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n", " right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n", " w = candidates(left_context, right_context)\n", " f.write(w + '\\n')" ], "metadata": { "id": "mMC84-OzycZ5" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "WORD_LIMIT = 3000\n", "V = Counter(words(DATA_DIR+'train/in.tsv.xz'))\n", "V_common_dict = dict(V.most_common(WORD_LIMIT))\n", "# UNK = 0\n", "# for key, value in V.items():\n", "# if V_common_dict.get(key) is None:\n", "# UNK += value\n", "# V_common_dict['UNK'] = UNK\n", "# with open('V.pickle', 'wb') as handle:\n", "# pickle.dump(V_common_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", "\n", "\n", "with open(DATA_DIR+'5/V.pickle', 'rb') as handle:\n", " V_common_dict = pickle.load(handle)\n", "\n", "total = sum(V_common_dict.values())\n", "\n", "# V2 = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n", "# V2_dict = dict(V2)\n", "# with open('V2.pickle', 'wb') as handle:\n", "# pickle.dump(V2_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", "\n", "with open(DATA_DIR+'5/V2.pickle', 'rb') as handle:\n", " V2_dict = pickle.load(handle)\n", "\n", "# V3 = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n", "# V3_dict = dict(V3)\n", "# with open('V3.pickle', 'wb') as handle:\n", "# pickle.dump(V3_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", "\n", "with open(DATA_DIR+'5/V3.pickle', 'rb') as handle:\n", " V3_dict = pickle.load(handle)\n", "\n", "V4 = Counter(tetragrams(DATA_DIR+'train/in.tsv.xz', V_common_dict))\n", "V4_dict = dict(V4)\n", "with open('V4.pickle', 'wb') as handle:\n", " pickle.dump(V4_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", "\n", "# with open('V4.pickle', 'rb') as handle:\n", "# V4_dict = pickle.load(handle)\n", "\n", "\n" ], "metadata": { "id": "Fsvv3QJl7kWN", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "3c8387a4-5ebe-41ae-aafa-2bf3999f0025" }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Words\n", "Progress: 100.000000%\n", "Tetragrams\n", "Progress: 100.000000%\n" ] } ] }, { "cell_type": "code", "source": [ "outputs(DATA_DIR+'dev-0')\n", "outputs(DATA_DIR+'test-A')\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UK73WsKnB8ZP", "outputId": "2c3d6171-bfb2-4fcd-cfc0-49cabdf9f0a9" }, "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Creating outputs in /content/gdrive/MyDrive/data_gralinski/dev-0\n", "Creating outputs in /content/gdrive/MyDrive/data_gralinski/test-A\n" ] } ] } ] }