diff --git a/cw/05_NDA_IE.ipynb b/cw/05_NDA_IE.ipynb index 592eb82..99a26b3 100644 --- a/cw/05_NDA_IE.ipynb +++ b/cw/05_NDA_IE.ipynb @@ -210,13 +210,6 @@ "\n", "Termin 5 maj 2021 (proszę w MS TEAMS podać link do repozytorium albo publicznego albo z dostępem dla kubapok i filipg na git.wmi)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/cw/06_klasyfikacja.ipynb b/cw/06_klasyfikacja.ipynb new file mode 100644 index 0000000..d682e7d --- /dev/null +++ b/cw/06_klasyfikacja.ipynb @@ -0,0 +1,964 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Zajęcia klasyfikacja" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zbiór kleister" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "from collections import Counter\n", + "from sklearn.metrics import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "KLEISTER_PATH = pathlib.Path('/home/kuba/Syncthing/przedmioty/2020-02/IE/applica/kleister-nda')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pytanie\n", + "\n", + "Czy jurysdykcja musi być zapisana explicite w umowie?" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_expected_jurisdiction(filepath):\n", + " dataset_expected_jurisdiction = []\n", + " with open(filepath,'r') as train_expected_file:\n", + " for line in train_expected_file:\n", + " key_values = line.rstrip('\\n').split(' ')\n", + " jurisdiction = None\n", + " for key_value in key_values:\n", + " key, value = key_value.split('=')\n", + " if key == 'jurisdiction':\n", + " jurisdiction = value\n", + " if jurisdiction is None:\n", + " jurisdiction = 'NONE'\n", + " dataset_expected_jurisdiction.append(jurisdiction)\n", + " return dataset_expected_jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "train_expected_jurisdiction = get_expected_jurisdiction(KLEISTER_PATH/'train'/'expected.tsv')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dev_expected_jurisdiction = get_expected_jurisdiction(KLEISTER_PATH/'dev-0'/'expected.tsv')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "254" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_expected_jurisdiction)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'NONE' in train_expected_jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "31" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(set(train_expected_jurisdiction))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Czy wszystkie stany muszą występować w zbiorze trenującym w zbiorze kleister?\n", + "\n", + "https://en.wikipedia.org/wiki/U.S._state\n", + "\n", + "### Jaki jest baseline?" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "train_counter = Counter(train_expected_jurisdiction)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('New_York', 43),\n", + " ('Delaware', 39),\n", + " ('California', 32),\n", + " ('Massachusetts', 15),\n", + " ('Texas', 13),\n", + " ('Illinois', 10),\n", + " ('Oregon', 9),\n", + " ('Florida', 9),\n", + " ('Pennsylvania', 9),\n", + " ('Missouri', 9),\n", + " ('Ohio', 8),\n", + " ('New_Jersey', 7),\n", + " ('Georgia', 6),\n", + " ('Indiana', 5),\n", + " ('Nevada', 5),\n", + " ('Colorado', 4),\n", + " ('Virginia', 4),\n", + " ('Washington', 4),\n", + " ('Michigan', 3),\n", + " ('Minnesota', 3),\n", + " ('Connecticut', 2),\n", + " ('Wisconsin', 2),\n", + " ('Maine', 2),\n", + " ('North_Carolina', 2),\n", + " ('Kansas', 2),\n", + " ('Utah', 2),\n", + " ('Iowa', 1),\n", + " ('Idaho', 1),\n", + " ('South_Dakota', 1),\n", + " ('South_Carolina', 1),\n", + " ('Rhode_Island', 1)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_counter.most_common(100)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "most_common_answer = train_counter.most_common(100)[0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'New_York'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "most_common_answer" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "dev_predictions_jurisdiction = [most_common_answer] * len(dev_expected_jurisdiction)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['New_York',\n", + " 'New_York',\n", + " 'Delaware',\n", + " 'Massachusetts',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'Delaware',\n", + " 'New_Jersey',\n", + " 'New_York',\n", + " 'NONE',\n", + " 'NONE',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'New_York',\n", + " 'Massachusetts',\n", + " 'Minnesota',\n", + " 'California',\n", + " 'New_York',\n", + " 'California',\n", + " 'Iowa',\n", + " 'California',\n", + " 'Virginia',\n", + " 'North_Carolina',\n", + " 'Arizona',\n", + " 'Indiana',\n", + " 'New_Jersey',\n", + " 'California',\n", + " 'Delaware',\n", + " 'Georgia',\n", + " 'New_York',\n", + " 'New_York',\n", + " 'California',\n", + " 'Minnesota',\n", + " 'California',\n", + " 'Kentucky',\n", + " 'Minnesota',\n", + " 'Ohio',\n", + " 'Michigan',\n", + " 'California',\n", + " 'Minnesota',\n", + " 'California',\n", + " 'Delaware',\n", + " 'Illinois',\n", + " 'Minnesota',\n", + " 'Texas',\n", + " 'New_Jersey',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'NONE',\n", + " 'Delaware',\n", + " 'Oregon',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'Massachusetts',\n", + " 'California',\n", + " 'NONE',\n", + " 'Delaware',\n", + " 'Illinois',\n", + " 'Idaho',\n", + " 'Washington',\n", + " 'New_York',\n", + " 'New_York',\n", + " 'California',\n", + " 'Utah',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'Virginia',\n", + " 'New_York',\n", + " 'New_York',\n", + " 'Illinois',\n", + " 'California',\n", + " 'Delaware',\n", + " 'NONE',\n", + " 'Texas',\n", + " 'California',\n", + " 'Washington',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'New_York',\n", + " 'Washington',\n", + " 'Illinois']" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dev_expected_jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "accuracy: 0.14457831325301204\n" + ] + } + ], + "source": [ + "counter = 0 \n", + "for pred, exp in zip(dev_predictions_jurisdiction, dev_expected_jurisdiction):\n", + " if pred == exp:\n", + " counter +=1\n", + "print('accuracy: ', counter/len(dev_predictions_jurisdiction))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.14457831325301204" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(dev_predictions_jurisdiction, dev_expected_jurisdiction)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Co jeżeli nazwy klas nie występują explicite w zbiorach?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "https://git.wmi.amu.edu.pl/kubapok/paranormal-or-skeptic-ISI-public\n", + " \n", + "https://git.wmi.amu.edu.pl/kubapok/sport-text-classification-ball-ISI-public" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SPORT_PATH='/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia6_klasyfikacja/repos/sport-text-classification-ball'\n", + "\n", + "SPORT_TRAIN=$SPORT_PATH/train/train.tsv.gz\n", + " \n", + "SPORT_DEV_EXP=$SPORT_PATH/dev-0/expected.tsv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### jaki jest baseline dla sport classification ball?\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "zcat $SPORT_TRAIN | awk '{print $1}' | wc -l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "zcat $SPORT_TRAIN | awk '{print $1}' | grep 1 | wc -l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "cat $SPORT_DEV_EXP | wc -l\n", + "\n", + "grep 1 $SPORT_DEV_EXP | wc -l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sprytne podejście do klasyfikacji tekstu? Naiwny bayess" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kuba/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n", + " warnings.warn(msg)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import fetch_20newsgroups\n", + "# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n", + "\n", + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "import numpy as np\n", + "import sklearn.metrics\n", + "import gensim" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "newsgroups = fetch_20newsgroups()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "newsgroups_text = newsgroups['data']" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "newsgroups_text_tokenized = [list(set(gensim.utils.tokenize(x, lowercase = True))) for x in newsgroups_text]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From: lerxst@wam.umd.edu (where's my thing)\n", + "Subject: WHAT car is this!?\n", + "Nntp-Posting-Host: rac3.wam.umd.edu\n", + "Organization: University of Maryland, College Park\n", + "Lines: 15\n", + "\n", + " I was wondering if anyone out there could enlighten me on this car I saw\n", + "the other day. It was a 2-door sports car, looked to be from the late 60s/\n", + "early 70s. It was called a Bricklin. The doors were really small. In addition,\n", + "the front bumper was separate from the rest of the body. This is \n", + "all I know. If anyone can tellme a model name, engine specs, years\n", + "of production, where this car is made, history, or whatever info you\n", + "have on this funky looking car, please e-mail.\n", + "\n", + "Thanks,\n", + "- IL\n", + " ---- brought to you by your neighborhood Lerxst ----\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(newsgroups_text[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['where', 'name', 'looked', 'to', 'have', 'out', 'on', 'by', 'park', 'what', 'from', 'host', 'doors', 'day', 'be', 'organization', 'e', 'front', 'in', 'it', 'history', 'brought', 'know', 'addition', 'il', 'of', 'lines', 'i', 'your', 'bumper', 'there', 'please', 'me', 'separate', 'is', 'tellme', 'can', 'could', 'called', 'specs', 'college', 'this', 'thanks', 'looking', 'if', 'production', 'sports', 'lerxst', 'whatever', 'anyone', 'enlighten', 'saw', 'all', 'small', 'you', 'wam', 'mail', 'rest', 's', 'late', 'rac', 'funky', 'edu', 'info', 'the', 'wondering', 'years', 'door', 'posting', 'car', 'made', 'or', 'maryland', 'subject', 'bricklin', 'was', 'model', 'thing', 'university', 'engine', 'nntp', 'other', 'really', 'neighborhood', 'early', 'a', 'umd', 'my', 'body', 'were']\n" + ] + } + ], + "source": [ + "print(newsgroups_text_tokenized[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "Y = newsgroups['target']" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([7, 4, 4, ..., 3, 1, 8])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Y" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "Y_names = newsgroups['target_names']" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['alt.atheism',\n", + " 'comp.graphics',\n", + " 'comp.os.ms-windows.misc',\n", + " 'comp.sys.ibm.pc.hardware',\n", + " 'comp.sys.mac.hardware',\n", + " 'comp.windows.x',\n", + " 'misc.forsale',\n", + " 'rec.autos',\n", + " 'rec.motorcycles',\n", + " 'rec.sport.baseball',\n", + " 'rec.sport.hockey',\n", + " 'sci.crypt',\n", + " 'sci.electronics',\n", + " 'sci.med',\n", + " 'sci.space',\n", + " 'soc.religion.christian',\n", + " 'talk.politics.guns',\n", + " 'talk.politics.mideast',\n", + " 'talk.politics.misc',\n", + " 'talk.religion.misc']" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Y_names" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'talk.politics.guns'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Y_names[16]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$P('talk.politics.guns' | 'gun')= ?$ \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "$P(A|B) * P(A) = P(B) * P(B|A)$\n", + "\n", + "$P(A|B) = \\frac{P(B) * P(B|A)}{P(A)}$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$P('talk.politics.guns' | 'gun') * P('gun') = P('gun'|'talk.politics.guns') * P('talk.politics.guns')$\n", + "\n", + "\n", + "$P('talk.politics.guns' | 'gun') = \\frac{P('gun'|'talk.politics.guns') * P('talk.politics.guns')}{P('gun')}$\n", + "\n", + "\n", + "$p1 = P('gun'|'talk.politics.guns')$\n", + "\n", + "\n", + "$p2 = P('talk.politics.guns')$\n", + "\n", + "\n", + "$p3 = P('gun')$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## obliczanie $p1 = P('gun'|'talk.politics.guns')$" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# samodzielne wykonanie" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## obliczanie $p2 = P('talk.politics.guns')$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# samodzielne wykonanie" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## obliczanie $p3 = P('gun')$" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "# samodzielne wykonanie" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ostatecznie" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'p1' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mp1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mp2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mp3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'p1' is not defined" + ] + } + ], + "source": [ + "(p1 * p2) / p3" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "def get_prob(index ):\n", + " talks_topic = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == index]\n", + "\n", + " len([x for x in talks_topic if 'gun' in x])\n", + "\n", + " if len(talks_topic) == 0:\n", + " return 0.0\n", + " p1 = len([x for x in talks_topic if 'gun' in x]) / len(talks_topic)\n", + " p2 = len(talks_topic) / len(Y)\n", + " p3 = len([x for x in newsgroups_text_tokenized if 'gun' in x]) / len(Y)\n", + "\n", + " if p3 == 0:\n", + " return 0.0\n", + " else: \n", + " return (p1 * p2)/ p3\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.01622 \t\t alt.atheism\n", + "0.00000 \t\t comp.graphics\n", + "0.00541 \t\t comp.os.ms-windows.misc\n", + "0.01892 \t\t comp.sys.ibm.pc.hardware\n", + "0.00270 \t\t comp.sys.mac.hardware\n", + "0.00000 \t\t comp.windows.x\n", + "0.01351 \t\t misc.forsale\n", + "0.04054 \t\t rec.autos\n", + "0.01892 \t\t rec.motorcycles\n", + "0.00270 \t\t rec.sport.baseball\n", + "0.00541 \t\t rec.sport.hockey\n", + "0.03784 \t\t sci.crypt\n", + "0.02973 \t\t sci.electronics\n", + "0.00541 \t\t sci.med\n", + "0.01622 \t\t sci.space\n", + "0.00270 \t\t soc.religion.christian\n", + "0.68378 \t\t talk.politics.guns\n", + "0.04595 \t\t talk.politics.mideast\n", + "0.03784 \t\t talk.politics.misc\n", + "0.01622 \t\t talk.religion.misc\n", + "1.00000 \t\tsuma\n" + ] + } + ], + "source": [ + "probs = []\n", + "for i in range(len(Y_names)):\n", + " probs.append(get_prob(i))\n", + " print(\"%.5f\" % get_prob(i),'\\t\\t', Y_names[i])\n", + " \n", + "print(\"%.5f\" % sum(probs), '\\t\\tsuma',)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### zadanie samodzielne" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def get_prob2(index, word ):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "# listing dla get_prob2, słowo 'god'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## założenie naiwnego bayesa" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$P(class | word1, word2, word3) = \\frac{P(word1, word2, word3|class) * P(class)}{P(word1, word2, word3)}$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**przy założeniu o niezależności zmiennych losowych $word1$, $word2$, $word3$**:\n", + "\n", + "\n", + "$P(word1, word2, word3|class) = P(word1|class)* P(word2|class) * P(word3|class)$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**ostatecznie:**\n", + "\n", + "\n", + "$P(class | word1, word2, word3) = \\frac{P(word1|class)* P(word2|class) * P(word3|class) * P(class)}{\\sum_k{P(word1|class_k)* P(word2|class_k) * P(word3|class_k) * P(class_k)}}$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## zadania domowe naiwny bayes1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- analogicznie zaimplementować funkcję get_prob3(index, document_tokenized), argument document_tokenized ma być zbiorem słów dokumentu. funkcja ma być naiwnym klasyfikatorem bayesowskim (w przypadku wielu słów)\n", + "- odpalić powyższy listing prawdopodobieństw z funkcją get_prob3 dla dokumentów: {'i','love','guns'} oraz {'is','there','life','after'\n", + ",'death'}\n", + "- zadanie proszę zrobić w jupyterze, wygenerować pdf (kod + wyniki odpalenia) i umieścić go jako zadanie w teams\n", + "- termin 12.05, punktów: 40\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## zadania domowe naiwny bayes1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- wybrać jedno z poniższych repozytoriów i je sforkować:\n", + " - https://git.wmi.amu.edu.pl/kubapok/paranormal-or-skeptic-ISI-public\n", + " - https://git.wmi.amu.edu.pl/kubapok/sport-text-classification-ball-ISI-public\n", + "- stworzyć klasyfikator bazujący na naiwnym bayessie (może być gotowa biblioteka)\n", + "- stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n", + "- wynik accuracy sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.67\n", + "- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n", + "termin 12.05, 40 punktów\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/cw/06_klasyfikacja_ODPOWIEDZI.ipynb b/cw/06_klasyfikacja_ODPOWIEDZI.ipynb new file mode 100644 index 0000000..f3b2299 --- /dev/null +++ b/cw/06_klasyfikacja_ODPOWIEDZI.ipynb @@ -0,0 +1,1111 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Zajęcia klasyfikacja" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zbiór kleister" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "from collections import Counter\n", + "from sklearn.metrics import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "KLEISTER_PATH = pathlib.Path('/home/kuba/Syncthing/przedmioty/2020-02/IE/applica/kleister-nda')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pytanie\n", + "\n", + "Czy jurysdykcja musi być zapisana explicite w umowie?" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_expected_jurisdiction(filepath):\n", + " dataset_expected_jurisdiction = []\n", + " with open(filepath,'r') as train_expected_file:\n", + " for line in train_expected_file:\n", + " key_values = line.rstrip('\\n').split(' ')\n", + " jurisdiction = None\n", + " for key_value in key_values:\n", + " key, value = key_value.split('=')\n", + " if key == 'jurisdiction':\n", + " jurisdiction = value\n", + " if jurisdiction is None:\n", + " jurisdiction = 'NONE'\n", + " dataset_expected_jurisdiction.append(jurisdiction)\n", + " return dataset_expected_jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "train_expected_jurisdiction = get_expected_jurisdiction(KLEISTER_PATH/'train'/'expected.tsv')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dev_expected_jurisdiction = get_expected_jurisdiction(KLEISTER_PATH/'dev-0'/'expected.tsv')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "254" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_expected_jurisdiction)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'NONE' in train_expected_jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "31" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(set(train_expected_jurisdiction))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Czy wszystkie stany muszą występować w zbiorze trenującym w zbiorze kleister?\n", + "\n", + "https://en.wikipedia.org/wiki/U.S._state\n", + "\n", + "### Jaki jest baseline?" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "train_counter = Counter(train_expected_jurisdiction)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('New_York', 43),\n", + " ('Delaware', 39),\n", + " ('California', 32),\n", + " ('Massachusetts', 15),\n", + " ('Texas', 13),\n", + " ('Illinois', 10),\n", + " ('Oregon', 9),\n", + " ('Florida', 9),\n", + " ('Pennsylvania', 9),\n", + " ('Missouri', 9),\n", + " ('Ohio', 8),\n", + " ('New_Jersey', 7),\n", + " ('Georgia', 6),\n", + " ('Indiana', 5),\n", + " ('Nevada', 5),\n", + " ('Colorado', 4),\n", + " ('Virginia', 4),\n", + " ('Washington', 4),\n", + " ('Michigan', 3),\n", + " ('Minnesota', 3),\n", + " ('Connecticut', 2),\n", + " ('Wisconsin', 2),\n", + " ('Maine', 2),\n", + " ('North_Carolina', 2),\n", + " ('Kansas', 2),\n", + " ('Utah', 2),\n", + " ('Iowa', 1),\n", + " ('Idaho', 1),\n", + " ('South_Dakota', 1),\n", + " ('South_Carolina', 1),\n", + " ('Rhode_Island', 1)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_counter.most_common(100)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "most_common_answer = train_counter.most_common(100)[0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'New_York'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "most_common_answer" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "dev_predictions_jurisdiction = [most_common_answer] * len(dev_expected_jurisdiction)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['New_York',\n", + " 'New_York',\n", + " 'Delaware',\n", + " 'Massachusetts',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'Delaware',\n", + " 'New_Jersey',\n", + " 'New_York',\n", + " 'NONE',\n", + " 'NONE',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'New_York',\n", + " 'Massachusetts',\n", + " 'Minnesota',\n", + " 'California',\n", + " 'New_York',\n", + " 'California',\n", + " 'Iowa',\n", + " 'California',\n", + " 'Virginia',\n", + " 'North_Carolina',\n", + " 'Arizona',\n", + " 'Indiana',\n", + " 'New_Jersey',\n", + " 'California',\n", + " 'Delaware',\n", + " 'Georgia',\n", + " 'New_York',\n", + " 'New_York',\n", + " 'California',\n", + " 'Minnesota',\n", + " 'California',\n", + " 'Kentucky',\n", + " 'Minnesota',\n", + " 'Ohio',\n", + " 'Michigan',\n", + " 'California',\n", + " 'Minnesota',\n", + " 'California',\n", + " 'Delaware',\n", + " 'Illinois',\n", + " 'Minnesota',\n", + " 'Texas',\n", + " 'New_Jersey',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'NONE',\n", + " 'Delaware',\n", + " 'Oregon',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'Delaware',\n", + " 'Massachusetts',\n", + " 'California',\n", + " 'NONE',\n", + " 'Delaware',\n", + " 'Illinois',\n", + " 'Idaho',\n", + " 'Washington',\n", + " 'New_York',\n", + " 'New_York',\n", + " 'California',\n", + " 'Utah',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'Virginia',\n", + " 'New_York',\n", + " 'New_York',\n", + " 'Illinois',\n", + " 'California',\n", + " 'Delaware',\n", + " 'NONE',\n", + " 'Texas',\n", + " 'California',\n", + " 'Washington',\n", + " 'Delaware',\n", + " 'Washington',\n", + " 'New_York',\n", + " 'Washington',\n", + " 'Illinois']" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dev_expected_jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "accuracy: 0.14457831325301204\n" + ] + } + ], + "source": [ + "counter = 0 \n", + "for pred, exp in zip(dev_predictions_jurisdiction, dev_expected_jurisdiction):\n", + " if pred == exp:\n", + " counter +=1\n", + "print('accuracy: ', counter/len(dev_predictions_jurisdiction))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.14457831325301204" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(dev_predictions_jurisdiction, dev_expected_jurisdiction)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Co jeżeli nazwy klas nie występują explicite w zbiorach?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "https://git.wmi.amu.edu.pl/kubapok/paranormal-or-skeptic-ISI-public\n", + " \n", + "https://git.wmi.amu.edu.pl/kubapok/sport-text-classification-ball-ISI-public" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SPORT_PATH='/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia6_klasyfikacja/repos/sport-text-classification-ball'\n", + "\n", + "SPORT_TRAIN=$SPORT_PATH/train/train.tsv.gz\n", + " \n", + "SPORT_DEV_EXP=$SPORT_PATH/dev-0/expected.tsv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### jaki jest baseline dla sport classification ball?\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "zcat $SPORT_TRAIN | awk '{print $1}' | wc -l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "zcat $SPORT_TRAIN | awk '{print $1}' | grep 1 | wc -l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "cat $SPORT_DEV_EXP | wc -l\n", + "\n", + "grep 1 $SPORT_DEV_EXP | wc -l" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sprytne podejście do klasyfikacji tekstu? Naiwny bayess" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/kuba/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n", + " warnings.warn(msg)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import fetch_20newsgroups\n", + "# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n", + "\n", + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "import numpy as np\n", + "import sklearn.metrics\n", + "import gensim" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "newsgroups = fetch_20newsgroups()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "newsgroups_text = newsgroups['data']" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "newsgroups_text_tokenized = [list(set(gensim.utils.tokenize(x, lowercase = True))) for x in newsgroups_text]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From: lerxst@wam.umd.edu (where's my thing)\n", + "Subject: WHAT car is this!?\n", + "Nntp-Posting-Host: rac3.wam.umd.edu\n", + "Organization: University of Maryland, College Park\n", + "Lines: 15\n", + "\n", + " I was wondering if anyone out there could enlighten me on this car I saw\n", + "the other day. It was a 2-door sports car, looked to be from the late 60s/\n", + "early 70s. It was called a Bricklin. The doors were really small. In addition,\n", + "the front bumper was separate from the rest of the body. This is \n", + "all I know. If anyone can tellme a model name, engine specs, years\n", + "of production, where this car is made, history, or whatever info you\n", + "have on this funky looking car, please e-mail.\n", + "\n", + "Thanks,\n", + "- IL\n", + " ---- brought to you by your neighborhood Lerxst ----\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(newsgroups_text[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['lerxst', 'on', 'be', 'name', 'brought', 'late', 'front', 'umd', 'bumper', 'door', 'there', 'subject', 'day', 'early', 'history', 'me', 'neighborhood', 'university', 'mail', 'doors', 'by', 'funky', 'if', 'engine', 'know', 'years', 'maryland', 'your', 'rest', 'is', 'info', 'body', 'have', 'tellme', 'out', 'anyone', 'small', 'wam', 'il', 'organization', 'thanks', 'park', 'made', 'whatever', 'other', 'specs', 'wondering', 'lines', 'from', 'was', 'a', 'what', 'the', 's', 'or', 'please', 'all', 'rac', 'i', 'looked', 'really', 'edu', 'where', 'to', 'e', 'my', 'it', 'car', 'addition', 'can', 'of', 'production', 'in', 'saw', 'separate', 'you', 'thing', 'posting', 'bricklin', 'could', 'enlighten', 'nntp', 'model', 'were', 'host', 'looking', 'this', 'college', 'sports', 'called']\n" + ] + } + ], + "source": [ + "print(newsgroups_text_tokenized[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "Y = newsgroups['target']" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([7, 4, 4, ..., 3, 1, 8])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Y" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "Y_names = newsgroups['target_names']" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['alt.atheism',\n", + " 'comp.graphics',\n", + " 'comp.os.ms-windows.misc',\n", + " 'comp.sys.ibm.pc.hardware',\n", + " 'comp.sys.mac.hardware',\n", + " 'comp.windows.x',\n", + " 'misc.forsale',\n", + " 'rec.autos',\n", + " 'rec.motorcycles',\n", + " 'rec.sport.baseball',\n", + " 'rec.sport.hockey',\n", + " 'sci.crypt',\n", + " 'sci.electronics',\n", + " 'sci.med',\n", + " 'sci.space',\n", + " 'soc.religion.christian',\n", + " 'talk.politics.guns',\n", + " 'talk.politics.mideast',\n", + " 'talk.politics.misc',\n", + " 'talk.religion.misc']" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Y_names" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'talk.politics.guns'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Y_names[16]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$P('talk.politics.guns' | 'gun')= ?$ \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "$P(A|B) * P(A) = P(B) * P(B|A)$\n", + "\n", + "$P(A|B) = \\frac{P(B) * P(B|A)}{P(A)}$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$P('talk.politics.guns' | 'gun') * P('gun') = P('gun'|'talk.politics.guns') * P('talk.politics.guns')$\n", + "\n", + "\n", + "$P('talk.politics.guns' | 'gun') = \\frac{P('gun'|'talk.politics.guns') * P('talk.politics.guns')}{P('gun')}$\n", + "\n", + "\n", + "$p1 = P('gun'|'talk.politics.guns')$\n", + "\n", + "\n", + "$p2 = P('talk.politics.guns')$\n", + "\n", + "\n", + "$p3 = P('gun')$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## obliczanie $p1 = P('gun'|'talk.politics.guns')$" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "talk_politics_guns = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == 16]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "546" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(talk_politics_guns)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "253" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len([x for x in talk_politics_guns if 'gun' in x])" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "p1 = len([x for x in talk_politics_guns if 'gun' in x]) / len(talk_politics_guns)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.4633699633699634" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## obliczanie $p2 = P('talk.politics.guns')$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "p2 = len(talk_politics_guns) / len(Y)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.048258794414000356" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## obliczanie $p3 = P('gun')$" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "p3 = len([x for x in newsgroups_text_tokenized if 'gun' in x]) / len(Y)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.03270284603146544" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ostatecznie" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6837837837837839" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(p1 * p2) / p3" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "def get_prob(index ):\n", + " talks_topic = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == index]\n", + "\n", + " len([x for x in talks_topic if 'gun' in x])\n", + "\n", + " if len(talks_topic) == 0:\n", + " return 0.0\n", + " p1 = len([x for x in talks_topic if 'gun' in x]) / len(talks_topic)\n", + " p2 = len(talks_topic) / len(Y)\n", + " p3 = len([x for x in newsgroups_text_tokenized if 'gun' in x]) / len(Y)\n", + "\n", + " if p3 == 0:\n", + " return 0.0\n", + " else: \n", + " return (p1 * p2)/ p3\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.01622 \t\t alt.atheism\n", + "0.00000 \t\t comp.graphics\n", + "0.00541 \t\t comp.os.ms-windows.misc\n", + "0.01892 \t\t comp.sys.ibm.pc.hardware\n", + "0.00270 \t\t comp.sys.mac.hardware\n", + "0.00000 \t\t comp.windows.x\n", + "0.01351 \t\t misc.forsale\n", + "0.04054 \t\t rec.autos\n", + "0.01892 \t\t rec.motorcycles\n", + "0.00270 \t\t rec.sport.baseball\n", + "0.00541 \t\t rec.sport.hockey\n", + "0.03784 \t\t sci.crypt\n", + "0.02973 \t\t sci.electronics\n", + "0.00541 \t\t sci.med\n", + "0.01622 \t\t sci.space\n", + "0.00270 \t\t soc.religion.christian\n", + "0.68378 \t\t talk.politics.guns\n", + "0.04595 \t\t talk.politics.mideast\n", + "0.03784 \t\t talk.politics.misc\n", + "0.01622 \t\t talk.religion.misc\n", + "1.00000 \t\tsuma\n" + ] + } + ], + "source": [ + "probs = []\n", + "for i in range(len(Y_names)):\n", + " probs.append(get_prob(i))\n", + " print(\"%.5f\" % get_prob(i),'\\t\\t', Y_names[i])\n", + " \n", + "print(\"%.5f\" % sum(probs), '\\t\\tsuma',)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "def get_prob2(index, word ):\n", + " talks_topic = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == index]\n", + "\n", + " len([x for x in talks_topic if word in x])\n", + "\n", + " if len(talks_topic) == 0:\n", + " return 0.0\n", + " p1 = len([x for x in talks_topic if word in x]) / len(talks_topic)\n", + " p2 = len(talks_topic) / len(Y)\n", + " p3 = len([x for x in newsgroups_text_tokenized if word in x]) / len(Y)\n", + "\n", + " if p3 == 0:\n", + " return 0.0\n", + " else: \n", + " return (p1 * p2)/ p3\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.20874 \t\t alt.atheism\n", + "0.00850 \t\t comp.graphics\n", + "0.00364 \t\t comp.os.ms-windows.misc\n", + "0.00850 \t\t comp.sys.ibm.pc.hardware\n", + "0.00243 \t\t comp.sys.mac.hardware\n", + "0.00485 \t\t comp.windows.x\n", + "0.00607 \t\t misc.forsale\n", + "0.01092 \t\t rec.autos\n", + "0.02063 \t\t rec.motorcycles\n", + "0.01456 \t\t rec.sport.baseball\n", + "0.01092 \t\t rec.sport.hockey\n", + "0.00485 \t\t sci.crypt\n", + "0.00364 \t\t sci.electronics\n", + "0.00364 \t\t sci.med\n", + "0.01092 \t\t sci.space\n", + "0.41748 \t\t soc.religion.christian\n", + "0.03398 \t\t talk.politics.guns\n", + "0.02791 \t\t talk.politics.mideast\n", + "0.02549 \t\t talk.politics.misc\n", + "0.17233 \t\t talk.religion.misc\n", + "1.00000 \t\tsuma\n" + ] + } + ], + "source": [ + "probs = []\n", + "for i in range(len(Y_names)):\n", + " probs.append(get_prob2(i,'god'))\n", + " print(\"%.5f\" % get_prob2(i,'god'),'\\t\\t', Y_names[i])\n", + " \n", + "print(\"%.5f\" % sum(probs), '\\t\\tsuma',)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## założenie naiwnego bayesa" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$P(class | word1, word2, word3) = \\frac{P(word1, word2, word3|class) * P(class)}{P(word1, word2, word3)}$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**przy założeniu o niezależności zmiennych losowych $word1$, $word2$, $word3$**:\n", + "\n", + "\n", + "$P(word1, word2, word3|class) = P(word1|class)* P(word2|class) * P(word3|class)$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**ostatecznie:**\n", + "\n", + "\n", + "$P(class | word1, word2, word3) = \\frac{P(word1|class)* P(word2|class) * P(word3|class) * P(class)}{\\sum_k{P(word1|class_k)* P(word2|class_k) * P(word3|class_k) * P(class_k)}}$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## zadania domowe naiwny bayes1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- analogicznie zaimplementować funkcję get_prob3(index, document_tokenized), argument document_tokenized ma być zbiorem słów dokumentu. funkcja ma być naiwnym klasyfikatorem bayesowskim (w przypadku wielu słów)\n", + "- odpalić powyższy listing prawdopodobieństw z funkcją get_prob3 dla dokumentów: {'i','love','guns'} oraz {'is','there','life','after'\n", + ",'death'}\n", + "- zadanie proszę zrobić w jupyterze, wygenerować pdf (kod + wyniki odpalenia) i umieścić go jako zadanie w teams\n", + "- termin 12.05, punktów: 40\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## zadania domowe naiwny bayes1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- wybrać jedno z poniższych repozytoriów i je sforkować:\n", + " - https://git.wmi.amu.edu.pl/kubapok/paranormal-or-skeptic-ISI-public\n", + " - https://git.wmi.amu.edu.pl/kubapok/sport-text-classification-ball-ISI-public\n", + "- stworzyć klasyfikator bazujący na naiwnym bayessie (może być gotowa biblioteka)\n", + "- stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n", + "- wynik accuracy sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.67\n", + "- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n", + "termin 12.05, 40 punktów\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}