challenging-america-word-ga.../tetragrams_final.ipynb

352 lines
11 KiB
Plaintext
Raw Normal View History

2023-04-23 17:17:33 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import lzma\n",
"import pickle\n",
"from tqdm import tqdm\n",
"from collections import Counter, defaultdict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def clean_text(line: str):\n",
" separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n",
" suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')\n",
" return prefix + ' ' + suffix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def unigrams(filename):\n",
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
" with tqdm(total=432022) as pbar:\n",
" for line in fid:\n",
" text = clean_text(line)\n",
" for word in text.split():\n",
" yield word\n",
" pbar.update(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def bigrams(filename, V: dict):\n",
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
" pbar = tqdm(total=432022)\n",
" first_word = ''\n",
" for line in fid:\n",
" text = clean_text(line)\n",
" for second_word in text.split():\n",
" if V.get(second_word) is None:\n",
" second_word = 'UNK'\n",
" if second_word:\n",
" yield first_word, second_word\n",
" first_word = second_word\n",
" pbar.update(1)\n",
" pbar.close()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def trigrams(filename, V: dict):\n",
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
" print('Trigrams')\n",
" for line in tqdm(fid, total=432022):\n",
" text = clean_text(line)\n",
" words = text.split()\n",
" for i in range(len(words)-2):\n",
" trigram = tuple(V.get(word, 'UNK') for word in words[i:i+3])\n",
" yield trigram\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tetragrams(filename, V: dict):\n",
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
" print('Tetragrams')\n",
" for i, line in enumerate(tqdm(fid, total=432022)):\n",
" text = clean_text(line)\n",
" words = [V.get(word, 'UNK') for word in text.split()]\n",
" for first_word, second_word, third_word, fourth_word in zip(words, words[1:], words[2:], words[3:]):\n",
" yield first_word, second_word, third_word, fourth_word\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def P(first_word, second_word=None, third_word=None, fourth_word=None):\n",
" if second_word is None:\n",
" return V_common_dict.get(first_word, 0) / total\n",
" elif third_word is None:\n",
" return V2_bigrams_dict.get((first_word, second_word), 0) / V_common_dict.get(first_word, 0)\n",
" elif fourth_word is None:\n",
" return V3_trigrams_dict.get((first_word, second_word, third_word), 0) / V2_bigrams_dict.get((first_word, second_word), 0)\n",
" else:\n",
" return V4_tatragrams_dict.get((first_word, second_word, third_word, fourth_word), 0) / V3__trigrams_dict.get((first_word, second_word, third_word), 0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_tetragram_probability(tetragram):\n",
" return 0.5 * P(*tetragram) + 0.35 * P(tetragram[1], *tetragram[2:]) + \\\n",
" 0.1 * P(tetragram[2], *tetragram[3:]) + 0.05 * P(tetragram[3])\n",
"\n",
"def get_context(position, sentence):\n",
" context = []\n",
" for i in range(position-3, position):\n",
" if i < 0:\n",
" context.append('')\n",
" else:\n",
" context.append(sentence[i])\n",
" for i in range(position, position+4):\n",
" if i >= len(sentence):\n",
" context.append('')\n",
" else:\n",
" context.append(sentence[i])\n",
" return context\n",
"\n",
"def compute_candidates(left_context, right_context):\n",
" candidate_probabilities = {}\n",
" for word in V_common_dict:\n",
" tetragram = left_context[-3:] + [word] + right_context[:3]\n",
" probability = compute_tetragram_probability(tetragram)\n",
" candidate_probabilities[word] = probability\n",
" sorted_candidates = sorted(candidate_probabilities.items(), key=lambda x: x[1], reverse=True)[:5]\n",
" total_probability = sum([c[1] for c in sorted_candidates])\n",
" normalized_candidates = [(c[0], c[1] / total_probability) for c in sorted_candidates]\n",
" for index, elem in enumerate(normalized_candidates):\n",
" if 'UNK' in elem:\n",
" normalized_candidates.pop(index)\n",
" normalized_candidates.append(('', elem[1]))\n",
" break\n",
" else:\n",
" normalized_candidates[-1] = ('', normalized_candidates[-1][1])\n",
" return ' '.join([f'{x[0]}:{x[1]}' for x in normalized_candidates])\n",
"\n",
"def candidates(left_context, right_context):\n",
" left_context = [w if V_common_dict.get(w) else 'UNK' for w in left_context]\n",
" right_context = [w if V_common_dict.get(w) else 'UNK' for w in right_context]\n",
" return compute_candidates(left_context, right_context)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def create_vocab(filename, word_limit):\n",
" V = Counter(unigrams(filename))\n",
" V_common = V.most_common(word_limit)\n",
" UNK = sum(v for k, v in V.items() if k not in dict(V_common))\n",
" V_common_dict = dict(V_common)\n",
" V_common_dict['UNK'] = UNK\n",
" V_common_tuple = tuple((k, v) for k, v in V_common_dict.items())\n",
" with open('V.pickle', 'wb') as handle:\n",
" pickle.dump(V_common_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
" return V_common_dict\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_pickle(filename):\n",
" with open(filename, 'rb') as handle:\n",
" return pickle.load(handle)\n",
"\n",
"def save_pickle(obj, filename):\n",
" with open(filename, 'wb') as handle:\n",
" pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"create_vocab('train/in.tsv.xz', 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('V.pickle', 'rb') as handle:\n",
" V_common_dict = pickle.load(handle)\n",
"\n",
"total = sum(V_common_dict.values())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('V.pickle', 'rb') as handle:\n",
" V_common_tuple = pickle.load(handle)\n",
"\n",
"V_common_dict = dict(V_common_tuple)\n",
"\n",
"total = sum(V_common_dict.values())\n",
"\n",
"V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_tuple))\n",
"V2_bigrams_dict = dict(V2_bigrams)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_pickle(V2_bigrams_dict,'V2_bigrams.pickle')\n",
"\n",
"V2_bigrams_dict = load_pickle('V2_bigrams.pickle')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n",
"V2_bigrams_dict = dict(V2_bigrams)\n",
"save_pickle('V2_bigrams.pickle', V2_bigrams_dict)\n",
"\n",
"V2_bigrams_dict = load_pickle('V2_bigrams.pickle')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"V3_trigrams = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n",
"V3_trigrams_dict = dict(V3_trigrams)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_pickle(V3_trigrams_dict, 'V3_trigrams.pickle')\n",
"V3_trigrams_dict = load_pickle('V3_trigrams.pickle')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"V4_tetragrams = Counter(tetragrams('train/in.tsv.xz', V_common_dict))\n",
"V4_tetragrams_dict = dict(V4_tetragrams)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_pickle(V4_tetragrams_dict, 'V4_tetragrams.pickle')\n",
"V4_tetragrams_dict = load_pickle('V4_tetragrams.pickle')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def save_outs(folder_name):\n",
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
" with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n",
" for line in tqdm(fid):\n",
" separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n",
" suffix = separated[7].replace(r'\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()\n",
" left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n",
" right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n",
" w = candidates(left_context, right_context)\n",
" f.write(w + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_outs('dev-0')\n",
"save_outs('test-A')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}