kenlm 434766
This commit is contained in:
parent
6eb5a5160f
commit
da9a7ccd36
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
180
kenlm.ipynb
Normal file
180
kenlm.ipynb
Normal file
@ -0,0 +1,180 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "738b7e97",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import defaultdict, Counter\n",
|
||||
"from nltk import trigrams, word_tokenize\n",
|
||||
"from english_words import english_words_alpha_set\n",
|
||||
"import csv\n",
|
||||
"import regex as re\n",
|
||||
"import pandas as pd\n",
|
||||
"import kenlm\n",
|
||||
"from math import log10\n",
|
||||
"\n",
|
||||
"def preprocess(row):\n",
|
||||
" row = re.sub(r'\\p{P}', '', row.lower().replace('-\\\\n', '').replace('\\\\n', ' '))\n",
|
||||
" return row\n",
|
||||
"\n",
|
||||
"def kenlm_model():\n",
|
||||
" with open(\"train_file.txt\", \"w+\") as f:\n",
|
||||
" for text in X_train:\n",
|
||||
" f.write(str(text) + \"\\n\")\n",
|
||||
"\n",
|
||||
" #%%\n",
|
||||
" KENLM_BUILD_PATH='/home/przemek/kenlm/build'\n",
|
||||
" !$KENLM_BUILD_PATH/bin/lmplz -o 4 < train_file.txt > model.arpa\n",
|
||||
" !$KENLM_BUILD_PATH/bin/build_binary model.arpa model.binary\n",
|
||||
" !rm train_file.txt\n",
|
||||
" \n",
|
||||
" model = kenlm.Model(\"model.binary\")\n",
|
||||
" return model\n",
|
||||
"\n",
|
||||
"def predict_word(w1, w3):\n",
|
||||
" best_scores = []\n",
|
||||
" for word in english_words_alpha_set:\n",
|
||||
" text = ' '.join([w1, word, w3])\n",
|
||||
" text_score = model.score(text, bos=False, eos=False)\n",
|
||||
" if len(best_scores) < 12:\n",
|
||||
" best_scores.append((word, text_score))\n",
|
||||
" else:\n",
|
||||
" is_better = False\n",
|
||||
" worst_score = None\n",
|
||||
" for score in best_scores:\n",
|
||||
" if not worst_score:\n",
|
||||
" worst_score = score\n",
|
||||
" else:\n",
|
||||
" if worst_score[1] > score[1]:\n",
|
||||
" worst_score = score\n",
|
||||
" if worst_score[1] < text_score:\n",
|
||||
" best_scores.remove(worst_score)\n",
|
||||
" best_scores.append((word, text_score))\n",
|
||||
" probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)\n",
|
||||
" pred_str = ''\n",
|
||||
" for word, prob in probs:\n",
|
||||
" pred_str += f'{word}:{prob} '\n",
|
||||
" pred_str += f':{log10(0.99)}'\n",
|
||||
" return pred_str\n",
|
||||
" \n",
|
||||
"def word_gap_prediction(file, model):\n",
|
||||
" X_test = pd.read_csv(f'{file}/in.tsv.xz', sep='\\t', header=None, quoting=csv.QUOTE_NONE, error_bad_lines=False)\n",
|
||||
" with open(f'{file}/out.tsv', 'w', encoding='utf-8') as output_file:\n",
|
||||
" for _, row in X_test.iterrows():\n",
|
||||
" before, after = word_tokenize(preprocess(str(row[6]))), word_tokenize(preprocess(str(row[7])))\n",
|
||||
" if len(before) < 2 or len(after) < 2:\n",
|
||||
" output = 'the:0.04 be:0.04 to:0.04 and:0.02 not:0.02 or:0.02 a:0.02 :0.8'\n",
|
||||
" else:\n",
|
||||
" output = predict_word(before[-1], after[0])\n",
|
||||
" output_file.write(output + '\\n')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "4bb5fbab",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"=== 1/5 Counting and sorting n-grams ===\n",
|
||||
"Reading /home/przemek/challenging-america-word-gap-prediction/train_file.txt\n",
|
||||
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
||||
"****************************************************************************************************\n",
|
||||
"Unigram tokens 2787545 types 548500\n",
|
||||
"=== 2/5 Calculating and sorting adjusted counts ===\n",
|
||||
"Chain sizes: 1:6582000 2:1755888000 3:3292290048 4:5267664384\n",
|
||||
"Statistics:\n",
|
||||
"1 548500 D1=0.85065 D2=1.01013 D3+=1.14959\n",
|
||||
"2 1743634 D1=0.900957 D2=1.09827 D3+=1.20014\n",
|
||||
"3 2511917 D1=0.957313 D2=1.22283 D3+=1.33724\n",
|
||||
"4 2719775 D1=0.982576 D2=1.4205 D3+=1.65074\n",
|
||||
"Memory estimate for binary LM:\n",
|
||||
"type MB\n",
|
||||
"probing 157 assuming -p 1.5\n",
|
||||
"probing 184 assuming -r models -p 1.5\n",
|
||||
"trie 82 without quantization\n",
|
||||
"trie 51 assuming -q 8 -b 8 quantization \n",
|
||||
"trie 74 assuming -a 22 array pointer compression\n",
|
||||
"trie 43 assuming -a 22 -q 8 -b 8 array pointer compression and quantization\n",
|
||||
"=== 3/5 Calculating and sorting initial probabilities ===\n",
|
||||
"Chain sizes: 1:6582000 2:27898144 3:50238340 4:65274600\n",
|
||||
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
||||
"####################################################################################################\n",
|
||||
"=== 4/5 Calculating and writing order-interpolated probabilities ===\n",
|
||||
"Chain sizes: 1:6582000 2:27898144 3:50238340 4:65274600\n",
|
||||
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
||||
"####################################################################################################\n",
|
||||
"=== 5/5 Writing ARPA model ===\n",
|
||||
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
||||
"****************************************************************************************************\n",
|
||||
"Name:lmplz\tVmPeak:10263812 kB\tVmRSS:54588 kB\tRSSMax:2112048 kB\tuser:4.74374\tsys:1.56732\tCPU:6.31112\treal:4.63458\n",
|
||||
"Reading model.arpa\n",
|
||||
"----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n",
|
||||
"****************************************************************************************************\n",
|
||||
"SUCCESS\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"in_file = 'train/in.tsv.xz'\n",
|
||||
"out_file = 'train/expected.tsv'\n",
|
||||
"\n",
|
||||
"X_train = pd.read_csv(in_file, sep='\\t', header=None, quoting=csv.QUOTE_NONE, nrows=10000, error_bad_lines=False)\n",
|
||||
"Y_train = pd.read_csv(out_file, sep='\\t', header=None, quoting=csv.QUOTE_NONE, nrows=10000, error_bad_lines=False)\n",
|
||||
"\n",
|
||||
"X_train = X_train[[6, 7]]\n",
|
||||
"X_train = pd.concat([X_train, Y_train], axis=1)\n",
|
||||
"X_train = X_train[6] + X_train[0] + X_train[7]\n",
|
||||
"\n",
|
||||
"model = kenlm_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "93921d97",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"word_gap_prediction(\"dev-0/\", model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdfad0ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"word_gap_prediction(\"test-A/\", model)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
BIN
model.binary
Normal file
BIN
model.binary
Normal file
Binary file not shown.
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user