From 8c25eb8da80d0161cec69057eae8a6d469ebb9d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20Parafin=CC=81ski?= Date: Tue, 25 Apr 2023 00:27:37 +0200 Subject: [PATCH] kenLM #3 --- lab6/kenlm.ipynb | 97 +++++++++++++++++++++++++++++++++++++++++++ lab6/kenlm_script.py | 99 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 lab6/kenlm.ipynb create mode 100644 lab6/kenlm_script.py diff --git a/lab6/kenlm.ipynb b/lab6/kenlm.ipynb new file mode 100644 index 0000000..cc3b4b9 --- /dev/null +++ b/lab6/kenlm.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!xzcat -f1 ../train/in.tsv.xz | cut -f7,8 | sed 's/-\\\\n/ /g' | sed 's/\\\\n//g' | sed 's/\\\\//g' | ../../kenlm/build/bin/lmplz -o 3 > kenlm_model.arpa\n", + "!../../kenlm/build/bin/build_binary kenlm_model.arpa kenlm_model.binary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import re\n", + "\n", + "CONTRACTIONS = {\n", + " \"I'm\": \"I am\",\n", + " \"you're\": \"you are\",\n", + " \"he's\": \"he is\",\n", + " \"she's\": \"she is\",\n", + " \"it's\": \"it is\",\n", + " \"we're\": \"we are\",\n", + " \"they're\": \"they are\",\n", + " \"aren't\": \"are not\",\n", + " \"don't\": \"do not\",\n", + " \"doesn't\": \"does not\",\n", + " \"weren't\": \"were not\",\n", + " \"'ll\": \" will\",\n", + "}\n", + "\n", + "\n", + "def formalize_text(text):\n", + " # Replace contractions using regular expressions\n", + " pattern = re.compile(r'\\b(' + '|'.join(CONTRACTIONS.keys()) + r')\\b')\n", + " text = pattern.sub(lambda x: CONTRACTIONS[x.group()], text)\n", + "\n", + " # Remove hyphens at the end of lines and replace newlines with spaces\n", + " text = text.replace('-\\n', '')\n", + " text = text.replace('\\n', ' ')\n", + "\n", + " return text\n", + "\n", + "\n", + "def clean_string(text):\n", + " text = formalize_text(text)\n", + " text = re.sub(r\" -\\\\*\\\\n\", \"\", text)\n", + " text = re.sub(r\"\\\\n\", \" \", text)\n", + " text = text.strip()\n", + " return text\n", + "\n", + "\n", + "train_text = \"\"\n", + "print(\"Reading train data...\")\n", + "with open(\"../train/in.tsv\", encoding=\"utf8\", mode=\"rt\") as file, open(\"../train/expected.tsv\", encoding=\"utf8\", mode=\"rt\") as expected:\n", + " for t_line, e_line in zip(file, expected):\n", + " t_line = t_line.split(\"\\t\")\n", + " train_text += clean_string(t_line[-2]) + f\" {clean_string(e_line)} \" + clean_string(t_line[-1])\n", + "\n", + "# save train_text to file\n", + "print(\"saving to file...\")\n", + "with open(\"train_text.txt\", encoding=\"utf8\", mode=\"w\") as file:\n", + " file.write(train_text)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lab6/kenlm_script.py b/lab6/kenlm_script.py new file mode 100644 index 0000000..a7c28e3 --- /dev/null +++ b/lab6/kenlm_script.py @@ -0,0 +1,99 @@ +from tqdm import tqdm +import regex as re +from english_words import get_english_words_set +import kenlm +import pickle +import math +import numpy as np + +path = 'kenlm_model.binary' +model = kenlm.Model(path) + +CONTRACTIONS = { + "I'm": "I am", + "you're": "you are", + "he's": "he is", + "she's": "she is", + "it's": "it is", + "we're": "we are", + "they're": "they are", + "aren't": "are not", + "don't": "do not", + "doesn't": "does not", + "weren't": "were not", + "'ll": " will", +} + + +def formalize_text(text): + # Replace contractions using regular expressions + pattern = re.compile(r'\b(' + '|'.join(CONTRACTIONS.keys()) + r')\b') + text = pattern.sub(lambda x: CONTRACTIONS[x.group()], text) + + # Remove hyphens at the end of lines and replace newlines with spaces + text = text.replace('-\n', '') + text = text.replace('\n', ' ') + + return text + + +def clean_string(text): + text = formalize_text(text) + text = re.sub(r" -\\*\\n", "", text) + text = re.sub(r"\\n", " ", text) + text = text.strip() + return text + + +def p(text): + return 1 / (1 + math.exp(-(model.score(text, bos=False, eos=False)))) + + +def perplexity(text): + return model.perplexity(text) + + +def predict_probs_w1w2wi(w1, w2): + best_scores = [] + pred_str = "" + for word in V_counter: + w1w2 = ' '.join([w2, word]) + w1w2w3 = ' '.join([w1, w2, word]) + + text_score = 0.1 * p(word) + 0.3 * p(w1w2) + 0.6 * p(w1w2w3) + + if len(best_scores) < 5: + best_scores.append((word, text_score)) + else: + worst_score = best_scores[-1] + if worst_score[1] < text_score: + best_scores[-1] = (word, text_score) + best_scores = sorted(best_scores, key=lambda tup: tup[1], reverse=True) + + for word, prob in best_scores: + pred_str += f'{word}:{prob} ' + pred_str += f':{1 - sum([p for _, p in best_scores])}' + return pred_str + + +def run_predictions(source_folder): + print(f"Run predictions on {source_folder} data...") + + with open(f"{source_folder}/in.tsv", encoding="utf8", mode="rt") as file: + train_data = file.readlines() + + with open(f"{source_folder}/out.tsv", "w", encoding="utf-8") as output_file: + for line in tqdm(train_data): + line = line.split("\t") + + w1, w2 = clean_string(line[-2]).split()[-2:] + out_line = predict_probs_w1w2wi(w1, w2) + + output_file.write(out_line + "\n") + + +with open('V_3000.pickle', 'rb') as handle: + V_counter = pickle.load(handle) + +run_predictions("../dev-0") +# run_predictions("../test-A")