naivebayes

This commit is contained in:
s434732 2021-05-29 12:20:28 +02:00
commit 0517754510
2 changed files with 5501 additions and 0 deletions

File diff suppressed because it is too large Load Diff

390
main.ipynb Normal file
View File

@ -0,0 +1,390 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd \n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"#Wczytanie i normalizacja danych\n",
"def NormalizeData(data):\n",
" for col in data.columns:\n",
" if data[col].dtype == object: \n",
" data[col] = data[col].str.lower()\n",
" if col == 'smoking_status':\n",
" data[col] = data[col].str.replace(\" \", \"_\")\n",
" if col == 'work_type':\n",
" data[col] = data[col].str.replace(\"-\", \"_\")\n",
" if col == 'bmi':\n",
" bins = [0, 21, 28, 40]\n",
" labels=['low','mid','high']\n",
" data[col] = pd.cut(data[col], bins=bins, labels=labels)\n",
" if col == 'age':\n",
" bins = [18, 30, 40, 50, 60, 70, 120]\n",
" labels = ['18-29', '30-39', '40-49', '50-59', '60-69', '70+']\n",
" data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)\n",
" if col == 'stroke':\n",
" data[col] = data[col].replace({1: 'yes'})\n",
" data[col] = data[col].replace({0: 'no'})\n",
" if col == 'hypertension':\n",
" data[col] = data[col].replace({1: 'yes'})\n",
" data[col] = data[col].replace({0: 'no'})\n",
" if col == 'heart_disease':\n",
" data[col] = data[col].replace({1: 'yes'})\n",
" data[col] = data[col].replace({0: 'no'})\n",
" data = data.dropna()\n",
" return data\n",
"\n",
"data = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
"data = NormalizeData(data)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9046</td>\n",
" <td>male</td>\n",
" <td>60-69</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>private</td>\n",
" <td>urban</td>\n",
" <td>228.69</td>\n",
" <td>high</td>\n",
" <td>formerly_smoked</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31112</td>\n",
" <td>male</td>\n",
" <td>70+</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>private</td>\n",
" <td>rural</td>\n",
" <td>105.92</td>\n",
" <td>high</td>\n",
" <td>never_smoked</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>60182</td>\n",
" <td>female</td>\n",
" <td>40-49</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>private</td>\n",
" <td>urban</td>\n",
" <td>171.23</td>\n",
" <td>high</td>\n",
" <td>smokes</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1665</td>\n",
" <td>female</td>\n",
" <td>70+</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>self_employed</td>\n",
" <td>rural</td>\n",
" <td>174.12</td>\n",
" <td>mid</td>\n",
" <td>never_smoked</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>56669</td>\n",
" <td>male</td>\n",
" <td>70+</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>private</td>\n",
" <td>urban</td>\n",
" <td>186.21</td>\n",
" <td>high</td>\n",
" <td>formerly_smoked</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5102</th>\n",
" <td>45010</td>\n",
" <td>female</td>\n",
" <td>50-59</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>private</td>\n",
" <td>rural</td>\n",
" <td>77.93</td>\n",
" <td>mid</td>\n",
" <td>never_smoked</td>\n",
" <td>no</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5106</th>\n",
" <td>44873</td>\n",
" <td>female</td>\n",
" <td>70+</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>self_employed</td>\n",
" <td>urban</td>\n",
" <td>125.20</td>\n",
" <td>high</td>\n",
" <td>never_smoked</td>\n",
" <td>no</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5107</th>\n",
" <td>19723</td>\n",
" <td>female</td>\n",
" <td>30-39</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>self_employed</td>\n",
" <td>rural</td>\n",
" <td>82.99</td>\n",
" <td>high</td>\n",
" <td>never_smoked</td>\n",
" <td>no</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5108</th>\n",
" <td>37544</td>\n",
" <td>male</td>\n",
" <td>50-59</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>private</td>\n",
" <td>rural</td>\n",
" <td>166.29</td>\n",
" <td>mid</td>\n",
" <td>formerly_smoked</td>\n",
" <td>no</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5109</th>\n",
" <td>44679</td>\n",
" <td>female</td>\n",
" <td>40-49</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>govt_job</td>\n",
" <td>urban</td>\n",
" <td>85.28</td>\n",
" <td>mid</td>\n",
" <td>unknown</td>\n",
" <td>no</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3681 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 9046 male 60-69 no yes yes \n",
"2 31112 male 70+ no yes yes \n",
"3 60182 female 40-49 no no yes \n",
"4 1665 female 70+ yes no yes \n",
"5 56669 male 70+ no no yes \n",
"... ... ... ... ... ... ... \n",
"5102 45010 female 50-59 no no yes \n",
"5106 44873 female 70+ no no yes \n",
"5107 19723 female 30-39 no no yes \n",
"5108 37544 male 50-59 no no yes \n",
"5109 44679 female 40-49 no no yes \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 private urban 228.69 high formerly_smoked \n",
"2 private rural 105.92 high never_smoked \n",
"3 private urban 171.23 high smokes \n",
"4 self_employed rural 174.12 mid never_smoked \n",
"5 private urban 186.21 high formerly_smoked \n",
"... ... ... ... ... ... \n",
"5102 private rural 77.93 mid never_smoked \n",
"5106 self_employed urban 125.20 high never_smoked \n",
"5107 self_employed rural 82.99 high never_smoked \n",
"5108 private rural 166.29 mid formerly_smoked \n",
"5109 govt_job urban 85.28 mid unknown \n",
"\n",
" stroke \n",
"0 yes \n",
"2 yes \n",
"3 yes \n",
"4 yes \n",
"5 yes \n",
"... ... \n",
"5102 no \n",
"5106 no \n",
"5107 no \n",
"5108 no \n",
"5109 no \n",
"\n",
"[3681 rows x 12 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"#podział danych na treningowy i testowy \n",
"data_train, data_test = train_test_split(data, random_state = 42)\n",
"\n",
"X_train =data_train[['gender', 'age', 'ever_married', 'Residence_type', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]\n",
"Y_train = data_train['stroke']\n",
"\n",
"#rozdzielenie etykiet i cech\n",
"X_test =data_test[['gender', 'age', 'ever_married', 'Residence_type', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]\n",
"Y_test = data_test['stroke']"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class NaiveBayes:\n",
" def __init__(self):\n",
" self.features = list\n",
" self.likelihoods = {}\n",
" self.class_priors = {}\n",
" self.pred_priors = {}\n",
"\n",
" self.X_train = np.array\n",
" self.y_train = np.array\n",
" self.train_size = int\n",
" self.num_feats = int\n",
" \n",
" def fit(self, x_train, y_train):\n",
"\n",
" self.features = list(X.columns)\n",
" self.X_train = x_train\n",
" self.y_train = y_train\n",
" self.train_size = X.shape[0]\n",
" self.num_feats = X.shape[1]\n",
"\n",
" for feature in self.features:\n",
" self.likelihoods[feature] = {}\n",
" self.pred_priors[feature] = {}\n",
"\n",
" for feat_val in np.unique(self.X_train[feature]):\n",
" self.pred_priors[feature].update({feat_val: 0})\n",
"\n",
" for outcome in np.unique(self.y_train):\n",
" self.likelihoods[feature].update({feat_val+'_'+outcome:0})\n",
" self.class_priors.update({outcome: 0})\n",
"\n",
" self._calc_class_prior()\n",
" self._calc_likelihoods()\n",
" self._calc_predictor_prior()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}