From 6ac706db118b6c4d4fa23655fbb60ee0337557df Mon Sep 17 00:00:00 2001 From: Adrian Date: Sun, 23 Apr 2023 17:17:33 +0200 Subject: [PATCH] Add tetragrams notebook --- tetragrams_final.ipynb | 351 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 351 insertions(+) create mode 100644 tetragrams_final.ipynb diff --git a/tetragrams_final.ipynb b/tetragrams_final.ipynb new file mode 100644 index 0000000..c5bf4b0 --- /dev/null +++ b/tetragrams_final.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import lzma\n", + "import pickle\n", + "from tqdm import tqdm\n", + "from collections import Counter, defaultdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def clean_text(line: str):\n", + " separated = line.split('\\t')\n", + " prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n", + " suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n", + " return prefix + ' ' + suffix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def unigrams(filename):\n", + " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", + " with tqdm(total=432022) as pbar:\n", + " for line in fid:\n", + " text = clean_text(line)\n", + " for word in text.split():\n", + " yield word\n", + " pbar.update(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bigrams(filename, V: dict):\n", + " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", + " pbar = tqdm(total=432022)\n", + " first_word = ''\n", + " for line in fid:\n", + " text = clean_text(line)\n", + " for second_word in text.split():\n", + " if V.get(second_word) is None:\n", + " second_word = 'UNK'\n", + " if second_word:\n", + " yield first_word, second_word\n", + " first_word = second_word\n", + " pbar.update(1)\n", + " pbar.close()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def trigrams(filename, V: dict):\n", + " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", + " print('Trigrams')\n", + " for line in tqdm(fid, total=432022):\n", + " text = clean_text(line)\n", + " words = text.split()\n", + " for i in range(len(words)-2):\n", + " trigram = tuple(V.get(word, 'UNK') for word in words[i:i+3])\n", + " yield trigram\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def tetragrams(filename, V: dict):\n", + " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", + " print('Tetragrams')\n", + " for i, line in enumerate(tqdm(fid, total=432022)):\n", + " text = clean_text(line)\n", + " words = [V.get(word, 'UNK') for word in text.split()]\n", + " for first_word, second_word, third_word, fourth_word in zip(words, words[1:], words[2:], words[3:]):\n", + " yield first_word, second_word, third_word, fourth_word\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def P(first_word, second_word=None, third_word=None, fourth_word=None):\n", + " if second_word is None:\n", + " return V_common_dict.get(first_word, 0) / total\n", + " elif third_word is None:\n", + " return V2_bigrams_dict.get((first_word, second_word), 0) / V_common_dict.get(first_word, 0)\n", + " elif fourth_word is None:\n", + " return V3_trigrams_dict.get((first_word, second_word, third_word), 0) / V2_bigrams_dict.get((first_word, second_word), 0)\n", + " else:\n", + " return V4_tatragrams_dict.get((first_word, second_word, third_word, fourth_word), 0) / V3__trigrams_dict.get((first_word, second_word, third_word), 0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_tetragram_probability(tetragram):\n", + " return 0.5 * P(*tetragram) + 0.35 * P(tetragram[1], *tetragram[2:]) + \\\n", + " 0.1 * P(tetragram[2], *tetragram[3:]) + 0.05 * P(tetragram[3])\n", + "\n", + "def get_context(position, sentence):\n", + " context = []\n", + " for i in range(position-3, position):\n", + " if i < 0:\n", + " context.append('')\n", + " else:\n", + " context.append(sentence[i])\n", + " for i in range(position, position+4):\n", + " if i >= len(sentence):\n", + " context.append('')\n", + " else:\n", + " context.append(sentence[i])\n", + " return context\n", + "\n", + "def compute_candidates(left_context, right_context):\n", + " candidate_probabilities = {}\n", + " for word in V_common_dict:\n", + " tetragram = left_context[-3:] + [word] + right_context[:3]\n", + " probability = compute_tetragram_probability(tetragram)\n", + " candidate_probabilities[word] = probability\n", + " sorted_candidates = sorted(candidate_probabilities.items(), key=lambda x: x[1], reverse=True)[:5]\n", + " total_probability = sum([c[1] for c in sorted_candidates])\n", + " normalized_candidates = [(c[0], c[1] / total_probability) for c in sorted_candidates]\n", + " for index, elem in enumerate(normalized_candidates):\n", + " if 'UNK' in elem:\n", + " normalized_candidates.pop(index)\n", + " normalized_candidates.append(('', elem[1]))\n", + " break\n", + " else:\n", + " normalized_candidates[-1] = ('', normalized_candidates[-1][1])\n", + " return ' '.join([f'{x[0]}:{x[1]}' for x in normalized_candidates])\n", + "\n", + "def candidates(left_context, right_context):\n", + " left_context = [w if V_common_dict.get(w) else 'UNK' for w in left_context]\n", + " right_context = [w if V_common_dict.get(w) else 'UNK' for w in right_context]\n", + " return compute_candidates(left_context, right_context)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_vocab(filename, word_limit):\n", + " V = Counter(unigrams(filename))\n", + " V_common = V.most_common(word_limit)\n", + " UNK = sum(v for k, v in V.items() if k not in dict(V_common))\n", + " V_common_dict = dict(V_common)\n", + " V_common_dict['UNK'] = UNK\n", + " V_common_tuple = tuple((k, v) for k, v in V_common_dict.items())\n", + " with open('V.pickle', 'wb') as handle:\n", + " pickle.dump(V_common_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", + " return V_common_dict\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_pickle(filename):\n", + " with open(filename, 'rb') as handle:\n", + " return pickle.load(handle)\n", + "\n", + "def save_pickle(obj, filename):\n", + " with open(filename, 'wb') as handle:\n", + " pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "create_vocab('train/in.tsv.xz', 1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('V.pickle', 'rb') as handle:\n", + " V_common_dict = pickle.load(handle)\n", + "\n", + "total = sum(V_common_dict.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('V.pickle', 'rb') as handle:\n", + " V_common_tuple = pickle.load(handle)\n", + "\n", + "V_common_dict = dict(V_common_tuple)\n", + "\n", + "total = sum(V_common_dict.values())\n", + "\n", + "V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_tuple))\n", + "V2_bigrams_dict = dict(V2_bigrams)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_pickle(V2_bigrams_dict,'V2_bigrams.pickle')\n", + "\n", + "V2_bigrams_dict = load_pickle('V2_bigrams.pickle')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n", + "V2_bigrams_dict = dict(V2_bigrams)\n", + "save_pickle('V2_bigrams.pickle', V2_bigrams_dict)\n", + "\n", + "V2_bigrams_dict = load_pickle('V2_bigrams.pickle')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "V3_trigrams = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n", + "V3_trigrams_dict = dict(V3_trigrams)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_pickle(V3_trigrams_dict, 'V3_trigrams.pickle')\n", + "V3_trigrams_dict = load_pickle('V3_trigrams.pickle')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "V4_tetragrams = Counter(tetragrams('train/in.tsv.xz', V_common_dict))\n", + "V4_tetragrams_dict = dict(V4_tetragrams)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_pickle(V4_tetragrams_dict, 'V4_tetragrams.pickle')\n", + "V4_tetragrams_dict = load_pickle('V4_tetragrams.pickle')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def save_outs(folder_name):\n", + " with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", + " with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n", + " for line in tqdm(fid):\n", + " separated = line.split('\\t')\n", + " prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n", + " suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n", + " left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n", + " right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n", + " w = candidates(left_context, right_context)\n", + " f.write(w + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_outs('dev-0')\n", + "save_outs('test-A')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}