diff --git a/run.py b/run.py index 7d50dc2..0519940 100644 --- a/run.py +++ b/run.py @@ -1,87 +1,278 @@ -from collections import defaultdict, Counter -from nltk import trigrams, word_tokenize -import csv -import regex as re -import pandas as pd - -X_train = pd.read_csv( - 'train/in.tsv.xz', - sep='\t', - header=None, - quoting=csv.QUOTE_NONE, - nrows=70000, - on_bad_lines='skip') - -Y_train = pd.read_csv( - 'train/expected.tsv', - sep='\t', - header=None, - quoting=csv.QUOTE_NONE, - nrows=70000, - on_bad_lines='skip') - -X_train = X_train[[6, 7]] -X_train = pd.concat([X_train, Y_train], axis=1) -X_train['row'] = X_train[6] + X_train[0] + X_train[7] - - -def preprocess(row): - return re.sub(r'\p{P}', '', row.lower().replace('-\\n', '').replace('\\n', ' ')) - - -def train(X_train, alpha): - model = defaultdict(lambda: defaultdict(lambda: 0)) - vocabulary = set() - - for _, (_, row) in enumerate(X_train.iterrows()): - text = preprocess(str(row['row'])) - words = word_tokenize(text) - for w1, w2, w3 in trigrams(words, pad_right=True, pad_left=True): - if w1 and w2 and w3: - model[(w1, w3)][w2] += 1 - vocabulary.add(w1) - vocabulary.add(w2) - vocabulary.add(w3) - - for _, w13 in enumerate(model): - count = float(sum(model[w13].values())) - denominator = count + alpha * len(vocabulary) - for w2 in model[w13]: - nominator = model[w13][w2] + alpha - model[w13][w2] = nominator / denominator - return model - - -def predict_word(before, after, model): - output = '' - p = 0.0 - Y_pred = dict(Counter(dict(model[before, after])).most_common(7)) - - for key, value in Y_pred.items(): - p += value - output += f'{key}:{value} ' - if p == 0.0: - output = 'the:0.04 be:0.04 to:0.04 and:0.02 not:0.02 or:0.02 a:0.02 :0.8' - return output - output += f':{max(1 - p, 0.01)}' - - return output - - -def prediction(file, model): - X_test = pd.read_csv(f'{file}/in.tsv.xz', sep='\t', header=None, quoting=csv.QUOTE_NONE, on_bad_lines='skip') - - with open(f'{file}/out.tsv', 'w', encoding='utf-8') as output_file: - for _, row in X_test.iterrows(): - before, after = word_tokenize(preprocess(str(row[6]))), word_tokenize(preprocess(str(row[7]))) - if len(before) < 2 or len(after) < 2: - output = 'the:0.04 be:0.04 to:0.04 and:0.02 not:0.02 or:0.02 a:0.02 :0.8' - else: - output = predict_word(before[-1], after[0], model) - output_file.write(output + '\n') - - -model = train(X_train, 0.0002) - -prediction('dev-0', model) -prediction('test-A', model) \ No newline at end of file +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1da94494-ccbd-4f3c-9ca0-2241cfd9d361", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2f51e23a-93a0-4bf6-9c87-19da220e11bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting english_words\n", + " Downloading english-words-1.1.0.tar.gz (1.1 MB)\n", + "\u001b[K |████████████████████████████████| 1.1 MB 1.5 MB/s eta 0:00:01\n", + "\u001b[?25hBuilding wheels for collected packages: english-words\n", + " Building wheel for english-words (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for english-words: filename=english_words-1.1.0-py3-none-any.whl size=1106680 sha256=ddaf5f4288a2022c2ce712aad0ba022e7b25d4d7e73c5637d6154abc5a899662\n", + " Stored in directory: /home/asadursk/.cache/pip/wheels/0e/24/52/b4989db82a438482aa65b3c6c0537e988fd40546b792747b1a\n", + "Successfully built english-words\n", + "Installing collected packages: english-words\n", + "Successfully installed english-words-1.1.0\n" + ] + } + ], + "source": [ + "!pip install english_words" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d99975a7-aebe-4e26-b330-4be7f32204c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pypi-kenlm\n", + " Downloading pypi-kenlm-0.1.20210121.tar.gz (253 kB)\n", + "\u001b[K |████████████████████████████████| 253 kB 1.6 MB/s eta 0:00:01\n", + "\u001b[?25hBuilding wheels for collected packages: pypi-kenlm\n", + " Building wheel for pypi-kenlm (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for pypi-kenlm: filename=pypi_kenlm-0.1.20210121-cp39-cp39-linux_x86_64.whl size=311921 sha256=2fcde1a0b569c5d5aef6c61014559b38efc45ed4ae90357c1219816d9a5bbe9b\n", + " Stored in directory: /home/asadursk/.cache/pip/wheels/14/f0/7a/97db71356d1dc1b0c14bf48e0d01e5561d5d67ba869e4406d0\n", + "Successfully built pypi-kenlm\n", + "Installing collected packages: pypi-kenlm\n", + "Successfully installed pypi-kenlm-0.1.20210121\n" + ] + } + ], + "source": [ + "!python -m pip install pypi-kenlm" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "84560801-85f1-409b-a9c8-c209928276cc", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict, Counter\n", + "from nltk import trigrams, word_tokenize\n", + "from english_words import english_words_alpha_set\n", + "import csv\n", + "import regex as re\n", + "import pandas as pd\n", + "import kenlm\n", + "from math import log10" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "7a39272c-7929-42d8-98ba-8304570439af", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess(row):\n", + " return re.sub(r'\\p{P}', '', row.lower().replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' '))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "2a330ad2-9b88-4fdd-bc04-635b5cb42c0d", + "metadata": {}, + "outputs": [], + "source": [ + "def kenlm_model():\n", + " with open(\"train_file.txt\", \"w+\") as f:\n", + " for text in X_train:\n", + " f.write(str(text) + \"\\n\")\n", + "\n", + " #%%\n", + " KENLM_BUILD_PATH='/home/asadursk/kenlm/build'\n", + " !$KENLM_BUILD_PATH/bin/lmplz -o 4 < train_file.txt > model.arpa\n", + " !$KENLM_BUILD_PATH/bin/build_binary model.arpa model.binary\n", + " !rm train_file.txt\n", + " \n", + " model = kenlm.Model(\"model.binary\")\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "e848ba36-f4eb-4bd6-9b19-fffea177bfa1", + "metadata": {}, + "outputs": [], + "source": [ + "def predict_word(w1, w3):\n", + " best_scores = []\n", + " for word in english_words_alpha_set:\n", + " text = ' '.join([w1, word, w3])\n", + " text_score = model.score(text, bos=False, eos=False)\n", + " if len(best_scores) < 12:\n", + " best_scores.append((word, text_score))\n", + " else:\n", + " is_better = False\n", + " worst_score = None\n", + " for score in best_scores:\n", + " if not worst_score:\n", + " worst_score = score\n", + " else:\n", + " if worst_score[1] > score[1]:\n", + " worst_score = score\n", + " if worst_score[1] < text_score:\n", + " best_scores.remove(worst_score)\n", + " best_scores.append((word, text_score))\n", + " probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)\n", + " pred_str = ''\n", + " for word, prob in probs:\n", + " pred_str += f'{word}:{prob} '\n", + " pred_str += f':{log10(0.99)}'\n", + " return pred_str" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "6babeba5-af91-4e9c-a235-781525594f45", + "metadata": {}, + "outputs": [], + "source": [ + "def word_gap_prediction(file, model):\n", + " X_test = pd.read_csv(f'{file}/in.tsv.xz', sep='\\t', header=None, quoting=csv.QUOTE_NONE, on_bad_lines=\"skip\")\n", + " with open(f'{file}/out.tsv', 'w', encoding='utf-8') as output_file:\n", + " for _, row in X_test.iterrows():\n", + " before, after = word_tokenize(preprocess(str(row[6]))), word_tokenize(preprocess(str(row[7])))\n", + " if len(before) < 2 or len(after) < 2:\n", + " output = 'to:0.015 be:0.015 the:0.015 not:0.01 and:0.02 a:0.02 :0.9'\n", + " else:\n", + " output = predict_word(before[-1], after[0])\n", + " output_file.write(output + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "8df4a04c-ae0d-46d7-8b76-1bcf6b424d7a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== 1/5 Counting and sorting n-grams ===\n", + "Reading /home/asadursk/challenging-america-word-gap-prediction-kenlm/train_file.txt\n", + "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", + "****************************************************************************************************\n", + "Unigram tokens 2787545 types 548500\n", + "=== 2/5 Calculating and sorting adjusted counts ===\n", + "Chain sizes: 1:6582000 2:865198656 3:1622247552 4:2595596032\n", + "Statistics:\n", + "1 548500 D1=0.85065 D2=1.01013 D3+=1.14959\n", + "2 1743634 D1=0.900957 D2=1.09827 D3+=1.20014\n", + "3 2511917 D1=0.957313 D2=1.22283 D3+=1.33724\n", + "4 2719775 D1=0.982576 D2=1.4205 D3+=1.65074\n", + "Memory estimate for binary LM:\n", + "type MB\n", + "probing 157 assuming -p 1.5\n", + "probing 184 assuming -r models -p 1.5\n", + "trie 82 without quantization\n", + "trie 51 assuming -q 8 -b 8 quantization \n", + "trie 74 assuming -a 22 array pointer compression\n", + "trie 43 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n", + "=== 3/5 Calculating and sorting initial probabilities ===\n", + "Chain sizes: 1:6582000 2:27898144 3:50238340 4:65274600\n", + "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", + "####################################################################################################\n", + "=== 4/5 Calculating and writing order-interpolated probabilities ===\n", + "Chain sizes: 1:6582000 2:27898144 3:50238340 4:65274600\n", + "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", + "####################################################################################################\n", + "=== 5/5 Writing ARPA model ===\n", + "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", + "****************************************************************************************************\n", + "Name:lmplz\tVmPeak:5126188 kB\tVmRSS:54384 kB\tRSSMax:1084112 kB\tuser:9.18382\tsys:2.72419\tCPU:11.9081\treal:9.09119\n", + "Reading model.arpa\n", + "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n", + "****************************************************************************************************\n", + "SUCCESS\n" + ] + } + ], + "source": [ + "X_train = pd.read_csv('train/in.tsv.xz', sep='\\t', header=None, quoting=csv.QUOTE_NONE, nrows=10000, on_bad_lines=\"skip\")\n", + "Y_train = pd.read_csv('train/expected.tsv', sep='\\t', header=None, quoting=csv.QUOTE_NONE, nrows=10000, on_bad_lines=\"skip\")\n", + "\n", + "X_train = X_train[[6, 7]]\n", + "X_train = pd.concat([X_train, Y_train], axis=1)\n", + "X_train = X_train[6] + X_train[0] + X_train[7]\n", + "\n", + "model = kenlm_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "5f9b4351-54b6-42de-8653-597b17c42766", + "metadata": {}, + "outputs": [], + "source": [ + "word_gap_prediction(\"dev-0/\", model)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "71076162-473b-40f2-93ab-0536a2172780", + "metadata": {}, + "outputs": [], + "source": [ + "word_gap_prediction(\"test-A/\", model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2481727e-94b5-49a0-9c21-0e105af6ef5b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}