{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import lzma\n", "import pickle\n", "from tqdm import tqdm\n", "from collections import Counter, defaultdict" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def clean_text(line: str):\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n", " suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n", " return prefix + ' ' + suffix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def unigrams(filename):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " with tqdm(total=432022) as pbar:\n", " for line in fid:\n", " text = clean_text(line)\n", " for word in text.split():\n", " yield word\n", " pbar.update(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def bigrams(filename, V: dict):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " pbar = tqdm(total=432022)\n", " first_word = ''\n", " for line in fid:\n", " text = clean_text(line)\n", " for second_word in text.split():\n", " if V.get(second_word) is None:\n", " second_word = 'UNK'\n", " if second_word:\n", " yield first_word, second_word\n", " first_word = second_word\n", " pbar.update(1)\n", " pbar.close()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def trigrams(filename, V: dict):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " print('Trigrams')\n", " for line in tqdm(fid, total=432022):\n", " text = clean_text(line)\n", " words = text.split()\n", " for i in range(len(words)-2):\n", " trigram = tuple(V.get(word, 'UNK') for word in words[i:i+3])\n", " yield trigram\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tetragrams(filename, V: dict):\n", " with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n", " print('Tetragrams')\n", " for i, line in enumerate(tqdm(fid, total=432022)):\n", " text = clean_text(line)\n", " words = [V.get(word, 'UNK') for word in text.split()]\n", " for first_word, second_word, third_word, fourth_word in zip(words, words[1:], words[2:], words[3:]):\n", " yield first_word, second_word, third_word, fourth_word\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def P(first_word, second_word=None, third_word=None, fourth_word=None):\n", " if second_word is None:\n", " return V_common_dict.get(first_word, 0) / total\n", " elif third_word is None:\n", " return V2_bigrams_dict.get((first_word, second_word), 0) / V_common_dict.get(first_word, 0)\n", " elif fourth_word is None:\n", " return V3_trigrams_dict.get((first_word, second_word, third_word), 0) / V2_bigrams_dict.get((first_word, second_word), 0)\n", " else:\n", " return V4_tatragrams_dict.get((first_word, second_word, third_word, fourth_word), 0) / V3__trigrams_dict.get((first_word, second_word, third_word), 0)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compute_tetragram_probability(tetragram):\n", " return 0.5 * P(*tetragram) + 0.35 * P(tetragram[1], *tetragram[2:]) + \\\n", " 0.1 * P(tetragram[2], *tetragram[3:]) + 0.05 * P(tetragram[3])\n", "\n", "def get_context(position, sentence):\n", " context = []\n", " for i in range(position-3, position):\n", " if i < 0:\n", " context.append('')\n", " else:\n", " context.append(sentence[i])\n", " for i in range(position, position+4):\n", " if i >= len(sentence):\n", " context.append('')\n", " else:\n", " context.append(sentence[i])\n", " return context\n", "\n", "def compute_candidates(left_context, right_context):\n", " candidate_probabilities = {}\n", " for word in V_common_dict:\n", " tetragram = left_context[-3:] + [word] + right_context[:3]\n", " probability = compute_tetragram_probability(tetragram)\n", " candidate_probabilities[word] = probability\n", " sorted_candidates = sorted(candidate_probabilities.items(), key=lambda x: x[1], reverse=True)[:5]\n", " total_probability = sum([c[1] for c in sorted_candidates])\n", " normalized_candidates = [(c[0], c[1] / total_probability) for c in sorted_candidates]\n", " for index, elem in enumerate(normalized_candidates):\n", " if 'UNK' in elem:\n", " normalized_candidates.pop(index)\n", " normalized_candidates.append(('', elem[1]))\n", " break\n", " else:\n", " normalized_candidates[-1] = ('', normalized_candidates[-1][1])\n", " return ' '.join([f'{x[0]}:{x[1]}' for x in normalized_candidates])\n", "\n", "def candidates(left_context, right_context):\n", " left_context = [w if V_common_dict.get(w) else 'UNK' for w in left_context]\n", " right_context = [w if V_common_dict.get(w) else 'UNK' for w in right_context]\n", " return compute_candidates(left_context, right_context)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_vocab(filename, word_limit):\n", " V = Counter(unigrams(filename))\n", " V_common = V.most_common(word_limit)\n", " UNK = sum(v for k, v in V.items() if k not in dict(V_common))\n", " V_common_dict = dict(V_common)\n", " V_common_dict['UNK'] = UNK\n", " V_common_tuple = tuple((k, v) for k, v in V_common_dict.items())\n", " with open('V.pickle', 'wb') as handle:\n", " pickle.dump(V_common_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", " return V_common_dict\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_pickle(filename):\n", " with open(filename, 'rb') as handle:\n", " return pickle.load(handle)\n", "\n", "def save_pickle(obj, filename):\n", " with open(filename, 'wb') as handle:\n", " pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "create_vocab('train/in.tsv.xz', 1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('V.pickle', 'rb') as handle:\n", " V_common_dict = pickle.load(handle)\n", "\n", "total = sum(V_common_dict.values())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('V.pickle', 'rb') as handle:\n", " V_common_tuple = pickle.load(handle)\n", "\n", "V_common_dict = dict(V_common_tuple)\n", "\n", "total = sum(V_common_dict.values())\n", "\n", "V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_tuple))\n", "V2_bigrams_dict = dict(V2_bigrams)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_pickle(V2_bigrams_dict,'V2_bigrams.pickle')\n", "\n", "V2_bigrams_dict = load_pickle('V2_bigrams.pickle')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n", "V2_bigrams_dict = dict(V2_bigrams)\n", "save_pickle('V2_bigrams.pickle', V2_bigrams_dict)\n", "\n", "V2_bigrams_dict = load_pickle('V2_bigrams.pickle')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "V3_trigrams = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n", "V3_trigrams_dict = dict(V3_trigrams)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_pickle(V3_trigrams_dict, 'V3_trigrams.pickle')\n", "V3_trigrams_dict = load_pickle('V3_trigrams.pickle')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "V4_tetragrams = Counter(tetragrams('train/in.tsv.xz', V_common_dict))\n", "V4_tetragrams_dict = dict(V4_tetragrams)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_pickle(V4_tetragrams_dict, 'V4_tetragrams.pickle')\n", "V4_tetragrams_dict = load_pickle('V4_tetragrams.pickle')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_outs(folder_name):\n", " with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n", " for line in tqdm(fid):\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n", " suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n", " left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n", " right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n", " w = candidates(left_context, right_context)\n", " f.write(w + '\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_outs('dev-0')\n", "save_outs('test-A')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }