352 lines
11 KiB
Plaintext
352 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import lzma\n",
|
|
"import pickle\n",
|
|
"from tqdm import tqdm\n",
|
|
"from collections import Counter, defaultdict"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def clean_text(line: str):\n",
|
|
" separated = line.split('\\t')\n",
|
|
" prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n",
|
|
" suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n",
|
|
" return prefix + ' ' + suffix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def unigrams(filename):\n",
|
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
|
" with tqdm(total=432022) as pbar:\n",
|
|
" for line in fid:\n",
|
|
" text = clean_text(line)\n",
|
|
" for word in text.split():\n",
|
|
" yield word\n",
|
|
" pbar.update(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def bigrams(filename, V: dict):\n",
|
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
|
" pbar = tqdm(total=432022)\n",
|
|
" first_word = ''\n",
|
|
" for line in fid:\n",
|
|
" text = clean_text(line)\n",
|
|
" for second_word in text.split():\n",
|
|
" if V.get(second_word) is None:\n",
|
|
" second_word = 'UNK'\n",
|
|
" if second_word:\n",
|
|
" yield first_word, second_word\n",
|
|
" first_word = second_word\n",
|
|
" pbar.update(1)\n",
|
|
" pbar.close()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def trigrams(filename, V: dict):\n",
|
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
|
" print('Trigrams')\n",
|
|
" for line in tqdm(fid, total=432022):\n",
|
|
" text = clean_text(line)\n",
|
|
" words = text.split()\n",
|
|
" for i in range(len(words)-2):\n",
|
|
" trigram = tuple(V.get(word, 'UNK') for word in words[i:i+3])\n",
|
|
" yield trigram\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def tetragrams(filename, V: dict):\n",
|
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
|
" print('Tetragrams')\n",
|
|
" for i, line in enumerate(tqdm(fid, total=432022)):\n",
|
|
" text = clean_text(line)\n",
|
|
" words = [V.get(word, 'UNK') for word in text.split()]\n",
|
|
" for first_word, second_word, third_word, fourth_word in zip(words, words[1:], words[2:], words[3:]):\n",
|
|
" yield first_word, second_word, third_word, fourth_word\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def P(first_word, second_word=None, third_word=None, fourth_word=None):\n",
|
|
" if second_word is None:\n",
|
|
" return V_common_dict.get(first_word, 0) / total\n",
|
|
" elif third_word is None:\n",
|
|
" return V2_bigrams_dict.get((first_word, second_word), 0) / V_common_dict.get(first_word, 0)\n",
|
|
" elif fourth_word is None:\n",
|
|
" return V3_trigrams_dict.get((first_word, second_word, third_word), 0) / V2_bigrams_dict.get((first_word, second_word), 0)\n",
|
|
" else:\n",
|
|
" return V4_tatragrams_dict.get((first_word, second_word, third_word, fourth_word), 0) / V3__trigrams_dict.get((first_word, second_word, third_word), 0)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def compute_tetragram_probability(tetragram):\n",
|
|
" return 0.5 * P(*tetragram) + 0.35 * P(tetragram[1], *tetragram[2:]) + \\\n",
|
|
" 0.1 * P(tetragram[2], *tetragram[3:]) + 0.05 * P(tetragram[3])\n",
|
|
"\n",
|
|
"def get_context(position, sentence):\n",
|
|
" context = []\n",
|
|
" for i in range(position-3, position):\n",
|
|
" if i < 0:\n",
|
|
" context.append('')\n",
|
|
" else:\n",
|
|
" context.append(sentence[i])\n",
|
|
" for i in range(position, position+4):\n",
|
|
" if i >= len(sentence):\n",
|
|
" context.append('')\n",
|
|
" else:\n",
|
|
" context.append(sentence[i])\n",
|
|
" return context\n",
|
|
"\n",
|
|
"def compute_candidates(left_context, right_context):\n",
|
|
" candidate_probabilities = {}\n",
|
|
" for word in V_common_dict:\n",
|
|
" tetragram = left_context[-3:] + [word] + right_context[:3]\n",
|
|
" probability = compute_tetragram_probability(tetragram)\n",
|
|
" candidate_probabilities[word] = probability\n",
|
|
" sorted_candidates = sorted(candidate_probabilities.items(), key=lambda x: x[1], reverse=True)[:5]\n",
|
|
" total_probability = sum([c[1] for c in sorted_candidates])\n",
|
|
" normalized_candidates = [(c[0], c[1] / total_probability) for c in sorted_candidates]\n",
|
|
" for index, elem in enumerate(normalized_candidates):\n",
|
|
" if 'UNK' in elem:\n",
|
|
" normalized_candidates.pop(index)\n",
|
|
" normalized_candidates.append(('', elem[1]))\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" normalized_candidates[-1] = ('', normalized_candidates[-1][1])\n",
|
|
" return ' '.join([f'{x[0]}:{x[1]}' for x in normalized_candidates])\n",
|
|
"\n",
|
|
"def candidates(left_context, right_context):\n",
|
|
" left_context = [w if V_common_dict.get(w) else 'UNK' for w in left_context]\n",
|
|
" right_context = [w if V_common_dict.get(w) else 'UNK' for w in right_context]\n",
|
|
" return compute_candidates(left_context, right_context)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def create_vocab(filename, word_limit):\n",
|
|
" V = Counter(unigrams(filename))\n",
|
|
" V_common = V.most_common(word_limit)\n",
|
|
" UNK = sum(v for k, v in V.items() if k not in dict(V_common))\n",
|
|
" V_common_dict = dict(V_common)\n",
|
|
" V_common_dict['UNK'] = UNK\n",
|
|
" V_common_tuple = tuple((k, v) for k, v in V_common_dict.items())\n",
|
|
" with open('V.pickle', 'wb') as handle:\n",
|
|
" pickle.dump(V_common_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
|
" return V_common_dict\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_pickle(filename):\n",
|
|
" with open(filename, 'rb') as handle:\n",
|
|
" return pickle.load(handle)\n",
|
|
"\n",
|
|
"def save_pickle(obj, filename):\n",
|
|
" with open(filename, 'wb') as handle:\n",
|
|
" pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"create_vocab('train/in.tsv.xz', 1000)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open('V.pickle', 'rb') as handle:\n",
|
|
" V_common_dict = pickle.load(handle)\n",
|
|
"\n",
|
|
"total = sum(V_common_dict.values())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open('V.pickle', 'rb') as handle:\n",
|
|
" V_common_tuple = pickle.load(handle)\n",
|
|
"\n",
|
|
"V_common_dict = dict(V_common_tuple)\n",
|
|
"\n",
|
|
"total = sum(V_common_dict.values())\n",
|
|
"\n",
|
|
"V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_tuple))\n",
|
|
"V2_bigrams_dict = dict(V2_bigrams)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_pickle(V2_bigrams_dict,'V2_bigrams.pickle')\n",
|
|
"\n",
|
|
"V2_bigrams_dict = load_pickle('V2_bigrams.pickle')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n",
|
|
"V2_bigrams_dict = dict(V2_bigrams)\n",
|
|
"save_pickle('V2_bigrams.pickle', V2_bigrams_dict)\n",
|
|
"\n",
|
|
"V2_bigrams_dict = load_pickle('V2_bigrams.pickle')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"V3_trigrams = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n",
|
|
"V3_trigrams_dict = dict(V3_trigrams)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_pickle(V3_trigrams_dict, 'V3_trigrams.pickle')\n",
|
|
"V3_trigrams_dict = load_pickle('V3_trigrams.pickle')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"V4_tetragrams = Counter(tetragrams('train/in.tsv.xz', V_common_dict))\n",
|
|
"V4_tetragrams_dict = dict(V4_tetragrams)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_pickle(V4_tetragrams_dict, 'V4_tetragrams.pickle')\n",
|
|
"V4_tetragrams_dict = load_pickle('V4_tetragrams.pickle')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def save_outs(folder_name):\n",
|
|
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
|
" with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n",
|
|
" for line in tqdm(fid):\n",
|
|
" separated = line.split('\\t')\n",
|
|
" prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n",
|
|
" suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n",
|
|
" left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n",
|
|
" right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n",
|
|
" w = candidates(left_context, right_context)\n",
|
|
" f.write(w + '\\n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_outs('dev-0')\n",
|
|
"save_outs('test-A')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.2"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|