372 lines
12 KiB
Plaintext
372 lines
12 KiB
Plaintext
|
{
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"provenance": [],
|
||
|
"gpuType": "T4",
|
||
|
"machine_shape": "hm"
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"name": "python3",
|
||
|
"display_name": "Python 3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"name": "python"
|
||
|
},
|
||
|
"accelerator": "GPU"
|
||
|
},
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "3dV_4SJ2xY_C",
|
||
|
"outputId": "28f0b228-e536-410e-8d45-a2063a04455b"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Mounted at /content/gdrive\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from google.colab import drive\n",
|
||
|
"drive.mount(\"/content/gdrive\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"# %env DATA_DIR=/content/gdrive/MyDrive/data_gralinski\n",
|
||
|
"DATA_DIR=\"/content/gdrive/MyDrive/data_gralinski/\""
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "VwdW1Qm3x9-N"
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"import lzma\n",
|
||
|
"import pickle\n",
|
||
|
"from collections import Counter"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "irsty5KcyYkR"
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def clean_line(line: str):\n",
|
||
|
" separated = line.split('\\t')\n",
|
||
|
" prefix = separated[6].replace(r'\\n', ' ')\n",
|
||
|
" suffix = separated[7].replace(r'\\n', ' ')\n",
|
||
|
" return prefix + ' ' + suffix"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "LXXtiKW3yY5J"
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def words(filename):\n",
|
||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||
|
" index = 1\n",
|
||
|
" print('Words')\n",
|
||
|
" for line in fid:\n",
|
||
|
" print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n",
|
||
|
" text = clean_line(line)\n",
|
||
|
" for word in text.split():\n",
|
||
|
" yield word\n",
|
||
|
" index += 1\n",
|
||
|
" print()\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "y9r0wmD3ycIi"
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def bigrams(filename, V: dict):\n",
|
||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||
|
" index = 1\n",
|
||
|
" print('Bigrams')\n",
|
||
|
" for line in fid:\n",
|
||
|
" print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n",
|
||
|
" text = clean_line(line)\n",
|
||
|
" first_word = ''\n",
|
||
|
" for second_word in text.split():\n",
|
||
|
" if V.get(second_word) is None:\n",
|
||
|
" second_word = 'UNK'\n",
|
||
|
" if second_word:\n",
|
||
|
" yield first_word, second_word\n",
|
||
|
" first_word = second_word\n",
|
||
|
" index += 1\n",
|
||
|
" print()"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "HE3YfiHkycKt"
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def trigrams(filename, V: dict):\n",
|
||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||
|
" index = 1\n",
|
||
|
" print('Trigrams')\n",
|
||
|
" for line in fid:\n",
|
||
|
" print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n",
|
||
|
" text = clean_line(line)\n",
|
||
|
" first_word = ''\n",
|
||
|
" second_word = ''\n",
|
||
|
" for third_word in text.split():\n",
|
||
|
" if V.get(third_word) is None:\n",
|
||
|
" third_word = 'UNK'\n",
|
||
|
" if first_word:\n",
|
||
|
" yield first_word, second_word, third_word\n",
|
||
|
" first_word = second_word\n",
|
||
|
" second_word = third_word\n",
|
||
|
" index += 1\n",
|
||
|
" print()"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "lvHvJV6XycNZ"
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def tetragrams(filename, V: dict):\n",
|
||
|
" with lzma.open(filename, mode='rt', encoding='utf-8') as fid:\n",
|
||
|
" index = 1\n",
|
||
|
" print('Tetragrams')\n",
|
||
|
" for line in fid:\n",
|
||
|
" print(f'\\rProgress: {(index / 432022 * 100):2f}%', end='')\n",
|
||
|
" text = clean_line(line)\n",
|
||
|
" first_word = ''\n",
|
||
|
" second_word = ''\n",
|
||
|
" third_word = ''\n",
|
||
|
" for fourth_word in text.split():\n",
|
||
|
" if V.get(fourth_word) is None:\n",
|
||
|
" fourth_word = 'UNK'\n",
|
||
|
" if first_word:\n",
|
||
|
" yield first_word, second_word, third_word, fourth_word\n",
|
||
|
" first_word = second_word\n",
|
||
|
" second_word = third_word\n",
|
||
|
" third_word = fourth_word\n",
|
||
|
" index += 1\n",
|
||
|
" print()"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "sOKeZN9cycP-"
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def P(first_word, second_word=None, third_word=None, fourth_word=None):\n",
|
||
|
" try:\n",
|
||
|
" if second_word is None:\n",
|
||
|
" return V_common_dict[first_word] / total\n",
|
||
|
" if third_word is None:\n",
|
||
|
" return V2_dict[(first_word, second_word)] / V_common_dict[first_word]\n",
|
||
|
" if fourth_word is None:\n",
|
||
|
" return V3_dict[(first_word, second_word, third_word)] / V2_dict[(first_word, second_word)]\n",
|
||
|
" else:\n",
|
||
|
" return V4_dict[(first_word, second_word, third_word, fourth_word)] / V3_dict[\n",
|
||
|
" (first_word, second_word, third_word)]\n",
|
||
|
" except KeyError:\n",
|
||
|
" return 0"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "MN_RftZNycSB"
|
||
|
},
|
||
|
"execution_count": 9,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def smoothed(tetragram):\n",
|
||
|
" first, second, third, fourth = tetragram\n",
|
||
|
" return 0.5 * P(first, second, third, fourth) + 0.25 * P(second, third, fourth) + 0.15 * P(third, fourth) + 0.1 * P(\n",
|
||
|
" fourth)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "n9wIsbLEycUd"
|
||
|
},
|
||
|
"execution_count": 10,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def candidates(left_context, right_context):\n",
|
||
|
" cand = {}\n",
|
||
|
" first, second, third = left_context\n",
|
||
|
" fifth, sixth, seventh = right_context\n",
|
||
|
" for word in V_common_dict:\n",
|
||
|
" p1 = smoothed((first, second, third, word))\n",
|
||
|
" p2 = smoothed((second, third, word, fifth))\n",
|
||
|
" p3 = smoothed((third, word, fifth, sixth))\n",
|
||
|
" p4 = smoothed((word, fifth, sixth, seventh))\n",
|
||
|
" cand[word] = p1 * p2 * p3 * p4\n",
|
||
|
" cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5]\n",
|
||
|
" norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand]\n",
|
||
|
" for index, elem in enumerate(norm):\n",
|
||
|
" unk = None\n",
|
||
|
" if 'UNK' in elem:\n",
|
||
|
" unk = norm.pop(index)\n",
|
||
|
" norm.append(('', unk[1]))\n",
|
||
|
" break\n",
|
||
|
" if unk is None:\n",
|
||
|
" norm[-1] = ('', norm[-1][1])\n",
|
||
|
" return ' '.join([f'{x[0]}:{x[1]}' for x in norm])"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "l490B5KFycXj"
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"def outputs(folder_name):\n",
|
||
|
" print(f'Creating outputs in {folder_name}')\n",
|
||
|
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
||
|
" with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:\n",
|
||
|
" for line in fid:\n",
|
||
|
" separated = line.split('\\t')\n",
|
||
|
" prefix = separated[6].replace(r'\\n', ' ').split()\n",
|
||
|
" suffix = separated[7].replace(r'\\n', ' ').split()\n",
|
||
|
" left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]\n",
|
||
|
" right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]\n",
|
||
|
" w = candidates(left_context, right_context)\n",
|
||
|
" f.write(w + '\\n')"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "mMC84-OzycZ5"
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"WORD_LIMIT = 3000\n",
|
||
|
"V = Counter(words(DATA_DIR+'train/in.tsv.xz'))\n",
|
||
|
"V_common_dict = dict(V.most_common(WORD_LIMIT))\n",
|
||
|
"# UNK = 0\n",
|
||
|
"# for key, value in V.items():\n",
|
||
|
"# if V_common_dict.get(key) is None:\n",
|
||
|
"# UNK += value\n",
|
||
|
"# V_common_dict['UNK'] = UNK\n",
|
||
|
"# with open('V.pickle', 'wb') as handle:\n",
|
||
|
"# pickle.dump(V_common_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"with open(DATA_DIR+'5/V.pickle', 'rb') as handle:\n",
|
||
|
" V_common_dict = pickle.load(handle)\n",
|
||
|
"\n",
|
||
|
"total = sum(V_common_dict.values())\n",
|
||
|
"\n",
|
||
|
"# V2 = Counter(bigrams('train/in.tsv.xz', V_common_dict))\n",
|
||
|
"# V2_dict = dict(V2)\n",
|
||
|
"# with open('V2.pickle', 'wb') as handle:\n",
|
||
|
"# pickle.dump(V2_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||
|
"\n",
|
||
|
"with open(DATA_DIR+'5/V2.pickle', 'rb') as handle:\n",
|
||
|
" V2_dict = pickle.load(handle)\n",
|
||
|
"\n",
|
||
|
"# V3 = Counter(trigrams('train/in.tsv.xz', V_common_dict))\n",
|
||
|
"# V3_dict = dict(V3)\n",
|
||
|
"# with open('V3.pickle', 'wb') as handle:\n",
|
||
|
"# pickle.dump(V3_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||
|
"\n",
|
||
|
"with open(DATA_DIR+'5/V3.pickle', 'rb') as handle:\n",
|
||
|
" V3_dict = pickle.load(handle)\n",
|
||
|
"\n",
|
||
|
"V4 = Counter(tetragrams(DATA_DIR+'train/in.tsv.xz', V_common_dict))\n",
|
||
|
"V4_dict = dict(V4)\n",
|
||
|
"with open('V4.pickle', 'wb') as handle:\n",
|
||
|
" pickle.dump(V4_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||
|
"\n",
|
||
|
"# with open('V4.pickle', 'rb') as handle:\n",
|
||
|
"# V4_dict = pickle.load(handle)\n",
|
||
|
"\n",
|
||
|
"\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "Fsvv3QJl7kWN",
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"outputId": "3c8387a4-5ebe-41ae-aafa-2bf3999f0025"
|
||
|
},
|
||
|
"execution_count": 15,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Words\n",
|
||
|
"Progress: 100.000000%\n",
|
||
|
"Tetragrams\n",
|
||
|
"Progress: 100.000000%\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"outputs(DATA_DIR+'dev-0')\n",
|
||
|
"outputs(DATA_DIR+'test-A')\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "UK73WsKnB8ZP",
|
||
|
"outputId": "2c3d6171-bfb2-4fcd-cfc0-49cabdf9f0a9"
|
||
|
},
|
||
|
"execution_count": 17,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Creating outputs in /content/gdrive/MyDrive/data_gralinski/dev-0\n",
|
||
|
"Creating outputs in /content/gdrive/MyDrive/data_gralinski/test-A\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
}
|