naivebayes
This commit is contained in:
commit
0517754510
5111
healthcare-dataset-stroke-data.csv
Normal file
5111
healthcare-dataset-stroke-data.csv
Normal file
File diff suppressed because it is too large
Load Diff
390
main.ipynb
Normal file
390
main.ipynb
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd \n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"\n",
|
||||||
|
"#Wczytanie i normalizacja danych\n",
|
||||||
|
"def NormalizeData(data):\n",
|
||||||
|
" for col in data.columns:\n",
|
||||||
|
" if data[col].dtype == object: \n",
|
||||||
|
" data[col] = data[col].str.lower()\n",
|
||||||
|
" if col == 'smoking_status':\n",
|
||||||
|
" data[col] = data[col].str.replace(\" \", \"_\")\n",
|
||||||
|
" if col == 'work_type':\n",
|
||||||
|
" data[col] = data[col].str.replace(\"-\", \"_\")\n",
|
||||||
|
" if col == 'bmi':\n",
|
||||||
|
" bins = [0, 21, 28, 40]\n",
|
||||||
|
" labels=['low','mid','high']\n",
|
||||||
|
" data[col] = pd.cut(data[col], bins=bins, labels=labels)\n",
|
||||||
|
" if col == 'age':\n",
|
||||||
|
" bins = [18, 30, 40, 50, 60, 70, 120]\n",
|
||||||
|
" labels = ['18-29', '30-39', '40-49', '50-59', '60-69', '70+']\n",
|
||||||
|
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
|
||||||
|
" if col == 'stroke':\n",
|
||||||
|
" data[col] = data[col].replace({1: 'yes'})\n",
|
||||||
|
" data[col] = data[col].replace({0: 'no'})\n",
|
||||||
|
" if col == 'hypertension':\n",
|
||||||
|
" data[col] = data[col].replace({1: 'yes'})\n",
|
||||||
|
" data[col] = data[col].replace({0: 'no'})\n",
|
||||||
|
" if col == 'heart_disease':\n",
|
||||||
|
" data[col] = data[col].replace({1: 'yes'})\n",
|
||||||
|
" data[col] = data[col].replace({0: 'no'})\n",
|
||||||
|
" data = data.dropna()\n",
|
||||||
|
" return data\n",
|
||||||
|
"\n",
|
||||||
|
"data = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
|
||||||
|
"data = NormalizeData(data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>id</th>\n",
|
||||||
|
" <th>gender</th>\n",
|
||||||
|
" <th>age</th>\n",
|
||||||
|
" <th>hypertension</th>\n",
|
||||||
|
" <th>heart_disease</th>\n",
|
||||||
|
" <th>ever_married</th>\n",
|
||||||
|
" <th>work_type</th>\n",
|
||||||
|
" <th>Residence_type</th>\n",
|
||||||
|
" <th>avg_glucose_level</th>\n",
|
||||||
|
" <th>bmi</th>\n",
|
||||||
|
" <th>smoking_status</th>\n",
|
||||||
|
" <th>stroke</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>9046</td>\n",
|
||||||
|
" <td>male</td>\n",
|
||||||
|
" <td>60-69</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>private</td>\n",
|
||||||
|
" <td>urban</td>\n",
|
||||||
|
" <td>228.69</td>\n",
|
||||||
|
" <td>high</td>\n",
|
||||||
|
" <td>formerly_smoked</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>31112</td>\n",
|
||||||
|
" <td>male</td>\n",
|
||||||
|
" <td>70+</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>private</td>\n",
|
||||||
|
" <td>rural</td>\n",
|
||||||
|
" <td>105.92</td>\n",
|
||||||
|
" <td>high</td>\n",
|
||||||
|
" <td>never_smoked</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>60182</td>\n",
|
||||||
|
" <td>female</td>\n",
|
||||||
|
" <td>40-49</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>private</td>\n",
|
||||||
|
" <td>urban</td>\n",
|
||||||
|
" <td>171.23</td>\n",
|
||||||
|
" <td>high</td>\n",
|
||||||
|
" <td>smokes</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>1665</td>\n",
|
||||||
|
" <td>female</td>\n",
|
||||||
|
" <td>70+</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>self_employed</td>\n",
|
||||||
|
" <td>rural</td>\n",
|
||||||
|
" <td>174.12</td>\n",
|
||||||
|
" <td>mid</td>\n",
|
||||||
|
" <td>never_smoked</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5</th>\n",
|
||||||
|
" <td>56669</td>\n",
|
||||||
|
" <td>male</td>\n",
|
||||||
|
" <td>70+</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>private</td>\n",
|
||||||
|
" <td>urban</td>\n",
|
||||||
|
" <td>186.21</td>\n",
|
||||||
|
" <td>high</td>\n",
|
||||||
|
" <td>formerly_smoked</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>...</th>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" <td>...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5102</th>\n",
|
||||||
|
" <td>45010</td>\n",
|
||||||
|
" <td>female</td>\n",
|
||||||
|
" <td>50-59</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>private</td>\n",
|
||||||
|
" <td>rural</td>\n",
|
||||||
|
" <td>77.93</td>\n",
|
||||||
|
" <td>mid</td>\n",
|
||||||
|
" <td>never_smoked</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5106</th>\n",
|
||||||
|
" <td>44873</td>\n",
|
||||||
|
" <td>female</td>\n",
|
||||||
|
" <td>70+</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>self_employed</td>\n",
|
||||||
|
" <td>urban</td>\n",
|
||||||
|
" <td>125.20</td>\n",
|
||||||
|
" <td>high</td>\n",
|
||||||
|
" <td>never_smoked</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5107</th>\n",
|
||||||
|
" <td>19723</td>\n",
|
||||||
|
" <td>female</td>\n",
|
||||||
|
" <td>30-39</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>self_employed</td>\n",
|
||||||
|
" <td>rural</td>\n",
|
||||||
|
" <td>82.99</td>\n",
|
||||||
|
" <td>high</td>\n",
|
||||||
|
" <td>never_smoked</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5108</th>\n",
|
||||||
|
" <td>37544</td>\n",
|
||||||
|
" <td>male</td>\n",
|
||||||
|
" <td>50-59</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>private</td>\n",
|
||||||
|
" <td>rural</td>\n",
|
||||||
|
" <td>166.29</td>\n",
|
||||||
|
" <td>mid</td>\n",
|
||||||
|
" <td>formerly_smoked</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5109</th>\n",
|
||||||
|
" <td>44679</td>\n",
|
||||||
|
" <td>female</td>\n",
|
||||||
|
" <td>40-49</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" <td>yes</td>\n",
|
||||||
|
" <td>govt_job</td>\n",
|
||||||
|
" <td>urban</td>\n",
|
||||||
|
" <td>85.28</td>\n",
|
||||||
|
" <td>mid</td>\n",
|
||||||
|
" <td>unknown</td>\n",
|
||||||
|
" <td>no</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"<p>3681 rows × 12 columns</p>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" id gender age hypertension heart_disease ever_married \\\n",
|
||||||
|
"0 9046 male 60-69 no yes yes \n",
|
||||||
|
"2 31112 male 70+ no yes yes \n",
|
||||||
|
"3 60182 female 40-49 no no yes \n",
|
||||||
|
"4 1665 female 70+ yes no yes \n",
|
||||||
|
"5 56669 male 70+ no no yes \n",
|
||||||
|
"... ... ... ... ... ... ... \n",
|
||||||
|
"5102 45010 female 50-59 no no yes \n",
|
||||||
|
"5106 44873 female 70+ no no yes \n",
|
||||||
|
"5107 19723 female 30-39 no no yes \n",
|
||||||
|
"5108 37544 male 50-59 no no yes \n",
|
||||||
|
"5109 44679 female 40-49 no no yes \n",
|
||||||
|
"\n",
|
||||||
|
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
|
||||||
|
"0 private urban 228.69 high formerly_smoked \n",
|
||||||
|
"2 private rural 105.92 high never_smoked \n",
|
||||||
|
"3 private urban 171.23 high smokes \n",
|
||||||
|
"4 self_employed rural 174.12 mid never_smoked \n",
|
||||||
|
"5 private urban 186.21 high formerly_smoked \n",
|
||||||
|
"... ... ... ... ... ... \n",
|
||||||
|
"5102 private rural 77.93 mid never_smoked \n",
|
||||||
|
"5106 self_employed urban 125.20 high never_smoked \n",
|
||||||
|
"5107 self_employed rural 82.99 high never_smoked \n",
|
||||||
|
"5108 private rural 166.29 mid formerly_smoked \n",
|
||||||
|
"5109 govt_job urban 85.28 mid unknown \n",
|
||||||
|
"\n",
|
||||||
|
" stroke \n",
|
||||||
|
"0 yes \n",
|
||||||
|
"2 yes \n",
|
||||||
|
"3 yes \n",
|
||||||
|
"4 yes \n",
|
||||||
|
"5 yes \n",
|
||||||
|
"... ... \n",
|
||||||
|
"5102 no \n",
|
||||||
|
"5106 no \n",
|
||||||
|
"5107 no \n",
|
||||||
|
"5108 no \n",
|
||||||
|
"5109 no \n",
|
||||||
|
"\n",
|
||||||
|
"[3681 rows x 12 columns]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"data\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#podział danych na treningowy i testowy \n",
|
||||||
|
"data_train, data_test = train_test_split(data, random_state = 42)\n",
|
||||||
|
"\n",
|
||||||
|
"X_train =data_train[['gender', 'age', 'ever_married', 'Residence_type', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]\n",
|
||||||
|
"Y_train = data_train['stroke']\n",
|
||||||
|
"\n",
|
||||||
|
"#rozdzielenie etykiet i cech\n",
|
||||||
|
"X_test =data_test[['gender', 'age', 'ever_married', 'Residence_type', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]\n",
|
||||||
|
"Y_test = data_test['stroke']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class NaiveBayes:\n",
|
||||||
|
" def __init__(self):\n",
|
||||||
|
" self.features = list\n",
|
||||||
|
" self.likelihoods = {}\n",
|
||||||
|
" self.class_priors = {}\n",
|
||||||
|
" self.pred_priors = {}\n",
|
||||||
|
"\n",
|
||||||
|
" self.X_train = np.array\n",
|
||||||
|
" self.y_train = np.array\n",
|
||||||
|
" self.train_size = int\n",
|
||||||
|
" self.num_feats = int\n",
|
||||||
|
" \n",
|
||||||
|
" def fit(self, x_train, y_train):\n",
|
||||||
|
"\n",
|
||||||
|
" self.features = list(X.columns)\n",
|
||||||
|
" self.X_train = x_train\n",
|
||||||
|
" self.y_train = y_train\n",
|
||||||
|
" self.train_size = X.shape[0]\n",
|
||||||
|
" self.num_feats = X.shape[1]\n",
|
||||||
|
"\n",
|
||||||
|
" for feature in self.features:\n",
|
||||||
|
" self.likelihoods[feature] = {}\n",
|
||||||
|
" self.pred_priors[feature] = {}\n",
|
||||||
|
"\n",
|
||||||
|
" for feat_val in np.unique(self.X_train[feature]):\n",
|
||||||
|
" self.pred_priors[feature].update({feat_val: 0})\n",
|
||||||
|
"\n",
|
||||||
|
" for outcome in np.unique(self.y_train):\n",
|
||||||
|
" self.likelihoods[feature].update({feat_val+'_'+outcome:0})\n",
|
||||||
|
" self.class_priors.update({outcome: 0})\n",
|
||||||
|
"\n",
|
||||||
|
" self._calc_class_prior()\n",
|
||||||
|
" self._calc_likelihoods()\n",
|
||||||
|
" self._calc_predictor_prior()\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user