{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "kenlm.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/gdrive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GQG8KfEo5BwV", "outputId": "7899949c-5bc3-4d13-acb2-88aa47f46655" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/gdrive\n" ] } ] }, { "cell_type": "code", "source": [ "!pip install https://github.com/kpu/kenlm/archive/master.zip" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GsoWSBmH5DT3", "outputId": "f67d798f-54f8-4c90-bdef-590424b49dd5" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting https://github.com/kpu/kenlm/archive/master.zip\n", " Using cached https://github.com/kpu/kenlm/archive/master.zip (550 kB)\n" ] } ] }, { "cell_type": "code", "source": [ "!pip install english_words" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rwNPsafM6KSb", "outputId": "b4e21df6-cf55-4f7a-843c-a87f1acc6082" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting english_words\n", " Downloading english-words-1.1.0.tar.gz (1.1 MB)\n", "\u001b[K |████████████████████████████████| 1.1 MB 5.4 MB/s \n", "\u001b[?25hBuilding wheels for collected packages: english-words\n", " Building wheel for english-words (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for english-words: filename=english_words-1.1.0-py3-none-any.whl size=1106680 sha256=9959ed5d02a4c06063019ede18eebf1ef1be2562a62aa85f86a13d6a3fe1e34b\n", " Stored in directory: /root/.cache/pip/wheels/25/3d/4c/12a119ce90b46b4f90f9ddf41d719ecabb40faec6103379fc8\n", "Successfully built english-words\n", "Installing collected packages: english-words\n", "Successfully installed english-words-1.1.0\n" ] } ] }, { "cell_type": "code", "source": [ "import nltk\n", "nltk.download(\"punkt\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "02yP2lJ9_4dT", "outputId": "5de6ad9b-41e0-4577-9af3-4ceefe85f3d0" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "[nltk_data] Downloading package punkt to /root/nltk_data...\n", "[nltk_data] Unzipping tokenizers/punkt.zip.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ " lmplz_command = f\"{KENLM_LMPLZ_PATH} -o 4 < train.txt > model.arpa\"\n", " build_binary_command = f\"{KENLM_BUILD_BINARY_PATH} model.arpa model.binary\"\n", " os.system('echo %s|sudo -S %s' % (SUDO_PASSWORD, lmplz_command))\n", " os.system('echo %s|sudo -S %s' % (SUDO_PASSWORD, build_binary_command))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YC397rhc7-CW", "outputId": "53adb185-9cbf-4ace-8556-7335776313d6" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "256" ] }, "metadata": {}, "execution_count": 8 } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tt_ucItY484I", "outputId": "e2839c64-b3b9-42fb-c2cf-dc7dc60ad8ab" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:51: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version.\n", "\n", "\n" ] } ], "source": [ "import pandas as pd\n", "import csv\n", "import regex as re\n", "import kenlm\n", "from english_words import english_words_alpha_set\n", "from nltk import word_tokenize\n", "from math import log10\n", "from pathlib import Path\n", "import os\n", "import numpy as np\n", "\n", "\n", "KENLM_BUILD_PATH = Path(\"gdrive/My Drive/gonito/kenlm/build\")\n", "KENLM_LMPLZ_PATH = KENLM_BUILD_PATH / \"bin\" / \"lmplz\"\n", "KENLM_BUILD_BINARY_PATH = KENLM_BUILD_PATH / \"bin\" / \"build_binary\"\n", "SUDO_PASSWORD = \"\"\n", "PREDICTION = 'the:0.03 be:0.03 to:0.03 of:0.025 and:0.025 a:0.025 in:0.020 that:0.020 have:0.015 I:0.010 it:0.010 for:0.010 not:0.010 on:0.010 with:0.010 he:0.010 as:0.010 you:0.010 do:0.010 at:0.010 :0.77'\n", "\n", "\n", "def clean(text):\n", " text = str(text).lower().replace(\"-\\\\n\", \"\").replace(\"\\\\n\", \" \")\n", " return re.sub(r\"\\p{P}\", \"\", text)\n", "\n", "\n", "def create_train_data():\n", " data = pd.read_csv(\"gdrive/My Drive/gonito/train/in.tsv.xz\", sep=\"\\t\", error_bad_lines=False, header=None, quoting=csv.QUOTE_NONE, nrows=50000)\n", " train_labels = pd.read_csv(\"gdrive/My Drive/gonito/train/expected.tsv\", sep=\"\\t\", error_bad_lines=False, header=None, quoting=csv.QUOTE_NONE, nrows=50000)\n", "\n", " train_data = data[[6, 7]]\n", " train_data = pd.concat([train_data, train_labels], axis=1)\n", "\n", " return train_data[6] + train_data[0] + train_data[7]\n", "\n", "\n", "def create_train_file(filename=\"gdrive/My Drive/gonito/train.txt\"):\n", " with open(filename, \"w\") as f:\n", " for line in create_train_data():\n", " f.write(clean(line) + \"\\n\")\n", " \n", "\n", "def train_model():\n", " lmplz_command = f\"{KENLM_LMPLZ_PATH} -o 4 < train.txt > model.arpa\"\n", " build_binary_command = f\"{KENLM_BUILD_BINARY_PATH} model.arpa model.binary\"\n", " os.system('echo %s|sudo -S %s' % (SUDO_PASSWORD, lmplz_command))\n", " os.system('echo %s|sudo -S %s' % (SUDO_PASSWORD, build_binary_command))\n", " \n", "\n", "def softmax(x):\n", " e_x = np.exp(x - np.max(x))\n", " return e_x / e_x.sum(axis=0)\n", "\n", "def predict(model, before, after):\n", " best_scores = []\n", " for word in english_words_alpha_set:\n", " text = ' '.join([before, word, after])\n", " text_score = model.score(text, bos=False, eos=False)\n", " if len(best_scores) < 12:\n", " best_scores.append((word, text_score))\n", " else:\n", " worst_score = None\n", " for score in best_scores:\n", " if not worst_score:\n", " worst_score = score\n", " else:\n", " if worst_score[1] > score[1]:\n", " worst_score = score\n", " if worst_score[1] < text_score:\n", " best_scores.remove(worst_score)\n", " best_scores.append((word, text_score))\n", " probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)\n", " pred_str = ''\n", " for word, prob in probs:\n", " pred_str += f'{word}:{prob} '\n", " pred_str += f':{log10(0.99)}'\n", " return pred_str\n", "\n", "def make_prediction(model, path, result_path):\n", " data = pd.read_csv(path, sep='\\t', header=None, quoting=csv.QUOTE_NONE)\n", " with open(result_path, 'w', encoding='utf-8') as file_out:\n", " for _, row in data.iterrows():\n", " before, after = word_tokenize(clean(str(row[6]))), word_tokenize(clean(str(row[7])))\n", " if len(before) < 2 or len(after) < 2:\n", " pred = PREDICTION\n", " else:\n", " pred = predict(model, before[-1], after[0])\n", " file_out.write(pred + '\\n')\n", "\n", "\n", "create_train_file()\n", "train_model()\n", "model = kenlm.Model('gdrive/My Drive/gonito/model.binary')\n", "make_prediction(model, \"gdrive/My Drive/gonito/dev-0/in.tsv.xz\", \"gdrive/My Drive/gonito/dev-0/out.tsv\")\n", "make_prediction(model, \"gdrive/My Drive/gonito/test-A/in.tsv.xz\", \"gdrive/My Drive/gonito/test-A/out.tsv\")" ] } ] }