7.0 KiB
7.0 KiB
from collections import defaultdict, Counter
from nltk import trigrams, word_tokenize
from english_words import english_words_alpha_set
import csv
import regex as re
import pandas as pd
import kenlm
from math import log10
def preprocess(row):
row = re.sub(r'\p{P}', '', row.lower().replace('-\\\\n', '').replace('\\\\n', ' '))
return row
def kenlm_model():
with open("train_file.txt", "w+") as f:
for text in X_train:
f.write(str(text) + "\n")
#%%
KENLM_BUILD_PATH='/home/przemek/kenlm/build'
!$KENLM_BUILD_PATH/bin/lmplz -o 4 < train_file.txt > model.arpa
!$KENLM_BUILD_PATH/bin/build_binary model.arpa model.binary
!rm train_file.txt
model = kenlm.Model("model.binary")
return model
def predict_word(w1, w3):
best_scores = []
for word in english_words_alpha_set:
text = ' '.join([w1, word, w3])
text_score = model.score(text, bos=False, eos=False)
if len(best_scores) < 12:
best_scores.append((word, text_score))
else:
is_better = False
worst_score = None
for score in best_scores:
if not worst_score:
worst_score = score
else:
if worst_score[1] > score[1]:
worst_score = score
if worst_score[1] < text_score:
best_scores.remove(worst_score)
best_scores.append((word, text_score))
probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)
pred_str = ''
for word, prob in probs:
pred_str += f'{word}:{prob} '
pred_str += f':{log10(0.99)}'
return pred_str
def word_gap_prediction(file, model):
X_test = pd.read_csv(f'{file}/in.tsv.xz', sep='\t', header=None, quoting=csv.QUOTE_NONE, error_bad_lines=False)
with open(f'{file}/out.tsv', 'w', encoding='utf-8') as output_file:
for _, row in X_test.iterrows():
before, after = word_tokenize(preprocess(str(row[6]))), word_tokenize(preprocess(str(row[7])))
if len(before) < 2 or len(after) < 2:
output = 'the:0.04 be:0.04 to:0.04 and:0.02 not:0.02 or:0.02 a:0.02 :0.8'
else:
output = predict_word(before[-1], after[0])
output_file.write(output + '\n')
in_file = 'train/in.tsv.xz'
out_file = 'train/expected.tsv'
X_train = pd.read_csv(in_file, sep='\t', header=None, quoting=csv.QUOTE_NONE, nrows=10000, error_bad_lines=False)
Y_train = pd.read_csv(out_file, sep='\t', header=None, quoting=csv.QUOTE_NONE, nrows=10000, error_bad_lines=False)
X_train = X_train[[6, 7]]
X_train = pd.concat([X_train, Y_train], axis=1)
X_train = X_train[6] + X_train[0] + X_train[7]
model = kenlm_model()
=== 1/5 Counting and sorting n-grams === Reading /home/przemek/challenging-america-word-gap-prediction/train_file.txt ----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 **************************************************************************************************** Unigram tokens 2787545 types 548500 === 2/5 Calculating and sorting adjusted counts === Chain sizes: 1:6582000 2:1755888000 3:3292290048 4:5267664384 Statistics: 1 548500 D1=0.85065 D2=1.01013 D3+=1.14959 2 1743634 D1=0.900957 D2=1.09827 D3+=1.20014 3 2511917 D1=0.957313 D2=1.22283 D3+=1.33724 4 2719775 D1=0.982576 D2=1.4205 D3+=1.65074 Memory estimate for binary LM: type MB probing 157 assuming -p 1.5 probing 184 assuming -r models -p 1.5 trie 82 without quantization trie 51 assuming -q 8 -b 8 quantization trie 74 assuming -a 22 array pointer compression trie 43 assuming -a 22 -q 8 -b 8 array pointer compression and quantization === 3/5 Calculating and sorting initial probabilities === Chain sizes: 1:6582000 2:27898144 3:50238340 4:65274600 ----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 #################################################################################################### === 4/5 Calculating and writing order-interpolated probabilities === Chain sizes: 1:6582000 2:27898144 3:50238340 4:65274600 ----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 #################################################################################################### === 5/5 Writing ARPA model === ----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 **************************************************************************************************** Name:lmplz VmPeak:10263812 kB VmRSS:54588 kB RSSMax:2112048 kB user:4.74374 sys:1.56732 CPU:6.31112 real:4.63458 Reading model.arpa ----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 **************************************************************************************************** SUCCESS
word_gap_prediction("dev-0/", model)
word_gap_prediction("test-A/", model)