302 lines
9.4 KiB
Plaintext
302 lines
9.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import lzma\n",
|
|
"import pickle\n",
|
|
"import os\n",
|
|
"from collections import Counter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_line(line):\n",
|
|
" parts = line.split('\\t')\n",
|
|
" prefix = parts[6].replace(r'\\n', ' ')\n",
|
|
" suffix = parts[7].replace(r'\\n', ' ')\n",
|
|
" return prefix + ' ' + suffix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def read_words(path):\n",
|
|
" with lzma.open(path, 'rt', encoding='utf-8') as f:\n",
|
|
" for line in f:\n",
|
|
" text = get_line(line)\n",
|
|
" for word in text.split():\n",
|
|
" yield word"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections import defaultdict\n",
|
|
"import bisect\n",
|
|
"import itertools\n",
|
|
"\n",
|
|
"\n",
|
|
"def ngrams(words, n=1):\n",
|
|
" ngrams_counts = defaultdict(int)\n",
|
|
" sum_counts = defaultdict(int)\n",
|
|
" \n",
|
|
" for i in range(len(words) - n + 1):\n",
|
|
" ngram = tuple(words[i:i+n])\n",
|
|
" ngrams_counts[ngram] += 1\n",
|
|
" for j in range(1, n):\n",
|
|
" sum_counts[ngram[:j]] += 1\n",
|
|
" \n",
|
|
" for key, value in ngrams_counts.items():\n",
|
|
" sum_counts[key[:-1]] += value\n",
|
|
" \n",
|
|
" return ngrams_counts, sum_counts\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"vocab_size = 2000\n",
|
|
"if os.path.exists('model/vocab_top_1000.pkl'):\n",
|
|
" with open('model/vocab_top_1000.pkl', 'rb') as f:\n",
|
|
" vocab_top_1000 = pickle.load(f)\n",
|
|
"else:\n",
|
|
" counter = Counter(read_words('train/in.tsv.xz'))\n",
|
|
" vocab_top_1000 = dict(counter.most_common(vocab_size))\n",
|
|
" unk = 0\n",
|
|
" for word, count in counter.items():\n",
|
|
" if not vocab_top_1000.get(word):\n",
|
|
" unk += count\n",
|
|
" vocab_top_1000['<unk>'] = unk\n",
|
|
" with open('model/vocab_top_1000.pkl', 'wb') as f:\n",
|
|
" pickle.dump(vocab_top_1000, f, protocol=pickle.HIGHEST_PROTOCOL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def ngrams(filename, V: dict, n: int):\n",
|
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
|
" print(f'{n}-grams')\n",
|
|
" ngram_func = {\n",
|
|
" 1: lambda w: (w,),\n",
|
|
" 2: lambda w1, w2: (w1, w2),\n",
|
|
" 3: lambda w1, w2, w3: (w1, w2, w3),\n",
|
|
" 4: lambda w1, w2, w3, w4: (w1, w2, w3, w4),\n",
|
|
" }[n]\n",
|
|
" for line in fid:\n",
|
|
" text = get_line(line)\n",
|
|
" words = [''] * (n-1)\n",
|
|
" for word in text.split():\n",
|
|
" if V.get(word) is None:\n",
|
|
" word = '<unk>'\n",
|
|
" words.append(word)\n",
|
|
" yield ngram_func(*words[-n:])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if os.path.exists('model/bigrams.pkl'):\n",
|
|
" with open('model/bigrams.pkl', 'rb') as f:\n",
|
|
" bigrams = pickle.load(f)\n",
|
|
"else:\n",
|
|
" bigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 2))\n",
|
|
" with open('model/bigrams.pkl', 'wb') as f:\n",
|
|
" pickle.dump(bigrams, f, protocol=pickle.HIGHEST_PROTOCOL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if os.path.exists('model/trigrams.pkl'):\n",
|
|
" with open('model/trigrams.pkl', 'rb') as f:\n",
|
|
" trigrams = pickle.load(f)\n",
|
|
"else:\n",
|
|
" trigrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 3))\n",
|
|
" with open('model/trigrams.pkl', 'wb') as f:\n",
|
|
" pickle.dump(trigrams, f, protocol=pickle.HIGHEST_PROTOCOL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if os.path.exists('model/tetragrams.pkl'):\n",
|
|
" with open('model/tetragrams.pkl', 'rb') as f:\n",
|
|
" tetragrams = pickle.load(f)\n",
|
|
"else:\n",
|
|
" tetragrams = Counter(ngrams('train/in.tsv.xz', vocab_top_1000, 4))\n",
|
|
" with open('model/tetragrams.pkl', 'wb') as f:\n",
|
|
" pickle.dump(tetragrams, f, protocol=pickle.HIGHEST_PROTOCOL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections import defaultdict\n",
|
|
"import bisect\n",
|
|
"import itertools\n",
|
|
"\n",
|
|
"def probability(first, second=None, third=None, fourth=None):\n",
|
|
"# Unigram\n",
|
|
" if not second:\n",
|
|
" return vocab_top_1000.get(first, 0) / sum(vocab_top_1000.values())\n",
|
|
" \n",
|
|
" # Bigram\n",
|
|
" bigram_key = (first, second)\n",
|
|
" if bigram_key in bigrams:\n",
|
|
" if not third:\n",
|
|
" return bigrams[bigram_key] / vocab_top_1000.get(first, 0)\n",
|
|
" \n",
|
|
" # Trigram\n",
|
|
" trigram_key = (first, second, third)\n",
|
|
" if trigram_key in trigrams:\n",
|
|
" if not fourth:\n",
|
|
" return trigrams[trigram_key] / bigrams.get(bigram_key, 0)\n",
|
|
" \n",
|
|
" # Tetragram\n",
|
|
" tetragram_key = (first, second, third, fourth)\n",
|
|
" if tetragram_key in tetragrams:\n",
|
|
" return tetragrams[tetragram_key] / trigrams.get(trigram_key, 0)\n",
|
|
" \n",
|
|
" # Key not found\n",
|
|
" return 0\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def interpolate(tetragram):\n",
|
|
" first, second, third, fourth = tetragram\n",
|
|
" if first and second and third and fourth:\n",
|
|
" return 0.4 * probability(first, second, third, fourth) + 0.3 * probability(second, third, fourth) + 0.2 * probability(third, fourth) + 0.1 * probability(fourth)\n",
|
|
" elif first and second and third:\n",
|
|
" return 0.5 * probability(first, second, third) + 0.3 * probability(second, third) + 0.2 * probability(third)\n",
|
|
" elif first and second:\n",
|
|
" return 0.6 * probability(first, second) + 0.4 * probability(second) \n",
|
|
" return probability(first)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def consider_context(left_context, right_context):\n",
|
|
" first, second, third = left_context\n",
|
|
" fifth, sixth, seventh = right_context\n",
|
|
" \n",
|
|
" probs = []\n",
|
|
" for word in vocab_top_1000:\n",
|
|
" p1 = interpolate((first, second, third, word))\n",
|
|
" p2 = interpolate((second, third, word, fifth))\n",
|
|
" p3 = interpolate((third, word, fifth, sixth)) \n",
|
|
" p4 = interpolate((word, fifth, sixth, seventh))\n",
|
|
" prob = p1 * p2 * p3 * p4\n",
|
|
" probs.append((word, prob))\n",
|
|
" \n",
|
|
" probs = sorted(probs, key=lambda x: x[1], reverse=True)[:5]\n",
|
|
" total_prob = sum(prob for _, prob in probs)\n",
|
|
" \n",
|
|
" norm = [(word, prob/total_prob) for word, prob in probs]\n",
|
|
" for index, elem in enumerate(norm):\n",
|
|
" if elem[0] == '<unk>':\n",
|
|
" norm.pop(index)\n",
|
|
" norm.append(('', elem[1]))\n",
|
|
" break\n",
|
|
" else:\n",
|
|
" norm[-1] = ('', norm[-1][1])\n",
|
|
" \n",
|
|
" return ' '.join([f'{x[0]}:{x[1]}' for x in norm])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def execute(path):\n",
|
|
" with lzma.open(f'{path}/in.tsv.xz', 'rt', encoding='utf-8') as f, \\\n",
|
|
" open(f'{path}/out.tsv', 'w', encoding='utf-8') as out:\n",
|
|
" for line in f:\n",
|
|
" prefix, suffix = line.split('\\t')[6:8]\n",
|
|
" prefix = prefix.replace(r'\\n', ' ').split()[-3:]\n",
|
|
" suffix = suffix.replace(r'\\n', ' ').split()[:3]\n",
|
|
" left = [vocab_top_1000.get(x, '<unk>') for x in prefix]\n",
|
|
" right = [vocab_top_1000.get(x, '<unk>') for x in suffix]\n",
|
|
" result = consider_context(left, right)\n",
|
|
" out.write(f\"{result}\\n\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"execute('dev-0')\n",
|
|
"execute('test-A')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|