From 77852bcc1e69de96825f1bb13d11acdce05a3f43 Mon Sep 17 00:00:00 2001 From: bartosz-karwacki Date: Mon, 25 Apr 2022 10:36:26 +0200 Subject: [PATCH] v2 --- run2.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/run2.py b/run2.py index 43dcf05..eaa805a 100644 --- a/run2.py +++ b/run2.py @@ -3,8 +3,8 @@ import csv import regex as re import kenlm from english_words import english_words_alpha_set -from nltk import trigrams, word_tokenize - +from nltk import word_tokenize +from math import log10 from pathlib import Path import os @@ -13,6 +13,8 @@ KENLM_BUILD_PATH = Path("/home/bartek/Pulpit/challenging-america-word-gap-predic KENLM_LMPLZ_PATH = KENLM_BUILD_PATH / "bin" / "lmplz" KENLM_BUILD_BINARY_PATH = KENLM_BUILD_PATH / "bin" / "build_binary" SUDO_PASSWORD = "" +PREDICTION = 'the:0.03 be:0.03 to:0.03 of:0.025 and:0.025 a:0.025 in:0.020 that:0.020 have:0.015 I:0.010 it:0.010 for:0.010 not:0.010 on:0.010 with:0.010 he:0.010 as:0.010 you:0.010 do:0.010 at:0.010 :0.77' + def clean(text): text = str(text).lower().replace("-\\n", "").replace("\\n", " ") @@ -56,5 +58,47 @@ def train_model(): os.system('echo %s|sudo -S %s' % (SUDO_PASSWORD, build_binary_command)) -# create_train_file() -# train_model() +def predict(model, before, after): + prob = 0.0 + best = [] + for word in english_words_alpha_set: + text = ' '.join([before, word, after]) + text_score = model.score(text, bos=False, eos=False) + if len(best) < 12: + best.append((word, text_score)) + else: + worst_score = None + for score in best: + if not worst_score: + worst_score = score + else: + if worst_score[1] > score[1]: + worst_score = score + if worst_score[1] < text_score: + best.remove(worst_score) + best.append((word, text_score)) + probs = sorted(best, key=lambda tup: tup[1], reverse=True) + pred_str = '' + for word, prob in probs: + pred_str += f'{word}:{prob} ' + pred_str += f':{log10(0.99)}' + return pred_str + +def make_prediction(model, path, result_path): + data = pd.read_csv(path, sep='\t', header=None, quoting=csv.QUOTE_NONE) + with open(result_path, 'w', encoding='utf-8') as file_out: + for _, row in data.iterrows(): + before, after = word_tokenize(clean(str(row[6]))), word_tokenize(clean(str(row[7]))) + if len(before) < 2 or len(after) < 2: + pred = PREDICTION + else: + pred = predict(model, before[-1], after[0]) + file_out.write(pred + '\n') + + +if __name__ == "__main__": + create_train_file() + train_model() + model = kenlm.Model('model.arpa') + make_prediction(model, "dev-0/in.tsv.xz", "dev-0/out.tsv") + make_prediction(model, "test-A/in.tsv.xz", "test-A/out.tsv") \ No newline at end of file