{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zajęcia klasyfikacja"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zbiór kleister"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "from collections import Counter\n",
    "from sklearn.metrics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "KLEISTER_PATH = pathlib.Path('/home/kuba/Syncthing/przedmioty/2020-02/IE/applica/kleister-nda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pytanie\n",
    "\n",
    "Czy jurysdykcja musi być zapisana explicite w umowie?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_expected_jurisdiction(filepath):\n",
    "    dataset_expected_jurisdiction = []\n",
    "    with open(filepath,'r') as train_expected_file:\n",
    "        for line in train_expected_file:\n",
    "            key_values = line.rstrip('\\n').split(' ')\n",
    "            jurisdiction = None\n",
    "            for key_value in key_values:\n",
    "                key, value = key_value.split('=')\n",
    "                if key == 'jurisdiction':\n",
    "                    jurisdiction = value\n",
    "            if jurisdiction is None:\n",
    "                jurisdiction = 'NONE'\n",
    "            dataset_expected_jurisdiction.append(jurisdiction)\n",
    "    return dataset_expected_jurisdiction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_expected_jurisdiction = get_expected_jurisdiction(KLEISTER_PATH/'train'/'expected.tsv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_expected_jurisdiction = get_expected_jurisdiction(KLEISTER_PATH/'dev-0'/'expected.tsv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "254"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_expected_jurisdiction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'NONE' in train_expected_jurisdiction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(set(train_expected_jurisdiction))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Czy wszystkie stany muszą występować w zbiorze trenującym w zbiorze kleister?\n",
    "\n",
    "https://en.wikipedia.org/wiki/U.S._state\n",
    "\n",
    "### Jaki jest baseline?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_counter = Counter(train_expected_jurisdiction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('New_York', 43),\n",
       " ('Delaware', 39),\n",
       " ('California', 32),\n",
       " ('Massachusetts', 15),\n",
       " ('Texas', 13),\n",
       " ('Illinois', 10),\n",
       " ('Oregon', 9),\n",
       " ('Florida', 9),\n",
       " ('Pennsylvania', 9),\n",
       " ('Missouri', 9),\n",
       " ('Ohio', 8),\n",
       " ('New_Jersey', 7),\n",
       " ('Georgia', 6),\n",
       " ('Indiana', 5),\n",
       " ('Nevada', 5),\n",
       " ('Colorado', 4),\n",
       " ('Virginia', 4),\n",
       " ('Washington', 4),\n",
       " ('Michigan', 3),\n",
       " ('Minnesota', 3),\n",
       " ('Connecticut', 2),\n",
       " ('Wisconsin', 2),\n",
       " ('Maine', 2),\n",
       " ('North_Carolina', 2),\n",
       " ('Kansas', 2),\n",
       " ('Utah', 2),\n",
       " ('Iowa', 1),\n",
       " ('Idaho', 1),\n",
       " ('South_Dakota', 1),\n",
       " ('South_Carolina', 1),\n",
       " ('Rhode_Island', 1)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_counter.most_common(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "most_common_answer = train_counter.most_common(100)[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'New_York'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "most_common_answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_predictions_jurisdiction = [most_common_answer] * len(dev_expected_jurisdiction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['New_York',\n",
       " 'New_York',\n",
       " 'Delaware',\n",
       " 'Massachusetts',\n",
       " 'Delaware',\n",
       " 'Washington',\n",
       " 'Delaware',\n",
       " 'New_Jersey',\n",
       " 'New_York',\n",
       " 'NONE',\n",
       " 'NONE',\n",
       " 'Delaware',\n",
       " 'Delaware',\n",
       " 'Delaware',\n",
       " 'New_York',\n",
       " 'Massachusetts',\n",
       " 'Minnesota',\n",
       " 'California',\n",
       " 'New_York',\n",
       " 'California',\n",
       " 'Iowa',\n",
       " 'California',\n",
       " 'Virginia',\n",
       " 'North_Carolina',\n",
       " 'Arizona',\n",
       " 'Indiana',\n",
       " 'New_Jersey',\n",
       " 'California',\n",
       " 'Delaware',\n",
       " 'Georgia',\n",
       " 'New_York',\n",
       " 'New_York',\n",
       " 'California',\n",
       " 'Minnesota',\n",
       " 'California',\n",
       " 'Kentucky',\n",
       " 'Minnesota',\n",
       " 'Ohio',\n",
       " 'Michigan',\n",
       " 'California',\n",
       " 'Minnesota',\n",
       " 'California',\n",
       " 'Delaware',\n",
       " 'Illinois',\n",
       " 'Minnesota',\n",
       " 'Texas',\n",
       " 'New_Jersey',\n",
       " 'Delaware',\n",
       " 'Washington',\n",
       " 'NONE',\n",
       " 'Delaware',\n",
       " 'Oregon',\n",
       " 'Delaware',\n",
       " 'Delaware',\n",
       " 'Delaware',\n",
       " 'Massachusetts',\n",
       " 'California',\n",
       " 'NONE',\n",
       " 'Delaware',\n",
       " 'Illinois',\n",
       " 'Idaho',\n",
       " 'Washington',\n",
       " 'New_York',\n",
       " 'New_York',\n",
       " 'California',\n",
       " 'Utah',\n",
       " 'Delaware',\n",
       " 'Washington',\n",
       " 'Virginia',\n",
       " 'New_York',\n",
       " 'New_York',\n",
       " 'Illinois',\n",
       " 'California',\n",
       " 'Delaware',\n",
       " 'NONE',\n",
       " 'Texas',\n",
       " 'California',\n",
       " 'Washington',\n",
       " 'Delaware',\n",
       " 'Washington',\n",
       " 'New_York',\n",
       " 'Washington',\n",
       " 'Illinois']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dev_expected_jurisdiction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy:  0.14457831325301204\n"
     ]
    }
   ],
   "source": [
    "counter = 0 \n",
    "for pred, exp in zip(dev_predictions_jurisdiction, dev_expected_jurisdiction):\n",
    "    if pred == exp:\n",
    "        counter +=1\n",
    "print('accuracy: ', counter/len(dev_predictions_jurisdiction))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.14457831325301204"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(dev_predictions_jurisdiction, dev_expected_jurisdiction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Co jeżeli nazwy klas nie występują explicite w zbiorach?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "https://git.wmi.amu.edu.pl/kubapok/paranormal-or-skeptic-ISI-public\n",
    "    \n",
    "https://git.wmi.amu.edu.pl/kubapok/sport-text-classification-ball-ISI-public"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SPORT_PATH='/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia6_klasyfikacja/repos/sport-text-classification-ball'\n",
    "\n",
    "SPORT_TRAIN=$SPORT_PATH/train/train.tsv.gz\n",
    "    \n",
    "SPORT_DEV_EXP=$SPORT_PATH/dev-0/expected.tsv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### jaki jest baseline dla sport classification ball?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "zcat  $SPORT_TRAIN | awk '{print $1}'  | wc -l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "zcat  $SPORT_TRAIN | awk '{print $1}'  | grep 1 | wc -l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "cat  $SPORT_DEV_EXP | wc -l\n",
    "\n",
    "grep 1  $SPORT_DEV_EXP | wc -l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sprytne podejście do klasyfikacji tekstu? Naiwny bayess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kuba/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
    "\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "import numpy as np\n",
    "import sklearn.metrics\n",
    "import gensim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "newsgroups = fetch_20newsgroups()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "newsgroups_text = newsgroups['data']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "newsgroups_text_tokenized = [list(set(gensim.utils.tokenize(x, lowercase = True))) for x in newsgroups_text]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: lerxst@wam.umd.edu (where's my thing)\n",
      "Subject: WHAT car is this!?\n",
      "Nntp-Posting-Host: rac3.wam.umd.edu\n",
      "Organization: University of Maryland, College Park\n",
      "Lines: 15\n",
      "\n",
      " I was wondering if anyone out there could enlighten me on this car I saw\n",
      "the other day. It was a 2-door sports car, looked to be from the late 60s/\n",
      "early 70s. It was called a Bricklin. The doors were really small. In addition,\n",
      "the front bumper was separate from the rest of the body. This is \n",
      "all I know. If anyone can tellme a model name, engine specs, years\n",
      "of production, where this car is made, history, or whatever info you\n",
      "have on this funky looking car, please e-mail.\n",
      "\n",
      "Thanks,\n",
      "- IL\n",
      "   ---- brought to you by your neighborhood Lerxst ----\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(newsgroups_text[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['lerxst', 'on', 'be', 'name', 'brought', 'late', 'front', 'umd', 'bumper', 'door', 'there', 'subject', 'day', 'early', 'history', 'me', 'neighborhood', 'university', 'mail', 'doors', 'by', 'funky', 'if', 'engine', 'know', 'years', 'maryland', 'your', 'rest', 'is', 'info', 'body', 'have', 'tellme', 'out', 'anyone', 'small', 'wam', 'il', 'organization', 'thanks', 'park', 'made', 'whatever', 'other', 'specs', 'wondering', 'lines', 'from', 'was', 'a', 'what', 'the', 's', 'or', 'please', 'all', 'rac', 'i', 'looked', 'really', 'edu', 'where', 'to', 'e', 'my', 'it', 'car', 'addition', 'can', 'of', 'production', 'in', 'saw', 'separate', 'you', 'thing', 'posting', 'bricklin', 'could', 'enlighten', 'nntp', 'model', 'were', 'host', 'looking', 'this', 'college', 'sports', 'called']\n"
     ]
    }
   ],
   "source": [
    "print(newsgroups_text_tokenized[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = newsgroups['target']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7, 4, 4, ..., 3, 1, 8])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_names = newsgroups['target_names']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['alt.atheism',\n",
       " 'comp.graphics',\n",
       " 'comp.os.ms-windows.misc',\n",
       " 'comp.sys.ibm.pc.hardware',\n",
       " 'comp.sys.mac.hardware',\n",
       " 'comp.windows.x',\n",
       " 'misc.forsale',\n",
       " 'rec.autos',\n",
       " 'rec.motorcycles',\n",
       " 'rec.sport.baseball',\n",
       " 'rec.sport.hockey',\n",
       " 'sci.crypt',\n",
       " 'sci.electronics',\n",
       " 'sci.med',\n",
       " 'sci.space',\n",
       " 'soc.religion.christian',\n",
       " 'talk.politics.guns',\n",
       " 'talk.politics.mideast',\n",
       " 'talk.politics.misc',\n",
       " 'talk.religion.misc']"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'talk.politics.guns'"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_names[16]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$P('talk.politics.guns' | 'gun')=  ?$ \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "$P(A|B) * P(A) = P(B) * P(B|A)$\n",
    "\n",
    "$P(A|B) = \\frac{P(B) * P(B|A)}{P(A)}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$P('talk.politics.guns' | 'gun') * P('gun') = P('gun'|'talk.politics.guns') * P('talk.politics.guns')$\n",
    "\n",
    "\n",
    "$P('talk.politics.guns' | 'gun')  = \\frac{P('gun'|'talk.politics.guns') * P('talk.politics.guns')}{P('gun')}$\n",
    "\n",
    "\n",
    "$p1 = P('gun'|'talk.politics.guns')$\n",
    "\n",
    "\n",
    "$p2 = P('talk.politics.guns')$\n",
    "\n",
    "\n",
    "$p3 = P('gun')$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## obliczanie $p1 = P('gun'|'talk.politics.guns')$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "talk_politics_guns = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == 16]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "546"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(talk_politics_guns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "253"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len([x for x in talk_politics_guns if 'gun' in x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "p1 = len([x for x in talk_politics_guns if 'gun' in x]) / len(talk_politics_guns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4633699633699634"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## obliczanie $p2 = P('talk.politics.guns')$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "p2 = len(talk_politics_guns) / len(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.048258794414000356"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## obliczanie $p3 = P('gun')$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "p3 = len([x for x in newsgroups_text_tokenized if 'gun' in x]) / len(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.03270284603146544"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ostatecznie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6837837837837839"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(p1 * p2) / p3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prob(index ):\n",
    "    talks_topic = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == index]\n",
    "\n",
    "    len([x for x in talks_topic if 'gun' in x])\n",
    "\n",
    "    if len(talks_topic) == 0:\n",
    "        return 0.0\n",
    "    p1 = len([x for x in talks_topic if 'gun' in x]) / len(talks_topic)\n",
    "    p2 = len(talks_topic) / len(Y)\n",
    "    p3 = len([x for x in newsgroups_text_tokenized if 'gun' in x]) / len(Y)\n",
    "\n",
    "    if p3 == 0:\n",
    "        return 0.0\n",
    "    else: \n",
    "        return (p1 * p2)/ p3\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.01622 \t\t alt.atheism\n",
      "0.00000 \t\t comp.graphics\n",
      "0.00541 \t\t comp.os.ms-windows.misc\n",
      "0.01892 \t\t comp.sys.ibm.pc.hardware\n",
      "0.00270 \t\t comp.sys.mac.hardware\n",
      "0.00000 \t\t comp.windows.x\n",
      "0.01351 \t\t misc.forsale\n",
      "0.04054 \t\t rec.autos\n",
      "0.01892 \t\t rec.motorcycles\n",
      "0.00270 \t\t rec.sport.baseball\n",
      "0.00541 \t\t rec.sport.hockey\n",
      "0.03784 \t\t sci.crypt\n",
      "0.02973 \t\t sci.electronics\n",
      "0.00541 \t\t sci.med\n",
      "0.01622 \t\t sci.space\n",
      "0.00270 \t\t soc.religion.christian\n",
      "0.68378 \t\t talk.politics.guns\n",
      "0.04595 \t\t talk.politics.mideast\n",
      "0.03784 \t\t talk.politics.misc\n",
      "0.01622 \t\t talk.religion.misc\n",
      "1.00000 \t\tsuma\n"
     ]
    }
   ],
   "source": [
    "probs = []\n",
    "for i in range(len(Y_names)):\n",
    "    probs.append(get_prob(i))\n",
    "    print(\"%.5f\" %   get_prob(i),'\\t\\t', Y_names[i])\n",
    "    \n",
    "print(\"%.5f\" % sum(probs), '\\t\\tsuma',)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prob2(index, word ):\n",
    "    talks_topic = [x for x,y in zip(newsgroups_text_tokenized,Y) if y == index]\n",
    "\n",
    "    len([x for x in talks_topic if word in x])\n",
    "\n",
    "    if len(talks_topic) == 0:\n",
    "        return 0.0\n",
    "    p1 = len([x for x in talks_topic if word in x]) / len(talks_topic)\n",
    "    p2 = len(talks_topic) / len(Y)\n",
    "    p3 = len([x for x in newsgroups_text_tokenized if word in x]) / len(Y)\n",
    "\n",
    "    if p3 == 0:\n",
    "        return 0.0\n",
    "    else: \n",
    "        return (p1 * p2)/ p3\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.20874 \t\t alt.atheism\n",
      "0.00850 \t\t comp.graphics\n",
      "0.00364 \t\t comp.os.ms-windows.misc\n",
      "0.00850 \t\t comp.sys.ibm.pc.hardware\n",
      "0.00243 \t\t comp.sys.mac.hardware\n",
      "0.00485 \t\t comp.windows.x\n",
      "0.00607 \t\t misc.forsale\n",
      "0.01092 \t\t rec.autos\n",
      "0.02063 \t\t rec.motorcycles\n",
      "0.01456 \t\t rec.sport.baseball\n",
      "0.01092 \t\t rec.sport.hockey\n",
      "0.00485 \t\t sci.crypt\n",
      "0.00364 \t\t sci.electronics\n",
      "0.00364 \t\t sci.med\n",
      "0.01092 \t\t sci.space\n",
      "0.41748 \t\t soc.religion.christian\n",
      "0.03398 \t\t talk.politics.guns\n",
      "0.02791 \t\t talk.politics.mideast\n",
      "0.02549 \t\t talk.politics.misc\n",
      "0.17233 \t\t talk.religion.misc\n",
      "1.00000 \t\tsuma\n"
     ]
    }
   ],
   "source": [
    "probs = []\n",
    "for i in range(len(Y_names)):\n",
    "    probs.append(get_prob2(i,'god'))\n",
    "    print(\"%.5f\" %   get_prob2(i,'god'),'\\t\\t', Y_names[i])\n",
    "    \n",
    "print(\"%.5f\" % sum(probs), '\\t\\tsuma',)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## założenie naiwnego bayesa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$P(class | word1, word2, word3)  = \\frac{P(word1, word2, word3|class) * P(class)}{P(word1, word2, word3)}$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**przy założeniu o niezależności zmiennych losowych $word1$, $word2$, $word3$**:\n",
    "\n",
    "\n",
    "$P(word1, word2, word3|class) = P(word1|class)* P(word2|class) *  P(word3|class)$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**ostatecznie:**\n",
    "\n",
    "\n",
    "$P(class | word1, word2, word3)  = \\frac{P(word1|class)* P(word2|class) *  P(word3|class)  * P(class)}{\\sum_k{P(word1|class_k)* P(word2|class_k) *  P(word3|class_k)  * P(class_k)}}$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## zadania domowe naiwny bayes1 ręcznie"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- analogicznie zaimplementować funkcję get_prob3(index, document_tokenized), argument document_tokenized ma być zbiorem słów dokumentu. funkcja ma być naiwnym klasyfikatorem bayesowskim (w przypadku wielu słów)\n",
    "- odpalić powyższy listing prawdopodobieństw z funkcją get_prob3 dla dokumentów: {'i','love','guns'} oraz {'is','there','life','after'\n",
    ",'death'}\n",
    "- zadanie proszę zrobić w jupyterze, wygenerować pdf (kod + wyniki odpalenia) i umieścić go jako zadanie w teams\n",
    "- termin 12.05, punktów: 40\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## zadania domowe naiwny bayes2 gotowa biblioteka"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- wybrać jedno z poniższych repozytoriów i je sforkować:\n",
    "  - https://git.wmi.amu.edu.pl/kubapok/paranormal-or-skeptic-ISI-public\n",
    "  - https://git.wmi.amu.edu.pl/kubapok/sport-text-classification-ball-ISI-public\n",
    "- stworzyć klasyfikator bazujący na naiwnym bayessie (może być gotowa biblioteka), może też korzystać z gotowych implementacji tfidf\n",
    "- stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n",
    "- wynik accuracy sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.67\n",
    "- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n",
    "termin 12.05, 40 punktów\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}