689 lines
257 KiB
Plaintext
689 lines
257 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "TQqrOdkY6nsy"
|
||
},
|
||
"source": [
|
||
"# **Klasyfikacja za pomocą naiwnej metody bayesowskiej z rozkładem normalnym**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "SSaJsYOhz8h8"
|
||
},
|
||
"source": [
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "AlcfRFCPSXIj"
|
||
},
|
||
"source": [
|
||
"# **Twierdzenie Bayesa**\n",
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "rcpTnWjOh5dq"
|
||
},
|
||
"source": [
|
||
"P(A) -- oznacza prawdopodobieństwo a-priori wystąpienia klasy A (tj. prawdopodobieństwo, że dowolny przykład należy do klasy A)\n",
|
||
"\n",
|
||
"P(B|A) -- oznacza prawdopodobieństwo a-posteriori, że B należy do \n",
|
||
"klasy A\n",
|
||
"\n",
|
||
"P(B) -- znacza prawdopodobieństwo a-priori wystąpienia przykładu B "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Yabcm4Rei2ue"
|
||
},
|
||
"source": [
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "dsf6FnlgjiOL"
|
||
},
|
||
"source": [
|
||
"# Funkcja gęstości prawdopodobieństwa rozkładu normalnego \n",
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"id": "v0oeHebytjNp"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import scipy.stats as stats\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"sns.set(style=\"whitegrid\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"id": "fOYTA3VVtjNw"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class NaiveBayesClassifier():\n",
|
||
" def calc_prior(self, features, target):\n",
|
||
" '''\n",
|
||
" Wyliczenie prawdopodobieństwa a priori\n",
|
||
" '''\n",
|
||
" self.prior = (features.groupby(target).apply(lambda x: len(x)) / self.rows).to_numpy()\n",
|
||
"\n",
|
||
" return self.prior\n",
|
||
" \n",
|
||
" def calc_statistics(self, features, target):\n",
|
||
" '''\n",
|
||
" Wyliczenie średnich i wariancji dla danych\n",
|
||
" ''' \n",
|
||
" self.mean = features.groupby(target).apply(np.mean).to_numpy()\n",
|
||
" self.var = features.groupby(target).apply(np.var).to_numpy()\n",
|
||
" \n",
|
||
" return self.mean, self.var\n",
|
||
" \n",
|
||
" def gaussian_density(self, class_idx, x): \n",
|
||
" '''\n",
|
||
" Wyliczenie prawdopodobieństwa z rozkładu normalnego \n",
|
||
" (1/√2pi*σ) * exp((-1/2)*((x-μ)^2)/(2*σ²))\n",
|
||
" μ -średnia\n",
|
||
" σ² - wariancja\n",
|
||
" σ - odchylenie standardowe\n",
|
||
" '''\n",
|
||
" mean = self.mean[class_idx]\n",
|
||
" var = self.var[class_idx]\n",
|
||
" numerator = np.exp((-1/2)*((x-mean)**2) / (2 * var))\n",
|
||
" denominator = np.sqrt(2 * np.pi * var)\n",
|
||
" prob = numerator / denominator\n",
|
||
" return prob\n",
|
||
" \n",
|
||
" def calc_posterior(self, x):\n",
|
||
" '''\n",
|
||
" Wyliczenie prawdopodobieństwa a posteriori i zwrócenie klasy, dla której prawdopodobieństwo jest najwyższe\n",
|
||
" '''\n",
|
||
" posteriors = []\n",
|
||
" posteriors_no_log = []\n",
|
||
"\n",
|
||
" # calculate posterior probability for each class\n",
|
||
" for i in range(self.count):\n",
|
||
" prior = np.log(self.prior[i]) # Do predykcji używane jest prawodopodobieństwo logarytmiczne\n",
|
||
" prior_no_log = self.prior[i] # Zwykłe prawdopodobieństwo liczymy, żeby zwrócić je z predykcjami\n",
|
||
"\n",
|
||
" conditional = np.sum(np.log(self.gaussian_density(i, x))) \n",
|
||
" conditional_no_log = np.prod(self.gaussian_density(i, x))\n",
|
||
"\n",
|
||
" posterior = prior + conditional\n",
|
||
" posterior_no_log = prior_no_log * conditional_no_log\n",
|
||
"\n",
|
||
" posteriors.append(posterior)\n",
|
||
" posteriors_no_log.append(posterior_no_log)\n",
|
||
"\n",
|
||
" # Zwracamy klasę o największym prawdopodobieństwie\n",
|
||
" return self.classes[np.argmax(posteriors)], np.max(posteriors_no_log)\n",
|
||
"\n",
|
||
" def fit(self, features, target):\n",
|
||
" '''\n",
|
||
" Główna metoda trenująca model\n",
|
||
" '''\n",
|
||
" self.classes = np.unique(target)\n",
|
||
" self.count = len(self.classes)\n",
|
||
" self.feature_nums = features.shape[1]\n",
|
||
" self.rows = features.shape[0]\n",
|
||
" \n",
|
||
" self.calc_statistics(features, target)\n",
|
||
" self.calc_prior(features, target)\n",
|
||
" \n",
|
||
" def predict(self, features):\n",
|
||
" '''\n",
|
||
" Predykcja wartości dla każdego wiersza\n",
|
||
" '''\n",
|
||
" preds = [self.calc_posterior(f) for f in features.to_numpy()]\n",
|
||
" return preds\n",
|
||
"\n",
|
||
" def accuracy(self, y_test, y_pred):\n",
|
||
" '''\n",
|
||
" Wyliczenie accuracy modelu\n",
|
||
" '''\n",
|
||
" accuracy = np.sum(y_test == y_pred) / len(y_test)\n",
|
||
" return accuracy\n",
|
||
"\n",
|
||
" def visualize(self, y_true, y_pred, target):\n",
|
||
" '''\n",
|
||
" Narysowanie wykresu porównującego rozkład klas prawdziwych i przewidzianych\n",
|
||
" '''\n",
|
||
" tr = pd.DataFrame(data=y_true, columns=[target])\n",
|
||
" pr = pd.DataFrame(data=y_pred, columns=[target])\n",
|
||
" \n",
|
||
" \n",
|
||
" fig, ax = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(15,6))\n",
|
||
" \n",
|
||
" sns.countplot(x=target, data=tr, ax=ax[0], palette='viridis', alpha=0.7, hue=target, dodge=False)\n",
|
||
" sns.countplot(x=target, data=pr, ax=ax[1], palette='viridis', alpha=0.7, hue=target, dodge=False)\n",
|
||
" \n",
|
||
"\n",
|
||
" fig.suptitle('True vs Predicted Comparison', fontsize=20)\n",
|
||
"\n",
|
||
" ax[0].tick_params(labelsize=12)\n",
|
||
" ax[1].tick_params(labelsize=12)\n",
|
||
" ax[0].set_title(\"Prawdziwe wartości\", fontsize=18)\n",
|
||
" ax[1].set_title(\"Predyckje\", fontsize=18)\n",
|
||
" plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/",
|
||
"height": 382
|
||
},
|
||
"id": "5-riUAGntjN2",
|
||
"outputId": "f87f047d-bc71-41ef-a43a-17b6f7cf84c3"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"(2948, 9) (2948,)\n",
|
||
"(328, 9) (328,)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Preprocessing danych\n",
|
||
"\n",
|
||
"# Uzupełnienie pustych wartości w kolumnach\n",
|
||
"def fill_nan(df):\n",
|
||
" for index, column in enumerate(df.columns[:9]):\n",
|
||
" df[column] = df[column].fillna(df.groupby('Potability')[column].transform('mean'))\n",
|
||
" return df\n",
|
||
"\n",
|
||
"# Wczytywanie danych\n",
|
||
"df = pd.read_csv(\"water_potability.csv\")\n",
|
||
"\n",
|
||
"df = fill_nan(df)\n",
|
||
"\n",
|
||
"# Zrandomizowanie kolejności danych w datasecie\n",
|
||
"df = df.sample(frac=1, random_state=10).reset_index(drop=True)\n",
|
||
"\n",
|
||
"# Podział na atrybuty i przewidywane wartości\n",
|
||
"X, y = df.iloc[:, :-1], df.iloc[:, -1]\n",
|
||
"\n",
|
||
"# Normalizacja i skalowanie danych\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"sc = StandardScaler()\n",
|
||
"X = sc.fit_transform(X.to_numpy())\n",
|
||
"X = pd.DataFrame(X, columns=df.columns.values.tolist()[:-1])\n",
|
||
"\n",
|
||
"# Podział na dane trenujące i testowe, z uwzględnieniem równego rozłożenia danych\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y, random_state=1)\n",
|
||
"\n",
|
||
"print(X_train.shape, y_train.shape)\n",
|
||
"print(X_test.shape, y_test.shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"id": "O82SGzK6tjN5"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>ph</th>\n",
|
||
" <th>Hardness</th>\n",
|
||
" <th>Solids</th>\n",
|
||
" <th>Chloramines</th>\n",
|
||
" <th>Sulfate</th>\n",
|
||
" <th>Conductivity</th>\n",
|
||
" <th>Organic_carbon</th>\n",
|
||
" <th>Trihalomethanes</th>\n",
|
||
" <th>Turbidity</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1022</th>\n",
|
||
" <td>0.003078</td>\n",
|
||
" <td>0.688791</td>\n",
|
||
" <td>0.846257</td>\n",
|
||
" <td>1.428934</td>\n",
|
||
" <td>-0.858263</td>\n",
|
||
" <td>0.002792</td>\n",
|
||
" <td>0.913790</td>\n",
|
||
" <td>0.232417</td>\n",
|
||
" <td>2.319505</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3191</th>\n",
|
||
" <td>-0.587365</td>\n",
|
||
" <td>0.223203</td>\n",
|
||
" <td>-0.731867</td>\n",
|
||
" <td>0.397503</td>\n",
|
||
" <td>0.759893</td>\n",
|
||
" <td>0.330607</td>\n",
|
||
" <td>0.094379</td>\n",
|
||
" <td>0.282563</td>\n",
|
||
" <td>0.235024</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>13</th>\n",
|
||
" <td>0.003078</td>\n",
|
||
" <td>-0.241037</td>\n",
|
||
" <td>0.773051</td>\n",
|
||
" <td>0.580019</td>\n",
|
||
" <td>1.334369</td>\n",
|
||
" <td>-0.049130</td>\n",
|
||
" <td>-1.121422</td>\n",
|
||
" <td>-0.200432</td>\n",
|
||
" <td>-0.946356</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2068</th>\n",
|
||
" <td>-2.176058</td>\n",
|
||
" <td>1.443006</td>\n",
|
||
" <td>-1.626771</td>\n",
|
||
" <td>-4.164610</td>\n",
|
||
" <td>-0.033706</td>\n",
|
||
" <td>-1.050763</td>\n",
|
||
" <td>-0.391328</td>\n",
|
||
" <td>-0.398649</td>\n",
|
||
" <td>-0.298341</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1484</th>\n",
|
||
" <td>0.213047</td>\n",
|
||
" <td>0.403036</td>\n",
|
||
" <td>-0.464729</td>\n",
|
||
" <td>0.070417</td>\n",
|
||
" <td>0.021560</td>\n",
|
||
" <td>-0.952776</td>\n",
|
||
" <td>-0.213330</td>\n",
|
||
" <td>0.111419</td>\n",
|
||
" <td>-0.235893</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>691</th>\n",
|
||
" <td>0.003078</td>\n",
|
||
" <td>1.199106</td>\n",
|
||
" <td>-0.003483</td>\n",
|
||
" <td>-0.670308</td>\n",
|
||
" <td>-0.069513</td>\n",
|
||
" <td>0.185754</td>\n",
|
||
" <td>-0.466010</td>\n",
|
||
" <td>0.031975</td>\n",
|
||
" <td>0.676276</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1283</th>\n",
|
||
" <td>-2.034004</td>\n",
|
||
" <td>-1.508135</td>\n",
|
||
" <td>0.255310</td>\n",
|
||
" <td>0.083839</td>\n",
|
||
" <td>-1.413707</td>\n",
|
||
" <td>0.694074</td>\n",
|
||
" <td>-1.110579</td>\n",
|
||
" <td>0.232996</td>\n",
|
||
" <td>2.544703</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2818</th>\n",
|
||
" <td>-0.702987</td>\n",
|
||
" <td>-0.575677</td>\n",
|
||
" <td>0.755056</td>\n",
|
||
" <td>0.664695</td>\n",
|
||
" <td>0.021560</td>\n",
|
||
" <td>-0.489334</td>\n",
|
||
" <td>0.371852</td>\n",
|
||
" <td>-2.272990</td>\n",
|
||
" <td>-1.764684</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1330</th>\n",
|
||
" <td>1.525943</td>\n",
|
||
" <td>0.497074</td>\n",
|
||
" <td>-0.714355</td>\n",
|
||
" <td>-1.024237</td>\n",
|
||
" <td>-1.022037</td>\n",
|
||
" <td>-0.327074</td>\n",
|
||
" <td>-1.107341</td>\n",
|
||
" <td>0.517432</td>\n",
|
||
" <td>-1.230528</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1926</th>\n",
|
||
" <td>-0.043558</td>\n",
|
||
" <td>-0.882359</td>\n",
|
||
" <td>-0.456141</td>\n",
|
||
" <td>-0.770271</td>\n",
|
||
" <td>0.795189</td>\n",
|
||
" <td>0.560306</td>\n",
|
||
" <td>-1.086081</td>\n",
|
||
" <td>-1.356820</td>\n",
|
||
" <td>0.172521</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2948 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" ph Hardness Solids Chloramines Sulfate Conductivity \\\n",
|
||
"1022 0.003078 0.688791 0.846257 1.428934 -0.858263 0.002792 \n",
|
||
"3191 -0.587365 0.223203 -0.731867 0.397503 0.759893 0.330607 \n",
|
||
"13 0.003078 -0.241037 0.773051 0.580019 1.334369 -0.049130 \n",
|
||
"2068 -2.176058 1.443006 -1.626771 -4.164610 -0.033706 -1.050763 \n",
|
||
"1484 0.213047 0.403036 -0.464729 0.070417 0.021560 -0.952776 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"691 0.003078 1.199106 -0.003483 -0.670308 -0.069513 0.185754 \n",
|
||
"1283 -2.034004 -1.508135 0.255310 0.083839 -1.413707 0.694074 \n",
|
||
"2818 -0.702987 -0.575677 0.755056 0.664695 0.021560 -0.489334 \n",
|
||
"1330 1.525943 0.497074 -0.714355 -1.024237 -1.022037 -0.327074 \n",
|
||
"1926 -0.043558 -0.882359 -0.456141 -0.770271 0.795189 0.560306 \n",
|
||
"\n",
|
||
" Organic_carbon Trihalomethanes Turbidity \n",
|
||
"1022 0.913790 0.232417 2.319505 \n",
|
||
"3191 0.094379 0.282563 0.235024 \n",
|
||
"13 -1.121422 -0.200432 -0.946356 \n",
|
||
"2068 -0.391328 -0.398649 -0.298341 \n",
|
||
"1484 -0.213330 0.111419 -0.235893 \n",
|
||
"... ... ... ... \n",
|
||
"691 -0.466010 0.031975 0.676276 \n",
|
||
"1283 -1.110579 0.232996 2.544703 \n",
|
||
"2818 0.371852 -2.272990 -1.764684 \n",
|
||
"1330 -1.107341 0.517432 -1.230528 \n",
|
||
"1926 -1.086081 -1.356820 0.172521 \n",
|
||
"\n",
|
||
"[2948 rows x 9 columns]"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"X_train"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"id": "a3jkTMFLtjN6"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Trenowanie modelu klasyfikatora\n",
|
||
"x = NaiveBayesClassifier()\n",
|
||
"x.fit(X_train, y_train)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"id": "CoC22aNgtjN9"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Predykcja wartości dla danych testowych\n",
|
||
"predictions = x.predict(X_test)\n",
|
||
"\n",
|
||
"# Prawdopodobieństwa kolejnych predykcji\n",
|
||
"probabilities = [p[1] for p in predictions]\n",
|
||
"\n",
|
||
"# Przewidziana wartość\n",
|
||
"predictions = [p[0] for p in predictions]\n",
|
||
"predictions[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"id": "JR06zodmtjN9"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0.6280487804878049"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Wyliczenie accuracy modelu\n",
|
||
"x.accuracy(y_test, predictions)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"id": "1jW0QPootjN_"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0.14084507042253522"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import f1_score\n",
|
||
"\n",
|
||
"f1_score(y_test, predictions)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"id": "vEVogTmAtjOA"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0 0.609756\n",
|
||
"1 0.390244\n",
|
||
"Name: Potability, dtype: float64"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"y_test.value_counts(normalize=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"id": "jCVOdBZytjOB"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 1080x432 with 2 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"x.visualize(y_test, predictions, 'Potability')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"id": "aw8Tefprhjnn"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/lib/python3/dist-packages/matplotlib/cbook/__init__.py:1377: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead.\n",
|
||
" x[:, None]\n",
|
||
"/usr/lib/python3/dist-packages/matplotlib/axes/_base.py:237: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead.\n",
|
||
" x = x[:, np.newaxis]\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 6 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"ph_val = X_test[\"ph\"]\n",
|
||
"sulfate_val = X_test[\"Sulfate\"]\n",
|
||
"hard_val = X_test[\"Hardness\"]\n",
|
||
"carb_val = X_test[\"Organic_carbon\"]\n",
|
||
"turb_val = X_test[\"Turbidity\"]\n",
|
||
"ch_val = X_test[\"Chloramines\"]\n",
|
||
"\n",
|
||
"\n",
|
||
"figure, axes = plt.subplots(nrows=3, ncols=2)\n",
|
||
"\n",
|
||
"axes[0, 0].plot(ph_val, predictions, 'bo')\n",
|
||
"axes[0, 0].set_title(\"pH\")\n",
|
||
"\n",
|
||
"axes[0, 1].plot(sulfate_val, predictions, 'bo')\n",
|
||
"axes[0, 1].set_title(\"Sulfate\")\n",
|
||
"\n",
|
||
"axes[1, 0].plot(hard_val, predictions, 'bo')\n",
|
||
"axes[1, 0].set_title(\"Hardness\")\n",
|
||
"\n",
|
||
"axes[1, 1].plot(carb_val, predictions, 'bo')\n",
|
||
"axes[1, 1].set_title(\"Organic carbon\")\n",
|
||
"\n",
|
||
"axes[2, 0].plot(turb_val, predictions, 'bo')\n",
|
||
"axes[2, 0].set_title(\"Turbidity\")\n",
|
||
"\n",
|
||
"axes[2, 1].plot(ch_val, predictions, 'bo')\n",
|
||
"axes[2, 1].set_title(\"Chloramines\")\n",
|
||
"\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"colab": {
|
||
"collapsed_sections": [],
|
||
"name": "naive_bayes.ipynb",
|
||
"provenance": []
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|