{ "cells": [ { "cell_type": "markdown", "id": "2a4fb731", "metadata": {}, "source": [ "MODEL TRIGRAMOWY - uwzględniamy dwa poprzednie słowa" ] }, { "cell_type": "code", "execution_count": 1, "id": "c16d72a6", "metadata": {}, "outputs": [], "source": [ "import lzma\n", "import csv\n", "import re\n", "import math" ] }, { "cell_type": "code", "execution_count": 2, "id": "a1ff03c8", "metadata": {}, "outputs": [], "source": [ "def read_data(folder_name, test_data=False):\n", " \n", " all_data = lzma.open(f'{folder_name}/in.tsv.xz').read().decode('UTF-8').split('\\n')\n", " data = [line.split('\\t') for line in all_data][:-1]\n", " data = [[i[6].replace('\\\\n', ' '), i[7].replace('\\\\n', ' ')] for i in data]\n", " \n", " if not test_data:\n", " words = []\n", " with open(f'{folder_name}/expected.tsv') as file:\n", " tsv_file = csv.reader(file, delimiter=\"\\t\")\n", " for line in tsv_file:\n", " words.append(line[0])\n", " \n", " return data, words\n", " \n", " return data\n", "\n", "train_data, train_words = read_data('train')" ] }, { "cell_type": "code", "execution_count": 3, "id": "a4a73c19", "metadata": {}, "outputs": [], "source": [ "def print_example(data, words, idx):\n", " print(f'{data[idx][0]} _____{words[idx].upper()}_____ {data[idx][1]}')\n", " \n", "# print_example(train_data, train_words, 13)" ] }, { "cell_type": "code", "execution_count": 4, "id": "ce522af5", "metadata": {}, "outputs": [], "source": [ "def generate_N_grams(text, ngram=1, no_punctuation=True):\n", " text = re.sub(r'[\\-] ', '', text).lower()\n", " if no_punctuation:\n", " text = re.sub(r'[\\)\\(\\.\\,\\-]', ' ', text)\n", " words=[word for word in text.split()]\n", " temp=zip(*[words[i:] for i in range(0,ngram)])\n", " ans=[' '.join(ngram) for ngram in temp]\n", " return ans\n", "\n", "N_grams = []\n", "for i in range(len(train_data[:5000])):\n", " N_grams += generate_N_grams(f'{train_data[i][0]} {train_words[i]} {train_data[i][1]}', 2)\n", " N_grams += generate_N_grams(f'{train_data[i][0]} {train_words[i]} {train_data[i][1]}', 3)" ] }, { "cell_type": "code", "execution_count": 5, "id": "317ade72", "metadata": { "scrolled": true }, "outputs": [], "source": [ "def check_prob(N_grams):\n", " count = {}\n", " for i in N_grams:\n", " i = i.rsplit(maxsplit=1)\n", " if i[0] in count:\n", " if i[1] in count[i[0]]:\n", " count[i[0]][i[1]] += 1\n", " else:\n", " count[i[0]][i[1]] = 1\n", " else:\n", " count[i[0]] = {i[1]: 1}\n", " \n", " for word in count:\n", " s = sum(count[word].values())\n", " for i in count[word]:\n", " count[word][i] = count[word][i] / s\n", " \n", " return count\n", "\n", "probs = check_prob(N_grams)" ] }, { "cell_type": "code", "execution_count": 6, "id": "3a7ec4ec", "metadata": {}, "outputs": [], "source": [ "dev_data, dev_words = read_data('dev-0')" ] }, { "cell_type": "code", "execution_count": 7, "id": "86aeda02", "metadata": {}, "outputs": [], "source": [ "def find_word(word_1, word_2):\n", " tmp_probs = {}\n", " if word_1 in probs:\n", " if word_2 in probs:\n", " for i in probs[word_1]:\n", " if i in probs[word_2]:\n", " tmp_probs[i] = probs[word_1][i] * probs[word_2][i]\n", " if tmp_probs[i] == 1:\n", " tmp_probs[i] = 0.1\n", " else:\n", " c = probs[word_2][min(probs[word_2].keys(), key=(lambda k: probs[word_2][k]))] / 10\n", " tmp_probs[i] = probs[word_1][i] * c\n", " else:\n", " tmp_probs = probs[word_1]\n", " else:\n", " tmp_probs = {}\n", " \n", " sorted_list = sorted(tmp_probs.items(), key=lambda x: x[1], reverse=True)[:1]\n", " tmm = ' '.join([i[0] + ':' + str(i[1]) for i in sorted_list])\n", " s = 1 - sum(n for _, n in sorted_list)\n", " if s == 0:\n", " s = 0.01\n", " tmm += ' :' + str(s)\n", " if tmp_probs == {}:\n", " return ':1'\n", " return tmm" ] }, { "cell_type": "code", "execution_count": 8, "id": "3b713dc3", "metadata": {}, "outputs": [], "source": [ "def find_words(data):\n", " found_words = []\n", " for i in data:\n", " t = i[0]\n", " t = re.sub(r'[\\-] ', '', t).lower()\n", " if True:\n", " t = re.sub(r'[\\)\\(\\.\\,\\-]', ' ', t)\n", " words=[word for word in t.split()]\n", " found_words.append(find_word(words[-1], ' '.join(words[-2:])))\n", " return found_words\n", "\n", "dev_found_words = find_words(dev_data)" ] }, { "cell_type": "code", "execution_count": 9, "id": "17be7468", "metadata": {}, "outputs": [], "source": [ "def save_data(folder, words):\n", " f = open(f'{folder}/out.tsv', 'w')\n", " f.write('\\n'.join(words) + '\\n')\n", " f.close()\n", " \n", "save_data('dev-0', dev_found_words)" ] }, { "cell_type": "code", "execution_count": 10, "id": "b2e52242", "metadata": {}, "outputs": [], "source": [ "test_data = read_data('test-A', True)\n", "test_found_words = find_words(test_data)\n", "save_data('test-A', test_found_words)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }