bayes_project/test.ipynb

233 lines
63 KiB
Plaintext
Raw Normal View History

2021-05-30 15:36:17 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "absolute-lending",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd \n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt\n",
"\n",
"#Wczytanie i normalizacja danych\n",
"def NormalizeData(data):\n",
" for col in data.columns:\n",
" if data[col].dtype == object: \n",
" data[col] = data[col].str.lower()\n",
" if col == 'smoking_status':\n",
" data[col] = data[col].str.replace(\" \", \"_\")\n",
" if col == 'stroke':\n",
" data[col] = data[col].replace({1: 'yes'})\n",
" data[col] = data[col].replace({0: 'no'})\n",
" if col == 'hypertension':\n",
" data[col] = data[col].replace({1: 'yes'})\n",
" data[col] = data[col].replace({0: 'no'})\n",
" if col == 'heart_disease':\n",
" data[col] = data[col].replace({1: 'yes'})\n",
" data[col] = data[col].replace({0: 'no'})\n",
" if col == 'bmi':\n",
" bins = [19,25,30,35,40,90]\n",
" labels=['correct','overweight','obesity_1','obesity_2','extreme']\n",
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
" if col == 'age':\n",
" bins = [0, 30, 40, 50, 60, 70, 80, 90]\n",
" labels = ['0-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80-89',]\n",
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
" if col == 'avg_glucose_level':\n",
" bins = [50,70,90,110,130,150,170,190,210,230,250,270]\n",
" labels = ['50-70', '70-90', '90-110','110-130','130-150','150-170','170-190','190-210', '210-230','230-250','250-270']\n",
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
" data = data.dropna()\n",
" return data\n",
"\n",
"def count_a_priori_prob(dataset):\n",
" is_stroke_amount = len(dataset[dataset.stroke == 'yes'])\n",
" no_stroke_amount = len(dataset[dataset.stroke == 'no'])\n",
" data_length = len(dataset.stroke)\n",
" return {'yes': float(is_stroke_amount)/float(data_length), 'no': float(no_stroke_amount)/float(data_length)}\n",
"\n",
"def separate_labels_from_properties(X_train):\n",
"\n",
" labels = X_train.columns\n",
" labels_values = {}\n",
" for label in labels:\n",
" labels_values[label] = set(X_train[label])\n",
" \n",
" to_return = []\n",
" for x in labels:\n",
" to_return.append({x: labels_values[x]})\n",
"\n",
" return to_return\n",
"\n",
"data = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
"data = NormalizeData(data)\n",
"\n",
"#podział danych na treningowy i testowy \n",
"data_train, data_test = train_test_split(data, random_state = 42)\n",
"\n",
"#rozdzielenie etykiet i cech\n",
"X_train =data_train[['gender', 'age', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]\n",
"Y_train = data_train['stroke']\n",
"\n",
"#rozdzielenie etykiet i cech\n",
"# Dane wejściowe - zbiór danych, wektor etykiet, wektor prawdopodobieństw a priori dla klas.\n",
"\n",
"# Wygenerowanie wektora prawdopodobieństw a priori dla klas.\n",
"a_priori_prob = count_a_priori_prob(data_train)\n",
"labels = separate_labels_from_properties(X_train)\n",
"\n",
"class NaiveBayes():\n",
" def __init__(self, dataset, labels, a_priori_prob):\n",
" self.dataset = dataset\n",
" self.labels = labels\n",
" self.a_priori_prob = a_priori_prob\n",
" \n",
" def count_bayes(self):\n",
" label_probs_return = []\n",
" posteriori_return = []\n",
" final_probs = {'top_yes': 0.0, 'top_no': 0.0, 'total': 0.0}\n",
" \n",
" # self.labels - Wartości etykiet które nas interesują, opcjonalnie podane sa wszystkie.\n",
" # [{'gender': {'female', 'male', 'other'}}, {'age': {'50-59', '40-49', '60-69', '70+', '18-29', '30-39'}}, {'ever_married': {'no', 'yes'}}, {'Residence_type': {'rural', 'urban'}}, {'bmi': {'high', 'mid', 'low'}}, {'smoking_status': {'unknown', 'smokes', 'never_smoked', 'formerly_smoked'}}, {'work_type': {'self_employed', 'private', 'never_worked', 'govt_job'}}, {'hypertension': {'no', 'yes'}}, {'heart_disease': {'no', 'yes'}}]\n",
" # Dla kazdej z klas - 'yes', 'no'\n",
" for idx, cls in enumerate(list(set(self.dataset['stroke']))):\n",
" label_probs = []\n",
" for label in self.labels:\n",
" label_name = list(label.keys())[0]\n",
" for label_value in label[label_name]:\n",
" # Oblicz ilość występowania danej cechy w zbiorze danych np. heart_disease.yes\n",
"\n",
" amount_label_value_yes_class = len(self.dataset.loc[(self.dataset['stroke'] == 'yes') & (self.dataset[label_name] == label_value)])\n",
" amount_label_value_no_class = len(self.dataset.loc[(self.dataset['stroke'] == 'no') & (self.dataset[label_name] == label_value)])\n",
" amount_yes_class = len(self.dataset.loc[(self.dataset['stroke'] == 'yes')])\n",
" amount_no_class = len(self.dataset.loc[(self.dataset['stroke'] == 'no')]) \n",
" # Obliczenie P(heart_disease.yes|'stroke'|), P(heart_disease.yes|'no stroke') itd. dla kazdej cechy.\n",
" # Zapisujemy do listy w formacie (cecha.wartość: prob stroke, cecha.wartość: prob no stroke)\n",
" label_probs.append({str(label_name + \".\" + label_value):(amount_label_value_yes_class/amount_yes_class, amount_label_value_no_class/amount_no_class)})\n",
"\n",
" # Suma prawdopodobienstw mozliwych wartosci danej cechy dla danej klasy, powinna sumować się do 1.\n",
"# print(label_probs)\n",
" label_probs_return.append(label_probs)\n",
" # Obliczanie licznika wzoru Bayesa (mnozymy wartosci prob cech z prawdop apriori danej klasy):\n",
" top = 1\n",
" for label_prob in label_probs:\n",
" top *= list(label_prob.values())[0][idx]\n",
" top *= self.a_priori_prob[cls]\n",
"\n",
" final_probs[cls] = top\n",
" final_probs['total'] += top\n",
" \n",
"# print(\"Prawdopodobieństwo a posteriori dla klasy yes-stroke\", final_probs['yes']/final_probs['total'])\n",
"# print(\"Prawdopodobieństwo a posteriori dla klasy no-stroke\", final_probs['no']/final_probs['total'])\n",
" posteriori_return.append(final_probs['yes']/final_probs['total'])\n",
" posteriori_return.append(final_probs['no']/final_probs['total'])\n",
" return posteriori_return, label_probs_return\n",
"\n",
2021-05-31 16:05:08 +02:00
"labels = [{'age': {'70-79'}},{'hypertension': {'yes'}},{'heart_disease': {'yes'}},{'bmi': {'correct'}},{'gender': {'male'}},{'smoking_status': {'smokes'}}]\n",
2021-05-30 15:36:17 +02:00
"naive_bayes = NaiveBayes(data_train, labels, a_priori_prob)\n",
"posteriori, labels = naive_bayes.count_bayes()"
]
},
{
"cell_type": "code",
2021-05-31 16:05:08 +02:00
"execution_count": 2,
2021-05-30 15:36:17 +02:00
"id": "handled-lightweight",
"metadata": {},
"outputs": [],
"source": [
2021-05-31 16:05:08 +02:00
"def autolabel(rects, values ,ax):\n",
" # Attach some text labels.\n",
" for (rect, value) in zip(rects, values):\n",
" ax.text(rect.get_x() + rect.get_width() / 2.,\n",
" rect.get_y() + rect.get_height() / 2.,\n",
" '%.3f'%value,\n",
" ha = 'center',\n",
" va = 'center',\n",
" fontsize= 15,\n",
" color ='black')\n",
"def plot_priori(labels, posteriori): \n",
" keys =[ r\"$\\bf{\" + (x.split('.',1)[0]).replace('_', ' ')+ \"}$\" + '\\n' + x.split('.',1)[1] for i in range(1) for j in range(len(labels[i])) for x in labels[i][j].keys()]\n",
2021-05-30 15:36:17 +02:00
" aprori = [list(x) for i in range(1) for j in range(len(labels[i])) for x in labels[i][j].values()]\n",
" yes_aprori = np.array(aprori)[:,0]\n",
" no_aprori = np.array(aprori)[:,1]\n",
2021-05-31 16:05:08 +02:00
" \n",
" width = 0.55\n",
"\n",
" fig = plt.figure(figsize=(25,10))\n",
" \n",
" ax1 = fig.add_subplot(121)\n",
" rec1 = ax1.bar(keys,yes_aprori,width, color ='lime', label= 'Positive stroke')\n",
" rec2 = ax1.bar(keys,no_aprori,width, color ='crimson', bottom = yes_aprori, label= 'Negative stroke')\n",
" ax1.set_yticks(np.arange(0, 1.1,0.1))\n",
" ax1.set_ylabel('Probability',fontsize=18)\n",
" ax1.set_xlabel('\\nFeatures',fontsize=18)\n",
" ax1.tick_params(axis='x', which='major', labelsize=12)\n",
" autolabel(rec1,yes_aprori, ax1)\n",
" autolabel(rec2,no_aprori, ax1)\n",
" ax1.legend(fontsize=15)\n",
" \n",
" ax2 = fig.add_subplot(122)\n",
" rec3 = ax2.bar(0, posteriori[1],capsize=1 ,color=['crimson'], label='Negative stroke')\n",
" rec4 = ax2.bar(1, posteriori[0], color=['lime'],label='Positive stroke')\n",
" ax2.set_ylabel('Probability',fontsize=18)\n",
" ax2.set_xlabel('\\nClasses',fontsize=18)\n",
" ax2.set_xticks([0,1])\n",
" ax2.set_yticks(np.arange(0, 1.1,0.1))\n",
" ax2.tick_params(axis='x', which='major', labelsize=15)\n",
" autolabel(rec3,[posteriori[1]], ax2)\n",
" autolabel(rec4,[posteriori[0]], ax2)\n",
" ax2.legend(fontsize=15)\n",
" \n",
2021-05-30 15:36:17 +02:00
" plt.show()\n"
]
},
{
"cell_type": "code",
2021-05-31 16:05:08 +02:00
"execution_count": 3,
2021-05-30 15:36:17 +02:00
"id": "molecular-ladder",
"metadata": {},
"outputs": [
{
"data": {
2021-05-31 16:05:08 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABa4AAAKDCAYAAADy9p1tAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAACdGElEQVR4nOzdeZzVVf348deZfQZmhn0RZBERNzIFkRSXWELF3BWyXEvsm+LP3NI0BUtNRC1TU+prZFaiVoqCS/ZVy8QQy3JHQAQRFEVZZoFZzu+PO9xmYxiGWS7weva4j5nP+ZzP+bw/N8T3fXvuOSHGiCRJkiRJkiRJqSKtrQOQJEmSJEmSJKk6C9eSJEmSJEmSpJRi4VqSJEmSJEmSlFIsXEuSJEmSJEmSUoqFa0mSJEmSJElSSrFwLUmSJEmSJElKKW1WuA4h3BtC+DiE8PpmzocQwu0hhIUhhP+EEA5o7RglSZIkNZ45viRJkppLW864ngEc2cD5o4CBVa+JwM9bISZJkiRJTTcDc3xJkiQ1gzYrXMcY/wqsbqDLccB9MeEloEMIoWfrRCdJkiRpa5njS5IkqblktHUADegFLKt2/EFV24raHUMIE0nM2KBdu3ZD9txzz1YJUJIkSa3nlVde+STG2LWt49A2MceXJElSUkM5fioXrhstxjgdmA4wdOjQOH/+/DaOSJIkSc0thPB+W8eg1mOOL0mStONrKMdvyzWut2Q5sGu1495VbZIkSZK2T+b4kiRJapRULlzPAs6o2nl8OLAmxljnK4SSJEmSthvm+JIkSWqUNlsqJITwe+AIoEsI4QPgWiATIMZ4NzAHOBpYCBQDZ7dNpJIkSZIawxxfkiRJzaXNCtcxxq9t4XwEzm+lcCRJkiRtI3N8SZIkNZdUXipEkiRJkiRJkrQTsnAtSZIkSZIkSUopbbZUiCRJah1r167l448/pqysrK1DkTYrMzOTbt26UVBQ0NahSJIkpTxzfG0PtjXHt3AtSdIObO3atXz00Uf06tWL3NxcQghtHZJUR4yRkpISli9fDmDxWpIkqQHm+NoeNEeO71IhkiTtwD7++GN69epFXl6eCa1SVgiBvLw8evXqxccff9zW4UiSJKU0c3xtD5ojx7dwLUnSDqysrIzc3Ny2DkNqlNzcXL/uKkmStAXm+NqebEuOb+FakqQdnLMwtL3wz6okSVLjmDdpe7Etf1YtXEuSJEmSJEmSUoqFa0mSJEmSJElSSrFwLUnSTii00f+aYvLkyYQQkq9ddtmFk046iUWLFjXvexICd9xxR/J4+vTpPPLII3X69evXj0svvbRZ791UU6dO5bnnnmvWMVPp+SRJktR4i7oe2iavptiU448dO7bOuZNPPpkjjjhiG9+NpluwYAGTJ0/m888/r9E+Y8YMQgisX7++bQKrZt68eUyePLlZx0yl59vEwrUkSUp5hYWFzJ07l7lz5zJt2jReffVVRo0aRVFRUbPdY+7cuZxyyinJ480Vrv/0pz9x4YUXNtt9t0VLFK4lSZKk1vL000/z8ssvt3UYNSxYsIApU6bUKVyPGzeOuXPnkpeX1zaBVTNv3jymTJnS1mG0uIy2DkCSJGlLMjIyGD58OADDhw+nT58+HHroocyZM6dGsXlbbBp/S/bff/9muV9rKikpced5SZIkpZROnTrRq1cvrr/++nonjKSarl270rVr17YOY6vEGNmwYQM5OTltHUqTOONakiRtd4YMGQLAkiVLAPjkk08488wz6dy5M3l5eRxxxBHMnz+/xjWzZs1iyJAhtGvXjo4dO3LQQQfx/PPPJ89XXyrkiCOO4JVXXuHXv/51comSGTNmADWX0pgxYwZZWVl1ZmO88cYbhBB45plnkm2PPvooQ4cOJScnhx49enD55ZdTVlbW4HO+8MILHHrooRQUFFBQUMAXv/hFHnrooWQcn376KVOmTEnGuGn2dQiBW2+9lYsuuoiuXbsyePDgRr9PtS1fvpw999yT0aNHU1xcDMDf/vY3Dj/8cPLy8ujcuTPnnnsu69ata3AcSZIkqboQAldddRWzZs3itddea7Dv0qVLmTBhAp06dSIvL4+xY8fyzjvv1Olz1FFHkZubS//+/ZkxY0adZUfefvttJkyYwK677kpeXh777LMPP/nJT6isrATgueee46tf/SoA/fv3J4RAv379gLpLafTv35/LLrusTqynnHIKI0aMSB6vXr2aiRMn0r17d3Jycjj44IP5xz/+0eDzlpWVcemll9KnTx+ys7PZZZddOOGEE9i4cSMzZsxg0qRJyfcwhJB8xsmTJ9OlSxdeeOEFDjzwQHJycpKfHx588EEGDx5MdnY2u+66K1dddRXl5eUNxnHzzTeTk5PDrFmzACgtLeXyyy9n1113JTs7m/322485c+Y0OMa2sHAtSZK2O5sK1j169ADg+OOP56mnnmLatGnMnDmTyspKvvzlL7Nw4UIAFi1axMknn8zIkSN57LHH+O1vf8sxxxzD6tWr6x3/rrvuYs899+Too49OLlEybty4Ov2OP/54Qgj86U9/qtE+c+ZMunfvzpe//GUgkSSeeOKJDBs2jFmzZnHttdcyffp0rrzyys0+49q1aznmmGPYbbfd+MMf/sDDDz/M6aefniyS/+lPf6KwsJBvfvObyRgPOOCA5PU333wzK1as4De/+Q233357o96n+t7nww47jAEDBvD444+Tl5fH3//+d0aPHk2PHj14+OGH+clPfsKcOXM4++yzN/sskiRJUn1OOeUUBg4cyPXXX7/ZPqtXr2bEiBG888473H333Tz44IMUFRUxevRoSkpKgMTM4mOPPZa33nqLe++9l1tvvZXbb7+9ToF4+fLlDBo0iLvuuos5c+Zw7rnncu2113LTTTcBcMABBzBt2jQA/vjHPzJ37tw6uf4mp556arIovMn69euZPXs2EyZMAGDDhg2MHj2aZ555hptvvplHHnmErl27Mnr0aFauXLnZZ77xxhv57W9/yw9/+EP+/Oc/85Of/ITCwkIqKioYN24cl1xyCUDyc8Bdd92VvLa4uJgzzzyTb33rWzz55JMMGzaMp59+mvHjx3PAAQfw6KOPMmnSJKZNm8YFF1yw2Riuu+46rr32WmbNmsWxxx4LJNYfnzFjBt///vd57LHHOPDAAzn22GN59dVXNzvOtnCpEEmStF3YNBtg8eLFfOc73yE/P5/Ro0fz5JNP8ve//53nnnuOww8/HICRI0fSr18/br75Zu655x7+9a9/kZ+fz80335wc7+ijj97svfbee2/atWtH165dG1xCpEOHDhx55JHMnDmzRuF25syZnHzyyaSnpxNj5LLLLuOMM86okVBmZ2dz/vnnc+WVV9K5c+c6Yy9YsIA1a9Zwxx13kJ+fD8BXvvKV5Pn999+fjIwMevfuXW+MPXv2ZObMmcnjxrxP1S1cuJCRI0dy4IEH8vvf/56srCwArrjiCg4++OAaY/fq1YtRo0bx+uuvs++++272/ZIkSZKqS0tL48orr+Sb3/wm1113HXvssUedPrfddhtFRUW8+uqrdOrUCYBDDjmEfv36ce+993L++eczZ84c/v3vfzNv3jwOPPBAAIYNG0a/fv0YMGBAcqxRo0YxatQoIFHsHjFiBMXFxfziF7/gyiuvpKCggEGDBgGJfHvTbOv6TJgwgalTp/LSSy8l8/HHHnuMjRs3JpczvP/++3n99dd54403GDhwIACjR49m0KBB3HLLLTU+n1Q3b948TjvtNM4888xk26mnngpAbm5uMq76PgeUlJRw6623ctxxxyXbzjzzTI444gh+/etfA3DkkUcCcOWVV3L11VfTu3fvGmN8//vf52c/+xlPPPFE8rPDX/7yF2bPnl3j88RXvvIVFixYwPXXX1+niN8cnHEtSZJS3qeffkpmZiaZmZkMGjSIxYsXM3PmTHr27Mm8efPo1q1bMnkCaNeuHccccwwvvPACAIMHD2bNmjWceeaZPP300826qeP48eP5y1/+wqeffgrAq6++yoIFCxg/fjyQKEAvXbqUU089lfLy8uRr5MiRlJaW8vrrr9c77oABA2j
2021-05-30 15:36:17 +02:00
"text/plain": [
2021-05-31 16:05:08 +02:00
"<Figure size 1800x720 with 2 Axes>"
2021-05-30 15:36:17 +02:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-05-31 16:05:08 +02:00
"plot_priori(labels,posteriori)"
2021-05-30 15:36:17 +02:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}