challenging-america-word-ga.../tetragram_final.ipynb
2023-04-23 19:59:11 +02:00

302 lines
9.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import lzma\n",
"import pickle\n",
"import os\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_line(line):\n",
" parts = line.split('\\t')\n",
" prefix = parts[6].replace(r'\\n', ' ')\n",
" suffix = parts[7].replace(r'\\n', ' ')\n",
" return prefix + ' ' + suffix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_words(path):\n",
" with lzma.open(path, 'rt', encoding='utf-8') as f:\n",
" for line in f:\n",
" text = get_line(line)\n",
" for word in text.split():\n",
" yield word"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import bisect\n",
"import itertools\n",
"\n",
"\n",
"def ngrams(words, n=1):\n",
" ngrams_counts = defaultdict(int)\n",
" sum_counts = defaultdict(int)\n",
" \n",
" for i in range(len(words) - n + 1):\n",
" ngram = tuple(words[i:i+n])\n",
" ngrams_counts[ngram] += 1\n",
" for j in range(1, n):\n",
" sum_counts[ngram[:j]] += 1\n",
" \n",
" for key, value in ngrams_counts.items():\n",
" sum_counts[key[:-1]] += value\n",
" \n",
" return ngrams_counts, sum_counts\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vocab_size = 2000\n",
"if os.path.exists('model/vocab_top_1000.pkl'):\n",
" with open('model/vocab_top_1000.pkl', 'rb') as f:\n",
" vocab_top_1000 = pickle.load(f)\n",
"else:\n",
" counter = Counter(read_words('train/in.tsv.xz'))\n",
" vocab_top_1000 = dict(counter.most_common(vocab_size))\n",
" unk = 0\n",
" for word, count in counter.items():\n",
" if not vocab_top_1000.get(word):\n",
" unk += count\n",
" vocab_top_1000['<unk>'] = unk\n",
" with open('model/vocab_top_1000.pkl', 'wb') as f:\n",
" pickle.dump(vocab_top_1000, f, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def ngrams(filename, V: dict, n: int):\n",
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
" print(f'{n}-grams')\n",
" ngram_func = {\n",
" 1: lambda w: (w,),\n",
" 2: lambda w1, w2: (w1, w2),\n",
" 3: lambda w1, w2, w3: (w1, w2, w3),\n",
" 4: lambda w1, w2, w3, w4: (w1, w2, w3, w4),\n",
" }[n]\n",
" for line in fid:\n",
" text = get_line(line)\n",
" words = [''] * (n-1)\n",
" for word in text.split():\n",
" if V.get(word) is None:\n",
" word = '<unk>'\n",
" words.append(word)\n",
" yield ngram_func(*words[-n:])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if os.path.exists('model/bigrams.pkl'):\n",
" with open('model/bigrams.pkl', 'rb') as f:\n",
" bigrams = pickle.load(f)\n",
"else:\n",
" bigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 2))\n",
" with open('model/bigrams.pkl', 'wb') as f:\n",
" pickle.dump(bigrams, f, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if os.path.exists('model/trigrams.pkl'):\n",
" with open('model/trigrams.pkl', 'rb') as f:\n",
" trigrams = pickle.load(f)\n",
"else:\n",
" trigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 3))\n",
" with open('model/trigrams.pkl', 'wb') as f:\n",
" pickle.dump(trigrams, f, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if os.path.exists('model/tetragrams.pkl'):\n",
" with open('model/tetragrams.pkl', 'rb') as f:\n",
" tetragrams = pickle.load(f)\n",
"else:\n",
" tetragrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 4))\n",
" with open('model/tetragrams.pkl', 'wb') as f:\n",
" pickle.dump(tetragrams, f, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import bisect\n",
"import itertools\n",
"\n",
"def probability(first, second=None, third=None, fourth=None):\n",
"# Unigram\n",
" if not second:\n",
" return vocab_top_1000.get(first, 0) / sum(vocab_top_1000.values())\n",
" \n",
" # Bigram\n",
" bigram_key = (first, second)\n",
" if bigram_key in bigrams:\n",
" if not third:\n",
" return bigrams[bigram_key] / vocab_top_1000.get(first, 0)\n",
" \n",
" # Trigram\n",
" trigram_key = (first, second, third)\n",
" if trigram_key in trigrams:\n",
" if not fourth:\n",
" return trigrams[trigram_key] / bigrams.get(bigram_key, 0)\n",
" \n",
" # Tetragram\n",
" tetragram_key = (first, second, third, fourth)\n",
" if tetragram_key in tetragrams:\n",
" return tetragrams[tetragram_key] / trigrams.get(trigram_key, 0)\n",
" \n",
" # Key not found\n",
" return 0\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def interpolate(tetragram):\n",
" first, second, third, fourth = tetragram\n",
" if first and second and third and fourth:\n",
" return 0.4 * probability(first, second, third, fourth) + 0.3 * probability(second, third, fourth) + 0.2 * probability(third, fourth) + 0.1 * probability(fourth)\n",
" elif first and second and third:\n",
" return 0.5 * probability(first, second, third) + 0.3 * probability(second, third) + 0.2 * probability(third)\n",
" elif first and second:\n",
" return 0.6 * probability(first, second) + 0.4 * probability(second) \n",
" return probability(first)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def consider_context(left_context, right_context):\n",
" first, second, third = left_context\n",
" fifth, sixth, seventh = right_context\n",
" \n",
" probs = []\n",
" for word in vocab_top_1000:\n",
" p1 = interpolate((first, second, third, word))\n",
" p2 = interpolate((second, third, word, fifth))\n",
" p3 = interpolate((third, word, fifth, sixth)) \n",
" p4 = interpolate((word, fifth, sixth, seventh))\n",
" prob = p1 * p2 * p3 * p4\n",
" probs.append((word, prob))\n",
" \n",
" probs = sorted(probs, key=lambda x: x[1], reverse=True)[:5]\n",
" total_prob = sum(prob for _, prob in probs)\n",
" \n",
" norm = [(word, prob/total_prob) for word, prob in probs]\n",
" for index, elem in enumerate(norm):\n",
" if elem[0] == '<unk>':\n",
" norm.pop(index)\n",
" norm.append(('', elem[1]))\n",
" break\n",
" else:\n",
" norm[-1] = ('', norm[-1][1])\n",
" \n",
" return ' '.join([f'{x[0]}:{x[1]}' for x in norm])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def execute(path):\n",
" with lzma.open(f'{path}/in.tsv.xz', 'rt', encoding='utf-8') as f, \\\n",
" open(f'{path}/out.tsv', 'w', encoding='utf-8') as out:\n",
" for line in f:\n",
" prefix, suffix = line.split('\\t')[6:8]\n",
" prefix = prefix.replace(r'\\n', ' ').split()[-3:]\n",
" suffix = suffix.replace(r'\\n', ' ').split()[:3]\n",
" left = [vocab_top_1000.get(x, '<unk>') for x in prefix]\n",
" right = [vocab_top_1000.get(x, '<unk>') for x in suffix]\n",
" result = consider_context(left, right)\n",
" out.write(f\"{result}\\n\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"execute('dev-0')\n",
"execute('test-A')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}