200 lines
28 KiB
Plaintext
200 lines
28 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "absolute-lending",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import pandas as pd \n",
|
||
|
"import numpy as np\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"#Wczytanie i normalizacja danych\n",
|
||
|
"def NormalizeData(data):\n",
|
||
|
" for col in data.columns:\n",
|
||
|
" if data[col].dtype == object: \n",
|
||
|
" data[col] = data[col].str.lower()\n",
|
||
|
" if col == 'smoking_status':\n",
|
||
|
" data[col] = data[col].str.replace(\" \", \"_\")\n",
|
||
|
" if col == 'stroke':\n",
|
||
|
" data[col] = data[col].replace({1: 'yes'})\n",
|
||
|
" data[col] = data[col].replace({0: 'no'})\n",
|
||
|
" if col == 'hypertension':\n",
|
||
|
" data[col] = data[col].replace({1: 'yes'})\n",
|
||
|
" data[col] = data[col].replace({0: 'no'})\n",
|
||
|
" if col == 'heart_disease':\n",
|
||
|
" data[col] = data[col].replace({1: 'yes'})\n",
|
||
|
" data[col] = data[col].replace({0: 'no'})\n",
|
||
|
" if col == 'bmi':\n",
|
||
|
" bins = [19,25,30,35,40,90]\n",
|
||
|
" labels=['correct','overweight','obesity_1','obesity_2','extreme']\n",
|
||
|
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
|
||
|
" if col == 'age':\n",
|
||
|
" bins = [0, 30, 40, 50, 60, 70, 80, 90]\n",
|
||
|
" labels = ['0-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80-89',]\n",
|
||
|
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
|
||
|
" if col == 'avg_glucose_level':\n",
|
||
|
" bins = [50,70,90,110,130,150,170,190,210,230,250,270]\n",
|
||
|
" labels = ['50-70', '70-90', '90-110','110-130','130-150','150-170','170-190','190-210', '210-230','230-250','250-270']\n",
|
||
|
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
|
||
|
" data = data.dropna()\n",
|
||
|
" return data\n",
|
||
|
"\n",
|
||
|
"def count_a_priori_prob(dataset):\n",
|
||
|
" is_stroke_amount = len(dataset[dataset.stroke == 'yes'])\n",
|
||
|
" no_stroke_amount = len(dataset[dataset.stroke == 'no'])\n",
|
||
|
" data_length = len(dataset.stroke)\n",
|
||
|
" return {'yes': float(is_stroke_amount)/float(data_length), 'no': float(no_stroke_amount)/float(data_length)}\n",
|
||
|
"\n",
|
||
|
"def separate_labels_from_properties(X_train):\n",
|
||
|
"\n",
|
||
|
" labels = X_train.columns\n",
|
||
|
" labels_values = {}\n",
|
||
|
" for label in labels:\n",
|
||
|
" labels_values[label] = set(X_train[label])\n",
|
||
|
" \n",
|
||
|
" to_return = []\n",
|
||
|
" for x in labels:\n",
|
||
|
" to_return.append({x: labels_values[x]})\n",
|
||
|
"\n",
|
||
|
" return to_return\n",
|
||
|
"\n",
|
||
|
"data = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
|
||
|
"data = NormalizeData(data)\n",
|
||
|
"\n",
|
||
|
"#podział danych na treningowy i testowy \n",
|
||
|
"data_train, data_test = train_test_split(data, random_state = 42)\n",
|
||
|
"\n",
|
||
|
"#rozdzielenie etykiet i cech\n",
|
||
|
"X_train =data_train[['gender', 'age', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]\n",
|
||
|
"Y_train = data_train['stroke']\n",
|
||
|
"\n",
|
||
|
"#rozdzielenie etykiet i cech\n",
|
||
|
"# Dane wejściowe - zbiór danych, wektor etykiet, wektor prawdopodobieństw a priori dla klas.\n",
|
||
|
"\n",
|
||
|
"# Wygenerowanie wektora prawdopodobieństw a priori dla klas.\n",
|
||
|
"a_priori_prob = count_a_priori_prob(data_train)\n",
|
||
|
"labels = separate_labels_from_properties(X_train)\n",
|
||
|
"\n",
|
||
|
"class NaiveBayes():\n",
|
||
|
" def __init__(self, dataset, labels, a_priori_prob):\n",
|
||
|
" self.dataset = dataset\n",
|
||
|
" self.labels = labels\n",
|
||
|
" self.a_priori_prob = a_priori_prob\n",
|
||
|
" \n",
|
||
|
" def count_bayes(self):\n",
|
||
|
" label_probs_return = []\n",
|
||
|
" posteriori_return = []\n",
|
||
|
" final_probs = {'top_yes': 0.0, 'top_no': 0.0, 'total': 0.0}\n",
|
||
|
" \n",
|
||
|
" # self.labels - Wartości etykiet które nas interesują, opcjonalnie podane sa wszystkie.\n",
|
||
|
" # [{'gender': {'female', 'male', 'other'}}, {'age': {'50-59', '40-49', '60-69', '70+', '18-29', '30-39'}}, {'ever_married': {'no', 'yes'}}, {'Residence_type': {'rural', 'urban'}}, {'bmi': {'high', 'mid', 'low'}}, {'smoking_status': {'unknown', 'smokes', 'never_smoked', 'formerly_smoked'}}, {'work_type': {'self_employed', 'private', 'never_worked', 'govt_job'}}, {'hypertension': {'no', 'yes'}}, {'heart_disease': {'no', 'yes'}}]\n",
|
||
|
" # Dla kazdej z klas - 'yes', 'no'\n",
|
||
|
" for idx, cls in enumerate(list(set(self.dataset['stroke']))):\n",
|
||
|
" label_probs = []\n",
|
||
|
" for label in self.labels:\n",
|
||
|
" label_name = list(label.keys())[0]\n",
|
||
|
" for label_value in label[label_name]:\n",
|
||
|
" # Oblicz ilość występowania danej cechy w zbiorze danych np. heart_disease.yes\n",
|
||
|
"\n",
|
||
|
" amount_label_value_yes_class = len(self.dataset.loc[(self.dataset['stroke'] == 'yes') & (self.dataset[label_name] == label_value)])\n",
|
||
|
" amount_label_value_no_class = len(self.dataset.loc[(self.dataset['stroke'] == 'no') & (self.dataset[label_name] == label_value)])\n",
|
||
|
" amount_yes_class = len(self.dataset.loc[(self.dataset['stroke'] == 'yes')])\n",
|
||
|
" amount_no_class = len(self.dataset.loc[(self.dataset['stroke'] == 'no')]) \n",
|
||
|
" # Obliczenie P(heart_disease.yes|'stroke'|), P(heart_disease.yes|'no stroke') itd. dla kazdej cechy.\n",
|
||
|
" # Zapisujemy do listy w formacie (cecha.wartość: prob stroke, cecha.wartość: prob no stroke)\n",
|
||
|
" label_probs.append({str(label_name + \".\" + label_value):(amount_label_value_yes_class/amount_yes_class, amount_label_value_no_class/amount_no_class)})\n",
|
||
|
"\n",
|
||
|
" # Suma prawdopodobienstw mozliwych wartosci danej cechy dla danej klasy, powinna sumować się do 1.\n",
|
||
|
"# print(label_probs)\n",
|
||
|
" label_probs_return.append(label_probs)\n",
|
||
|
" # Obliczanie licznika wzoru Bayesa (mnozymy wartosci prob cech z prawdop apriori danej klasy):\n",
|
||
|
" top = 1\n",
|
||
|
" for label_prob in label_probs:\n",
|
||
|
" top *= list(label_prob.values())[0][idx]\n",
|
||
|
" top *= self.a_priori_prob[cls]\n",
|
||
|
"\n",
|
||
|
" final_probs[cls] = top\n",
|
||
|
" final_probs['total'] += top\n",
|
||
|
" \n",
|
||
|
"# print(\"Prawdopodobieństwo a posteriori dla klasy yes-stroke\", final_probs['yes']/final_probs['total'])\n",
|
||
|
"# print(\"Prawdopodobieństwo a posteriori dla klasy no-stroke\", final_probs['no']/final_probs['total'])\n",
|
||
|
" posteriori_return.append(final_probs['yes']/final_probs['total'])\n",
|
||
|
" posteriori_return.append(final_probs['no']/final_probs['total'])\n",
|
||
|
" return posteriori_return, label_probs_return\n",
|
||
|
"\n",
|
||
|
"labels = [{'age': {'70-79'}},{'hypertension': {'yes'}},{'heart_disease': {'yes'}},{'bmi': {'obesity_2'}},{'gender': {'male'}},{'smoking_status': {'smokes'}}]\n",
|
||
|
"naive_bayes = NaiveBayes(data_train, labels, a_priori_prob)\n",
|
||
|
"posteriori, labels = naive_bayes.count_bayes()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "handled-lightweight",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"def plot_priori(labels): \n",
|
||
|
" keys = [x for i in range(1) for j in range(len(labels[i])) for x in labels[i][j].keys()]\n",
|
||
|
" aprori = [list(x) for i in range(1) for j in range(len(labels[i])) for x in labels[i][j].values()]\n",
|
||
|
" yes_aprori = np.array(aprori)[:,0]\n",
|
||
|
" no_aprori = np.array(aprori)[:,1]\n",
|
||
|
" plt.figure(figsize=(10,10))\n",
|
||
|
" plt.scatter(keys,yes_aprori, color ='green', label= 'Positive stroke')\n",
|
||
|
" plt.scatter(keys,no_aprori, color ='red', label= 'Negative stroke')\n",
|
||
|
" plt.legend()\n",
|
||
|
" plt.show()\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "molecular-ladder",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnsAAAI/CAYAAAAP/R/tAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAA530lEQVR4nO3deZxddZ3n/9cnm6FYIkucVkJSaRsfGpIQpEBHjM0iEAYNtBs45U9wZMp26bZb2p/piSOIxpWRDEi3hMambYsGQcBAo6C0ICC0VCAmEkTWhIAzHUEjUIAEPvPHOZXcVCq1JJXcyjev5+NRj7rne7bPWe6573vuOfdGZiJJkqQyjWp2AZIkSdp2DHuSJEkFM+xJkiQVzLAnSZJUMMOeJElSwcY0u4De9tlnn2xtbW12GZIkSQNasmTJbzJzYrPr6M+IC3utra10dXU1uwxJkqQBRcTKZtcwED/GlSRJKphhT5IkqWCGPUmSpIKNuGv2JEkqzQsvvMDq1at57rnnml2KttD48eOZNGkSY8eObXYpQ2bYkyRpG1u9ejW77747ra2tRESzy9EQZSZPPPEEq1evZurUqc0uZ8j8GFeSpG3sueeeY++99zbo7aAigr333nuHPTNr2JMkaTsw6O3YduTtZ9iTJEkqmGFPkqSdwOjRo5k1axbTp0/n3e9+N93d3UMa//HHH+dd73oXAEuXLuW6665b32/x4sV86UtfGtZ6ezzyyCNccsklQx7vzDPP5Oyzz94GFe14DHuSJO0EdtllF5YuXcovfvELxo0bxze+8Y0hjf+qV72KK664Atg07M2dO5d58+YNa709+gt769at2ybzLI1hT5KkEaZzeSetC1sZ9dlRtC5spXN557BOf/bs2TzwwAM8+eSTnHjiicycOZM3vvGNLFu2DICbb76ZWbNmMWvWLA466CCeeuopHnnkEaZPn84f/vAHPvOZz3DZZZcxa9YsLrvsMi6++GI+9rGPsXbtWqZMmcJLL70EwDPPPMN+++3HCy+8wIMPPsicOXM4+OCDmT17Nr/85S83qauv+c6bN49bbrmFWbNmcc4553DxxRczd+5cjjzySI466qjNLkOjCy+8kOOOO45nn32Wb3/72xx66KHMmjWLD33oQ7z44ovDum5HIsOeJEkjSOfyTjqu6WDl2pUkycq1K+m4pmPYAt+6dev4/ve/z4wZMzjjjDM46KCDWLZsGV/4whd4//vfD8DZZ5/N+eefz9KlS7nlllvYZZdd1o8/btw4zjrrLE466SSWLl3KSSedtL7fhAkTmDVrFjfffDMA1157Lcceeyxjx46lo6OD8847jyVLlnD22WfzkY98ZJPa+prvl770JWbPns3SpUv567/+awDuuusurrjiCm6++ebNLkOPr3/961x77bVcffXVPPLII1x22WXcdtttLF26lNGjR9PZObxBeiTye/YkSRpB5t84n+4XNr6ervuFbubfOJ/2Ge1bPN1nn32WWbNmAdWZvQ9+8IO84Q1v4Lvf/S4ARx55JE888QS///3vOeyww/jEJz5Be3s773jHO5g0adKg53PSSSdx2WWXccQRR3DppZfykY98hKeffpqf/vSnvPvd714/3PPPP7/JuIOd79FHH81ee+0FwK233trnMgB861vfYr/99uPqq69m7Nix3HjjjSxZsoRDDjlk/Tp5xSteMehl21EZ9iRJGkFWrV01pPbB6rlmbzDmzZvH8ccfz3XXXcdhhx3G9ddfz/jx4wc17ty5c/kf/+N/8OSTT7JkyRKOPPJInnnmGV7+8pcPOP++5tuXXXfddVC1zJgxg6VLl67/MuTM5JRTTuGLX/zioMYvhR/jSpI0gkyeMHlI7Vtj9uzZ6z/GvOmmm9hnn33YY489ePDBB5kxYwaf+tSnOOSQQza5vm733Xfnqaee6nOau+22G4cccggf//jHedvb3sbo0aPZY489mDp1KpdffjlQ/SLFz3/+803G7Wu+/c2rv2UAOOigg7jggguYO3cujz/+OEcddRRXXHEF//Ef/wHAk08+ycqVK4e41nY8hj1JkkaQBUctoGVsy0ZtLWNbWHDUgmGf15lnnsmSJUuYOXMm8+bN45/+6Z8AWLhwIdOnT2fmzJmMHTuW4447bqPxjjjiCFasWLH+Bo3eTjrpJL797W9vdD1fZ2cnF110EQceeCAHHHAA3/ve9zYZr6/5zpw5k9GjR3PggQdyzjnnDHoZerz5zW/m7LPP5vjjj+cVr3gFn//85znmmGOYOXMmRx99NL/+9a+3aN3tSCIzm13DRtra2rKrq6vZZUiSNGzuvfdeXve61w16+M7lncy/cT6r1q5i8oTJLDhqwVZdr6fh0dd2jIglmdnWpJIGxWv2JEkaYdpntBvuNGz8GFeSJKlghj1JkqSCGfYkSdoK2/rXLqSt5TV7kiRtoZ5fu+j5EuSeX7sAvOZOI4Zn9iRJ2kL9/dqFNFIY9iRJ2kLb6tcutoWI4PTTT1/fffbZZ3PmmWcO+3y+8IUvbNT9pje9adjn0WPhwoV0d3cPPGCDRx55hOnTp2+jikYmw54kSVtoe/7axdZ62ctexpVXXslvfvObbTqf3mHvpz/96TabV39h78UXX9xm893RGPYkSdpC2+zXLjo7obUVRo2q/ndu/U0fY8aMoaOjo89foVizZg3vfOc7OeSQQzjkkEO47bbb1rcfffTRHHDAAZx22mlMmTJlfVg88cQTOfjggznggANYtGgRUP227bPPPsusWbNob6+uWdxtt90AOPnkk/nXf/3X9fM89dRTueKKK3jxxRf55Cc/ySGHHMLMmTO54IILNqnvmWee4fjjj+fAAw9k+vTpXHbZZZx77rk8/vjjHHHEERxxxBHr53X66adz4IEHcvvtt/O1r32N6dOnM336dBYuXLjJdB966CEOOugg7rzzTh588EHmzJnDwQcfzOzZszf5ibgdWmaOqL+DDz44JUnaUXx72bdzyjlTMs6MnHLOlPz2sm9vMsyKFSuGMMFvZ7a0ZMKGv5aWqn0r7Lrrrrl27dqcMmVK/u53v8uvfvWrecYZZ2Rm5nvf+9685ZZbMjNz5cqV+drXvjYzMz/60Y/mF77whczM/P73v59ArlmzJjMzn3jiiczM7O7uzgMOOCB/85vfrJ9P7/lmZl555ZX5/ve/PzMzn3/++Zw0aVJ2d3fnBRdckJ/73OcyM/O5557Lgw8+OB966KGNpnHFFVfkaaedtr77d7/7XWZmTpkyZX09mZlAXnbZZZmZ2dXVldOnT8+nn346n3rqqZw2bVredddd+fDDD+cBBxyQv/zlL3PWrFm5dOnSzMw88sgj81e/+lVmZt5xxx15xBFHbLIO+9qOQFeOgPzU359340qStBWG/dcu5s+H3h9NdndX7e1bN5899tiD97///Zx77rnssssu69t/9KMfsWLFivXdv//973n66ae59dZbueqqqwCYM2cOe+655/phzj333PX9Hn30Ue6//3723nvvzc77uOOO4+Mf/zjPP/88P/jBD3jLW97CLrvswg033MCyZcu44oorAFi7di33338/U6dOXT/ujBkzOP300/nUpz7F2972NmbPnt3nPEaPHs073/lOAG699Vb+7M/+jF133RWAd7zjHdxyyy3MnTuXNWvWcMIJJ3DllVcybdo0nn76aX7605/y7ne/e/20nn/++cGt1B2AYU+SpJFk1WZu7thc+xD91V/9Fa9//ev5wAc+sL7tpZde4o477mD8+PGDmsZNN93Ej370I26//XZaWlo4/PDDee655/odZ/z48Rx++OFcf/31XHbZZZx88slA9Qnjeeedx7HHHrvZcV/zmtdw1113cd111/HpT3+ao446is985jN9zmP06NED1j9hwgQmT57MrbfeyrRp03jppZd4+ctfztKlSwccd0fkNXuSJI0kkzdzc8fm2odor7324j3veQ8XXXTR+rZjjjmG8847b313T+g57LDD+M53vgPADTfcwG9/+1ugOvu255570tLSwi9/+UvuuOOO9eOOHTuWF154oc95n3TSSfzjP/4jt9xyC3PmzAH
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x720 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plot_priori(labels)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|