naive-bayes-gaussian/naive_bayes.ipynb

689 lines
257 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "TQqrOdkY6nsy"
},
"source": [
"# **Klasyfikacja za pomocą naiwnej metody bayesowskiej z rozkładem normalnym**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SSaJsYOhz8h8"
},
"source": [
"![rozklady.jpg](
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AlcfRFCPSXIj"
},
"source": [
"# **Twierdzenie Bayesa**\n",
"![bayes.svg](
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rcpTnWjOh5dq"
},
"source": [
"P(A) -- oznacza prawdopodobieństwo a-priori wystąpienia klasy A (tj. prawdopodobieństwo, że dowolny przykład należy do klasy A)\n",
"\n",
"P(B|A) -- oznacza prawdopodobieństwo a-posteriori, że B należy do \n",
"klasy A\n",
"\n",
"P(B) -- znacza prawdopodobieństwo a-priori wystąpienia przykładu B "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yabcm4Rei2ue"
},
"source": [
"![GaussianNB.png](
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dsf6FnlgjiOL"
},
"source": [
"# Funkcja gęstości prawdopodobieństwa rozkładu normalnego \n",
"![gestosc.svg](
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "v0oeHebytjNp"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.stats as stats\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"sns.set(style=\"whitegrid\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "fOYTA3VVtjNw"
},
"outputs": [],
"source": [
"class NaiveBayesClassifier():\n",
" def calc_prior(self, features, target):\n",
" '''\n",
" Wyliczenie prawdopodobieństwa a priori\n",
" '''\n",
" self.prior = (features.groupby(target).apply(lambda x: len(x)) / self.rows).to_numpy()\n",
"\n",
" return self.prior\n",
" \n",
" def calc_statistics(self, features, target):\n",
" '''\n",
" Wyliczenie średnich i wariancji dla danych\n",
" ''' \n",
" self.mean = features.groupby(target).apply(np.mean).to_numpy()\n",
" self.var = features.groupby(target).apply(np.var).to_numpy()\n",
" \n",
" return self.mean, self.var\n",
" \n",
" def gaussian_density(self, class_idx, x): \n",
" '''\n",
" Wyliczenie prawdopodobieństwa z rozkładu normalnego \n",
" (1/√2pi*σ) * exp((-1/2)*((x-μ)^2)/(2*σ²))\n",
" μ -średnia\n",
" σ² - wariancja\n",
" σ - odchylenie standardowe\n",
" '''\n",
" mean = self.mean[class_idx]\n",
" var = self.var[class_idx]\n",
" numerator = np.exp((-1/2)*((x-mean)**2) / (2 * var))\n",
" denominator = np.sqrt(2 * np.pi * var)\n",
" prob = numerator / denominator\n",
" return prob\n",
" \n",
" def calc_posterior(self, x):\n",
" '''\n",
" Wyliczenie prawdopodobieństwa a posteriori i zwrócenie klasy, dla której prawdopodobieństwo jest najwyższe\n",
" '''\n",
" posteriors = []\n",
" posteriors_no_log = []\n",
"\n",
" # calculate posterior probability for each class\n",
" for i in range(self.count):\n",
" prior = np.log(self.prior[i]) # Do predykcji używane jest prawodopodobieństwo logarytmiczne\n",
" prior_no_log = self.prior[i] # Zwykłe prawdopodobieństwo liczymy, żeby zwrócić je z predykcjami\n",
"\n",
" conditional = np.sum(np.log(self.gaussian_density(i, x))) \n",
" conditional_no_log = np.prod(self.gaussian_density(i, x))\n",
"\n",
" posterior = prior + conditional\n",
" posterior_no_log = prior_no_log * conditional_no_log\n",
"\n",
" posteriors.append(posterior)\n",
" posteriors_no_log.append(posterior_no_log)\n",
"\n",
" # Zwracamy klasę o największym prawdopodobieństwie\n",
" return self.classes[np.argmax(posteriors)], np.max(posteriors_no_log)\n",
"\n",
" def fit(self, features, target):\n",
" '''\n",
" Główna metoda trenująca model\n",
" '''\n",
" self.classes = np.unique(target)\n",
" self.count = len(self.classes)\n",
" self.feature_nums = features.shape[1]\n",
" self.rows = features.shape[0]\n",
" \n",
" self.calc_statistics(features, target)\n",
" self.calc_prior(features, target)\n",
" \n",
" def predict(self, features):\n",
" '''\n",
" Predykcja wartości dla każdego wiersza\n",
" '''\n",
" preds = [self.calc_posterior(f) for f in features.to_numpy()]\n",
" return preds\n",
"\n",
" def accuracy(self, y_test, y_pred):\n",
" '''\n",
" Wyliczenie accuracy modelu\n",
" '''\n",
" accuracy = np.sum(y_test == y_pred) / len(y_test)\n",
" return accuracy\n",
"\n",
" def visualize(self, y_true, y_pred, target):\n",
" '''\n",
" Narysowanie wykresu porównującego rozkład klas prawdziwych i przewidzianych\n",
" '''\n",
" tr = pd.DataFrame(data=y_true, columns=[target])\n",
" pr = pd.DataFrame(data=y_pred, columns=[target])\n",
" \n",
" \n",
" fig, ax = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(15,6))\n",
" \n",
" sns.countplot(x=target, data=tr, ax=ax[0], palette='viridis', alpha=0.7, hue=target, dodge=False)\n",
" sns.countplot(x=target, data=pr, ax=ax[1], palette='viridis', alpha=0.7, hue=target, dodge=False)\n",
" \n",
"\n",
" fig.suptitle('True vs Predicted Comparison', fontsize=20)\n",
"\n",
" ax[0].tick_params(labelsize=12)\n",
" ax[1].tick_params(labelsize=12)\n",
" ax[0].set_title(\"Prawdziwe wartości\", fontsize=18)\n",
" ax[1].set_title(\"Predyckje\", fontsize=18)\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 382
},
"id": "5-riUAGntjN2",
"outputId": "f87f047d-bc71-41ef-a43a-17b6f7cf84c3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2948, 9) (2948,)\n",
"(328, 9) (328,)\n"
]
}
],
"source": [
"# Preprocessing danych\n",
"\n",
"# Uzupełnienie pustych wartości w kolumnach\n",
"def fill_nan(df):\n",
" for index, column in enumerate(df.columns[:9]):\n",
" df[column] = df[column].fillna(df.groupby('Potability')[column].transform('mean'))\n",
" return df\n",
"\n",
"# Wczytywanie danych\n",
"df = pd.read_csv(\"water_potability.csv\")\n",
"\n",
"df = fill_nan(df)\n",
"\n",
"# Zrandomizowanie kolejności danych w datasecie\n",
"df = df.sample(frac=1, random_state=10).reset_index(drop=True)\n",
"\n",
"# Podział na atrybuty i przewidywane wartości\n",
"X, y = df.iloc[:, :-1], df.iloc[:, -1]\n",
"\n",
"# Normalizacja i skalowanie danych\n",
"from sklearn.preprocessing import StandardScaler\n",
"sc = StandardScaler()\n",
"X = sc.fit_transform(X.to_numpy())\n",
"X = pd.DataFrame(X, columns=df.columns.values.tolist()[:-1])\n",
"\n",
"# Podział na dane trenujące i testowe, z uwzględnieniem równego rozłożenia danych\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y, random_state=1)\n",
"\n",
"print(X_train.shape, y_train.shape)\n",
"print(X_test.shape, y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "O82SGzK6tjN5"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ph</th>\n",
" <th>Hardness</th>\n",
" <th>Solids</th>\n",
" <th>Chloramines</th>\n",
" <th>Sulfate</th>\n",
" <th>Conductivity</th>\n",
" <th>Organic_carbon</th>\n",
" <th>Trihalomethanes</th>\n",
" <th>Turbidity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1022</th>\n",
" <td>0.003078</td>\n",
" <td>0.688791</td>\n",
" <td>0.846257</td>\n",
" <td>1.428934</td>\n",
" <td>-0.858263</td>\n",
" <td>0.002792</td>\n",
" <td>0.913790</td>\n",
" <td>0.232417</td>\n",
" <td>2.319505</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3191</th>\n",
" <td>-0.587365</td>\n",
" <td>0.223203</td>\n",
" <td>-0.731867</td>\n",
" <td>0.397503</td>\n",
" <td>0.759893</td>\n",
" <td>0.330607</td>\n",
" <td>0.094379</td>\n",
" <td>0.282563</td>\n",
" <td>0.235024</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>0.003078</td>\n",
" <td>-0.241037</td>\n",
" <td>0.773051</td>\n",
" <td>0.580019</td>\n",
" <td>1.334369</td>\n",
" <td>-0.049130</td>\n",
" <td>-1.121422</td>\n",
" <td>-0.200432</td>\n",
" <td>-0.946356</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2068</th>\n",
" <td>-2.176058</td>\n",
" <td>1.443006</td>\n",
" <td>-1.626771</td>\n",
" <td>-4.164610</td>\n",
" <td>-0.033706</td>\n",
" <td>-1.050763</td>\n",
" <td>-0.391328</td>\n",
" <td>-0.398649</td>\n",
" <td>-0.298341</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1484</th>\n",
" <td>0.213047</td>\n",
" <td>0.403036</td>\n",
" <td>-0.464729</td>\n",
" <td>0.070417</td>\n",
" <td>0.021560</td>\n",
" <td>-0.952776</td>\n",
" <td>-0.213330</td>\n",
" <td>0.111419</td>\n",
" <td>-0.235893</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>691</th>\n",
" <td>0.003078</td>\n",
" <td>1.199106</td>\n",
" <td>-0.003483</td>\n",
" <td>-0.670308</td>\n",
" <td>-0.069513</td>\n",
" <td>0.185754</td>\n",
" <td>-0.466010</td>\n",
" <td>0.031975</td>\n",
" <td>0.676276</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1283</th>\n",
" <td>-2.034004</td>\n",
" <td>-1.508135</td>\n",
" <td>0.255310</td>\n",
" <td>0.083839</td>\n",
" <td>-1.413707</td>\n",
" <td>0.694074</td>\n",
" <td>-1.110579</td>\n",
" <td>0.232996</td>\n",
" <td>2.544703</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2818</th>\n",
" <td>-0.702987</td>\n",
" <td>-0.575677</td>\n",
" <td>0.755056</td>\n",
" <td>0.664695</td>\n",
" <td>0.021560</td>\n",
" <td>-0.489334</td>\n",
" <td>0.371852</td>\n",
" <td>-2.272990</td>\n",
" <td>-1.764684</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1330</th>\n",
" <td>1.525943</td>\n",
" <td>0.497074</td>\n",
" <td>-0.714355</td>\n",
" <td>-1.024237</td>\n",
" <td>-1.022037</td>\n",
" <td>-0.327074</td>\n",
" <td>-1.107341</td>\n",
" <td>0.517432</td>\n",
" <td>-1.230528</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1926</th>\n",
" <td>-0.043558</td>\n",
" <td>-0.882359</td>\n",
" <td>-0.456141</td>\n",
" <td>-0.770271</td>\n",
" <td>0.795189</td>\n",
" <td>0.560306</td>\n",
" <td>-1.086081</td>\n",
" <td>-1.356820</td>\n",
" <td>0.172521</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2948 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" ph Hardness Solids Chloramines Sulfate Conductivity \\\n",
"1022 0.003078 0.688791 0.846257 1.428934 -0.858263 0.002792 \n",
"3191 -0.587365 0.223203 -0.731867 0.397503 0.759893 0.330607 \n",
"13 0.003078 -0.241037 0.773051 0.580019 1.334369 -0.049130 \n",
"2068 -2.176058 1.443006 -1.626771 -4.164610 -0.033706 -1.050763 \n",
"1484 0.213047 0.403036 -0.464729 0.070417 0.021560 -0.952776 \n",
"... ... ... ... ... ... ... \n",
"691 0.003078 1.199106 -0.003483 -0.670308 -0.069513 0.185754 \n",
"1283 -2.034004 -1.508135 0.255310 0.083839 -1.413707 0.694074 \n",
"2818 -0.702987 -0.575677 0.755056 0.664695 0.021560 -0.489334 \n",
"1330 1.525943 0.497074 -0.714355 -1.024237 -1.022037 -0.327074 \n",
"1926 -0.043558 -0.882359 -0.456141 -0.770271 0.795189 0.560306 \n",
"\n",
" Organic_carbon Trihalomethanes Turbidity \n",
"1022 0.913790 0.232417 2.319505 \n",
"3191 0.094379 0.282563 0.235024 \n",
"13 -1.121422 -0.200432 -0.946356 \n",
"2068 -0.391328 -0.398649 -0.298341 \n",
"1484 -0.213330 0.111419 -0.235893 \n",
"... ... ... ... \n",
"691 -0.466010 0.031975 0.676276 \n",
"1283 -1.110579 0.232996 2.544703 \n",
"2818 0.371852 -2.272990 -1.764684 \n",
"1330 -1.107341 0.517432 -1.230528 \n",
"1926 -1.086081 -1.356820 0.172521 \n",
"\n",
"[2948 rows x 9 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "a3jkTMFLtjN6"
},
"outputs": [],
"source": [
"# Trenowanie modelu klasyfikatora\n",
"x = NaiveBayesClassifier()\n",
"x.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "CoC22aNgtjN9"
},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Predykcja wartości dla danych testowych\n",
"predictions = x.predict(X_test)\n",
"\n",
"# Prawdopodobieństwa kolejnych predykcji\n",
"probabilities = [p[1] for p in predictions]\n",
"\n",
"# Przewidziana wartość\n",
"predictions = [p[0] for p in predictions]\n",
"predictions[0]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "JR06zodmtjN9"
},
"outputs": [
{
"data": {
"text/plain": [
"0.6280487804878049"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wyliczenie accuracy modelu\n",
"x.accuracy(y_test, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "1jW0QPootjN_"
},
"outputs": [
{
"data": {
"text/plain": [
"0.14084507042253522"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import f1_score\n",
"\n",
"f1_score(y_test, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "vEVogTmAtjOA"
},
"outputs": [
{
"data": {
"text/plain": [
"0 0.609756\n",
"1 0.390244\n",
"Name: Potability, dtype: float64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_test.value_counts(normalize=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "jCVOdBZytjOB"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4oAAAGnCAYAAADxIKk5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzs3XdYFFf7N/AvvUuRYo8E3LUgRWmComDDFiXWJ0ZjbxFbfIwajT5qophEjcSIiYrdWCIaK9aIGECR2MWCBcFCFaUIKPP+4W/nddxFF6Up3891eV1yzpmZeweWm3vnnBkNQRAEEBEREREREf0fzYoOgIiIiIiIiCoXFopEREREREQkwUKRiIiIiIiIJFgoEhERERERkQQLRSIiIiIiIpJgoUhEREREREQSLBSJiOi9dOzYMcjlcqxatUrS/umnn8LFxaWCoiq59y3e91lQUBDkcjkuXLhQ0aEQEVV62hUdABFRaZPL5SUaP3/+fHz66adlFM3769q1a+jWrZukTVtbG+bm5mjatCkGDhyIFi1aVFB0ZefYsWMYNWoUpkyZgqFDh1Z0OMVKT0/Hxo0bceLECSQmJiI7OxvGxsaQyWTw8fFBz549YWFhUdFhEhHRe4qFIhF9cMaOHavUtnbtWjx58gQDBw5EtWrVJH2NGjUqr9DeSxYWFvjss88AAE+fPsXly5dx9OhRHD16FPPmzUPv3r0rOEKp4OBgFBQUVHQYZWr//v2YPn06cnNz8fHHH6Njx44wNzdHVlYWzp07hx9//BEhISE4fvw4jI2NKzrcSmP48OHo1asX6tSpU9GhEBFVeiwUieiDExgYqNQWFhaGJ0+e4IsvvuAfiSVUvXp1pXO6fv16zJs3Dz/88AMCAgKgrV150knt2rUrOoQyFRERgUmTJkFfXx+LFi1Cly5dlMZcunQJc+fO/eAL5pKysLDgVVYiIjVxjSIR0f9RrBV7+vQpFi9ejPbt28PBwQFz5swB8Pr1TdeuXYNcLhfHviwnJwe//PILunXrBicnJ7i4uOCzzz7DwYMH1YorOzsbTk5O8PX1hSAIKsdMnjwZcrkcp06dEtuioqIwbNgwtGrVCg4ODmjZsiX69euH3377Ta3jvk7fvn2hra2NrKws3LlzB4B0zWBsbCyGDh0KNzc3yOVyZGRkiNsmJSXh22+/hZ+fHxwcHODh4YGxY8fiypUrKo/14MEDTJkyBR4eHnBycsKnn36KvXv3Fhvb69b8HTt2DMOHD4enpyccHBzQpk0bBAYG4vTp0wCAcePGYdSoUQCAhQsXQi6Xi/9e/r4LgoAdO3agf//+cHV1RdOmTdG1a1f89ttvKCwsVHnsHTt2oHv37mjatCm8vb0xbdo0yXlRR2FhIWbPno2ioiLMmTNHZZEIAE2aNMHGjRthamoqaf/3338xevRoeHh4wMHBAW3btsV3332nMo5x48ZBLpcjPT0dq1evhr+/P5o2bYp27dph9erV4rhdu3YhICAATk5O8Pb2xoIFC5TOQU5ODuRyOUaOHInk5GRMnDgRHh4ecHR0RK9evVS+F54+fYq1a9di6NChaNOmjfizMnToUERFRal83R4eHujatSsePXqEuXPnok2bNmjcuLG4jrW493BJ3iv379/HzJkzxZi8vLwwfvx4xMfHK43dsGED5HI5Dhw4gIiICPznP/+Bi4sLmjdvjjFjxojvHSKiyqjyfARMRFQJFBUVYeTIkbh16xZatmwJMzOzd7pClZGRgQEDBuDGjRtwdHRE7969UVhYiBMnTiAwMBBfffUVRowY8dp9GBsbo127dtizZw9iYmLg6ekp6c/Ozsbhw4dRu3ZtuLm5AQDCw8Mxbtw4mJmZwc/PD1ZWVsjMzMSNGzewZcuWNx7zXURHR2PRokXw9PREr169kJaWJl5xjIuLw4gRI5CTkwMfHx907NgR6enpOHToECIiIvD777/Dw8ND3NfDhw/Rt29fPHjwQCwU79+/j6+//hqtWrUqUVwLFixAaGgoTExM0LZtW9jY2ODhw4eIjY3F/v374ebmBn9/f+jo6GDPnj3w9vaWFJzW1tYAXhSJkyZNwr59+1C7dm34+/vDyMgIZ86cwU8//YTY2FiEhIRAU/P/fxa7bNkyLF26FGZmZujZsycMDQ1x/PhxcUqvuk6cOIHk5GTUr18fXbt2fe1YLS0tydf79u3Df//7X2hqasLf3x81atTA2bNnsW7dOhw5cgSbN2+GjY2N0n5mz56N2NhYtGnTBi1btsThw4cRFBQEQRBQWFiIFStWoG3btnB3d0dERARCQ0OhoaGBr7/+Wmlf6enp6Nu3L2xsbNC7d29kZmZi//79CAwMxKxZsyTn4+HDhwgKCkKzZs3QsmVLmJub4+HDhzh69CgGDx6MH3/8UeU5yMvLw+eff47CwkL4+PjAwMAANWvWLPY8leS9kpCQgM8//xwZGRlo2bIlPvnkEyQlJSE8PBx///03QkJCVK7b3bdvH44cOYI2bdqgX79+uHr1Ko4cOYKLFy9i3759nB5MRJWTQERUBfj6+goymUy4e/dusWMCAgIEmUwm9OzZU8jKylLqX7BggSCTyYTz588r9V29elWQyWTC//73P0l7YGCgIJPJhA0bNkjac3Nzhf79+wuNGjUSbt68+cb4IyMjBZlMJnz99ddKfdu3bxdkMpnw888/i21DhgwRZDKZcPv2baXx6enpbzzey6+pS5cuSn1r164VZDKZ4O7uLhQWFgqCIAhHjx4VZDKZIJPJhF27dilt8/TpU6FVq1aCs7OzcO7cOUnf3bt3BQ8PD8HX11fcnyAIwqRJkwSZTCYsXrxYMv706dOCXC4XZDKZsHLlSklfQECA4OzsLGkLDw8XZDKZ4O/vL6SlpUn6ioqKhAcPHohfK17Hq/tVWL9+vSCTyYSvvvpKyM/Pl+xH8TOydetWsf3GjRtCo0aNBC8vL8lxCgsLhaFDhwoymUwp3uIsXLhQkMlkwsyZM9Uar5CZmSm4uLgITZo0Ufr5Xbx4sSCTyYQvv/xS0q742e3YsaPknKWlpQkuLi6Ci4uL4OXlJdy5c0fsy83NFXx9fQUnJyfhyZMnYnt2drb4szF16lShqKhI7EtISBCcnZ2Fpk2bSs5Pbm6u8PDhQ6XXkpGRIbRr107w9vaW/KwIgiC4u7sLMplMGDlypPD06VOlbVW9h0vyXunXr58gk8mENWvWSNpPnjwpyOVywdvbW/IzofhZadKkiXDmzBnJNnPmzFH5u4GIqLLg1FMiold89dVXSje8eRsPHjzAwYMH4e7ujv79+0v6DAwMMHHiRDx//vy10ygVWrRogRo1aiA8PBy5ubmSvp07dwIAevToIWnX0NCAnp6e0r5KukYrPT0dwcHBCA4Oxg8//IDBgwfju+++A/Biyuur6xObN2+OTz75RGk/4eHhePjwIYYOHQpHR0dJX506dfDFF18gOTkZ//77L4AX0xUPHDgAc3NzjBw5UjLe1dUVHTp0UPs1rF+/HgAwY8YMVK9eXdKnoaGh8kpacdatWwcDAwPMnTsXurq6kv1MnDgRBgYG2L17t9i+c+dOPH/+HIMHD5YcR1tbG1OmTFH7uACQmpoKAKhRo0aJtjtw4ABycnIQEBCApk2bSvpGjx4NKysrHDlypNgpqC+fs+rVq6NVq1bIycnBoEGDUK9ePbHPwMAAHTt2RF5eHm7fvq20Lx0dHUyaNAkaGhpi28cff4x+/fohPz9f8l4wMDAQr+K+zNzcHN27d0dqaiquXr2q8vVOnz5d5c9+cdR5r9y8eRNxcXGwtbXFgAEDJOO8vLzQtm1bpKam4u+//1baz6effopmzZpJ2vr06QMAfFQHEVVanHpKRPSKV/+Qfltnz56FIAh49uwZgoODlfoVBd/NmzffuC9NTU10794dK1aswMGDB8WiMCkpCadPn4arq6vkD/Zu3bohMjIS3bt3R+fOneHh4YFmzZqp/MP7TTIyMvDLL78AeDGd0czMDL6+vhg4cCC8vLyUxr9aBCqcPXsWAHD79m2V5+P69esAXkzvc3Nzw7V
"text/plain": [
"<Figure size 1080x432 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x.visualize(y_test, predictions, 'Potability')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "aw8Tefprhjnn"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python3/dist-packages/matplotlib/cbook/__init__.py:1377: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead.\n",
" x[:, None]\n",
"/usr/lib/python3/dist-packages/matplotlib/axes/_base.py:237: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead.\n",
" x = x[:, np.newaxis]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAELCAYAAAAry2Y+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XtcFPX+P/DX7rLLVYUQFNPykhiKpqJHMz3eUxTFjMRU8FJ6vKX57WYPTa3ItHrksbxlP29odkozvJKWmpdKT5KZpB4M8QYIKKjIIiy78/uDdmNhL7MXdpfp9Xw8fMjufOYz75nPe947O7s7IxMEQQAREdVpcncHQEREjmMxJyKSABZzIiIJYDEnIpIAFnMiIglgMScikgAWcyJyqoSEBGzbts3weNmyZejWrRueeOIJN0YlfSzmHuTkyZP45z//WeP56jsHkSucOnUKo0ePRlRUFP7xj39g9OjR+O2332zqIzc3Fxs2bMC+ffvwww8/WG0/d+5cLFu2zN6Q/9a83B0AEXmee/fuYerUqVi0aBGio6Oh0Whw6tQpqFQqm/rJzs5GYGAggoODaylS0uORuRv069cPn3zyCYYMGYKuXbvi9ddfR1lZmbvDIjLIysoCAMTExEChUMDHxwc9e/bEo48+io8//hgvv/yyoe3169fRpk0bVFRUGPXx448/YtKkScjPz0enTp0wd+5cAMCsWbPwxBNPICoqCmPHjsXFixcBAF988QV2796NdevWoVOnTpg6dSoAIC8vDy+88AK6d++Ofv36ITk52RWboM5hMXcTfdJ+++23yMrKwqpVq9wdEpFBixYtoFAo8Nprr+HIkSO4c+eOzX306NEDn376KUJDQ3H69GksWbIEAPDPf/4T+/fvx08//YS2bdsaXhji4+MxbNgwPPfcczh9+jTWrFkDnU6HadOmoU2bNjh69Cg2bdqETZs24dixY05dXylgMXeTsWPHIiwsDIGBgZg2bRr27t0LAMjPz0eXLl2M/qWlpbk5Wvq7CQgIwNatWyGTyfDGG2/g8ccfx9SpU3Hz5k2H+46Li0NAQABUKhVeeOEFXLhwAcXFxSbbnj17FoWFhZg5cyZUKhWaNWuGUaNGYd++fQ7HITU8Z+4mYWFhhr+bNGmC/Px8AEBoaCiOHj1q1DYhIcGlsREBQKtWrQxH05mZmXjllVewePFitGjRwu4+tVotli1bhm+++QaFhYWQyyuPJ4uKilCvXr0a7bOzsw0HOFX7qPqYKrGYu0lubq7h75ycHISGhroxGiLLWrVqhZEjR+KLL75A27Ztcf/+fcM0W47Wd+/ejYMHD2LDhg1o2rQpiouL0bVrV+gv3iqTyYzah4WFoWnTpjhw4IBzVkTCeJrFTbZu3YobN27g9u3bhg9DiTxFZmYm1q9fjxs3bgCoPPjYs2cPHnvsMURERODnn39GTk4OiouL8cknn4jut6SkBCqVCkFBQSgtLcWHH35oND04OBjXr183PO7QoQMCAgKwdu1a3L9/H1qtFhkZGTZ/RfLvgMXcTWJiYjBp0iQMGDAAzZo1w7Rp09wdEpFBQEAAzpw5g2eeeQYdO3bEqFGjEB4ejrlz5+KJJ57AkCFDMHz4cIwcORJ9+/YV3e+IESPQpEkT9OrVC0OHDkXHjh2NpsfFxeGPP/5Aly5dMH36dCgUCqxevRoXLlxA//790b17d8yfPx/37t1z9irXeTLenML1+vXrh6SkJPTo0cPdoRCRRPDInIhIAljMiYgkgKdZiIgkgEfmREQSYPV75kuXLsX+/fuRnZ2N3bt3Izw8XHTnOp0OJSUlUCqVNb4/SuQoQRCg0Wjg7+9v+PGJLezNbeY11SZ789pqMe/fvz8SExMxduxYm4MqKSlBRkaGzfMR2SI8PNzkrwetsTe3mdfkCrbmtdVi7sjPZpVKpSEoWy+dWVV6ejoiIyPtnt+UE+m5SN57HuUV2hrTVF4K9OjQGD/+dsNouv75Y7/mQKur9lGDDJDLZNBVf/5PEQ8H4aVx4rflifRcrN/9u8n+5HIZJg1rh+6RYSbnq75eKi8FEodGmGxvjrlt7qz+naG8vBwZGRmGPLOVvbltKq/NbRdzeeTI9rKUu2K5a8zMqY19vDpn5a6lsTZZG6pQyGXo1bGJxZywN69r9ef8+regKpUK3t7eDvXl6PzVbUrNQMGdcjNTtdj1wzUThdTc89b9dO6mTeuwKTUDhcUVFqf3jmpu8vma66U1294SU/E6s39ncfWpDlN5bW67mMsjR7aX5dwVy71jZoqz9/HqnJW7to11TWJzwta8dsm1WdLT0x3uw9lXDiwoKrU43dyg2FPI9WxZB2vxFRSVmuzP3Hzm2ltS2/3XdVXz2tx2MZcvjmwva7lhSz+eNGa1HYuzctfWsRbbztHxcEkxj4yMdOhVNy0tDVFRUU6MCAhJvWVxp5DLTZ8yMfe8GLasg7X4QoJ8TfZnbj5z7c0xt82d1b8zlJWVOeVAwV5V89rcdjGXL45sL2u5IbofN4yZObWxj1fnrNy1dazFttPHYW9e/22/mpgYHQFvpcLkNG+lAoO7PVRjuv55L0XNtz8yWeX5MHMee8S222YlRkeY7U8hlyExOsLsfKbiNtfeVrXdf11lbruYyyNHtpel3BXr7zhmzspdS2NtqjZU5aWQ1UpOACKOzJOSknDgwAHcvHkTEydORGBgoOFGCnVZn6hmAIDk1PMoKCo1vFqGBPkiMToCfaKaIaJFMJJTz+NmUSkaVnt+bcpZFKs1AIB6fkpMGdHeqL+qHnskGEnTetoVX9XlVF2Wfrql9aoetzPUdv+u5MzctrRdzOWRvaovK8Cv8oOye2oNGgb54uFgGa7cEoymFas1JnP878RZuatv//9SzuCuWmuxNlRVdd91dk4AtfwLUP3bBU88zeIqdTX2uhC3s/KrrixXrLowdtUx5r/Ym19/29MsRERSwmJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSwGJORCQBLOZERBLAYk5EJAEs5kREEsBiTkQkASzmREQSIKqYZ2VlIT4+HoMGDUJ8fDwuX75cy2ERuQZzm6TCS0yjhQsXYsyYMYiNjcXOnTuxYMECJCcn273Q79OuITn1PG4WlSLATwkAuKfWoGGQLxKjI9AnqpnJedamnEWxWmP0vLdSDpVSUeN5j7L1ursjMAgJ8oW2QovC4nKj5+v5KY3GoJ6VfqqOoblxE9PG0WU4ytm57anM7T8AIJMBguCGoKrzoP1Er56fElNGtAcAJKeeR0FRKeRyGXQ6AQ38FHge15yek/ayemR+69YtnDt3DjExMQCAmJgYnDt3DoWFhXYt8Pu0a1ix7QwKikohAChWa1Cs1kAAUFBUihXbzuD7tGtG8/yWVYLlX5w2mYhlGp1nF3IPU1BUWqOQA6gxBr9llZjto/oYmho3MW0scXR+MZyd257q+7Rr+Pd/TO8/gIcUcg9VrNbgw62/YPkXp1FQVAoA0OkqN9gdtdbpOekIq8U8NzcXjRo1gkKhAAAoFAqEhoYiNzfXrgUmp55HmUZrdnqZRovk1PNGzx08cxcVWmacq5RptDh45q7Z6abGsPq4iWljiaPzi+Hs3PZUyannodVx/7GXAJitP87OSUeIOs3iqPT0dMPf+lc3SwqKSpGWlmZ4fEdtvvjXdRp1IbIOLUHrIe9CJle4OxyDO2qt0RhUZW4Mq46bmDaWODq/K1TNa09TdRuJ2eccdf3kOtRr8hgaNOtS68syJ2PPq2je91Wo/Bu6dLmekpNWi3l
"text/plain": [
"<Figure size 432x288 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ph_val = X_test[\"ph\"]\n",
"sulfate_val = X_test[\"Sulfate\"]\n",
"hard_val = X_test[\"Hardness\"]\n",
"carb_val = X_test[\"Organic_carbon\"]\n",
"turb_val = X_test[\"Turbidity\"]\n",
"ch_val = X_test[\"Chloramines\"]\n",
"\n",
"\n",
"figure, axes = plt.subplots(nrows=3, ncols=2)\n",
"\n",
"axes[0, 0].plot(ph_val, predictions, 'bo')\n",
"axes[0, 0].set_title(\"pH\")\n",
"\n",
"axes[0, 1].plot(sulfate_val, predictions, 'bo')\n",
"axes[0, 1].set_title(\"Sulfate\")\n",
"\n",
"axes[1, 0].plot(hard_val, predictions, 'bo')\n",
"axes[1, 0].set_title(\"Hardness\")\n",
"\n",
"axes[1, 1].plot(carb_val, predictions, 'bo')\n",
"axes[1, 1].set_title(\"Organic carbon\")\n",
"\n",
"axes[2, 0].plot(turb_val, predictions, 'bo')\n",
"axes[2, 0].set_title(\"Turbidity\")\n",
"\n",
"axes[2, 1].plot(ch_val, predictions, 'bo')\n",
"axes[2, 1].set_title(\"Chloramines\")\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "naive_bayes.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}