1726 lines
66 KiB
Plaintext
1726 lines
66 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import sklearn.model_selection"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import requests\n",
|
|
"\n",
|
|
"url = \"https://huggingface.co/datasets/mstz/wine/raw/main/Wine_Quality_Data.csv\"\n",
|
|
"save_path = \"Wine_Quality_Data.csv\"\n",
|
|
"\n",
|
|
"response = requests.get(url)\n",
|
|
"response.raise_for_status()\n",
|
|
"\n",
|
|
"with open(save_path, \"wb\") as f:\n",
|
|
" f.write(response.content)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wine_dataset = pd.read_csv(\"Wine_Quality_Data.csv\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>fixed_acidity</th>\n",
|
|
" <th>volatile_acidity</th>\n",
|
|
" <th>citric_acid</th>\n",
|
|
" <th>residual_sugar</th>\n",
|
|
" <th>chlorides</th>\n",
|
|
" <th>free_sulfur_dioxide</th>\n",
|
|
" <th>total_sulfur_dioxide</th>\n",
|
|
" <th>density</th>\n",
|
|
" <th>pH</th>\n",
|
|
" <th>sulphates</th>\n",
|
|
" <th>alcohol</th>\n",
|
|
" <th>quality</th>\n",
|
|
" <th>color</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>7.4</td>\n",
|
|
" <td>0.70</td>\n",
|
|
" <td>0.00</td>\n",
|
|
" <td>1.9</td>\n",
|
|
" <td>0.076</td>\n",
|
|
" <td>11.0</td>\n",
|
|
" <td>34.0</td>\n",
|
|
" <td>0.9978</td>\n",
|
|
" <td>3.51</td>\n",
|
|
" <td>0.56</td>\n",
|
|
" <td>9.4</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>red</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>7.8</td>\n",
|
|
" <td>0.88</td>\n",
|
|
" <td>0.00</td>\n",
|
|
" <td>2.6</td>\n",
|
|
" <td>0.098</td>\n",
|
|
" <td>25.0</td>\n",
|
|
" <td>67.0</td>\n",
|
|
" <td>0.9968</td>\n",
|
|
" <td>3.20</td>\n",
|
|
" <td>0.68</td>\n",
|
|
" <td>9.8</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>red</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>7.8</td>\n",
|
|
" <td>0.76</td>\n",
|
|
" <td>0.04</td>\n",
|
|
" <td>2.3</td>\n",
|
|
" <td>0.092</td>\n",
|
|
" <td>15.0</td>\n",
|
|
" <td>54.0</td>\n",
|
|
" <td>0.9970</td>\n",
|
|
" <td>3.26</td>\n",
|
|
" <td>0.65</td>\n",
|
|
" <td>9.8</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>red</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>11.2</td>\n",
|
|
" <td>0.28</td>\n",
|
|
" <td>0.56</td>\n",
|
|
" <td>1.9</td>\n",
|
|
" <td>0.075</td>\n",
|
|
" <td>17.0</td>\n",
|
|
" <td>60.0</td>\n",
|
|
" <td>0.9980</td>\n",
|
|
" <td>3.16</td>\n",
|
|
" <td>0.58</td>\n",
|
|
" <td>9.8</td>\n",
|
|
" <td>6</td>\n",
|
|
" <td>red</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>7.4</td>\n",
|
|
" <td>0.70</td>\n",
|
|
" <td>0.00</td>\n",
|
|
" <td>1.9</td>\n",
|
|
" <td>0.076</td>\n",
|
|
" <td>11.0</td>\n",
|
|
" <td>34.0</td>\n",
|
|
" <td>0.9978</td>\n",
|
|
" <td>3.51</td>\n",
|
|
" <td>0.56</td>\n",
|
|
" <td>9.4</td>\n",
|
|
" <td>5</td>\n",
|
|
" <td>red</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" fixed_acidity volatile_acidity citric_acid residual_sugar chlorides \\\n",
|
|
"0 7.4 0.70 0.00 1.9 0.076 \n",
|
|
"1 7.8 0.88 0.00 2.6 0.098 \n",
|
|
"2 7.8 0.76 0.04 2.3 0.092 \n",
|
|
"3 11.2 0.28 0.56 1.9 0.075 \n",
|
|
"4 7.4 0.70 0.00 1.9 0.076 \n",
|
|
"\n",
|
|
" free_sulfur_dioxide total_sulfur_dioxide density pH sulphates \\\n",
|
|
"0 11.0 34.0 0.9978 3.51 0.56 \n",
|
|
"1 25.0 67.0 0.9968 3.20 0.68 \n",
|
|
"2 15.0 54.0 0.9970 3.26 0.65 \n",
|
|
"3 17.0 60.0 0.9980 3.16 0.58 \n",
|
|
"4 11.0 34.0 0.9978 3.51 0.56 \n",
|
|
"\n",
|
|
" alcohol quality color \n",
|
|
"0 9.4 5 red \n",
|
|
"1 9.8 5 red \n",
|
|
"2 9.8 5 red \n",
|
|
"3 9.8 6 red \n",
|
|
"4 9.4 5 red "
|
|
]
|
|
},
|
|
"execution_count": 59,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_dataset.head()# podgląd danych"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 60,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wine_dataset['color'] = wine_dataset['color'].replace({'red': 1, 'white': 0})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 61,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>fixed_acidity</th>\n",
|
|
" <th>volatile_acidity</th>\n",
|
|
" <th>citric_acid</th>\n",
|
|
" <th>residual_sugar</th>\n",
|
|
" <th>chlorides</th>\n",
|
|
" <th>free_sulfur_dioxide</th>\n",
|
|
" <th>total_sulfur_dioxide</th>\n",
|
|
" <th>density</th>\n",
|
|
" <th>pH</th>\n",
|
|
" <th>sulphates</th>\n",
|
|
" <th>alcohol</th>\n",
|
|
" <th>quality</th>\n",
|
|
" <th>color</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>count</th>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>mean</th>\n",
|
|
" <td>7.215307</td>\n",
|
|
" <td>0.339666</td>\n",
|
|
" <td>0.318633</td>\n",
|
|
" <td>5.443235</td>\n",
|
|
" <td>0.056034</td>\n",
|
|
" <td>30.525319</td>\n",
|
|
" <td>115.744574</td>\n",
|
|
" <td>0.994697</td>\n",
|
|
" <td>3.218501</td>\n",
|
|
" <td>0.531268</td>\n",
|
|
" <td>10.491801</td>\n",
|
|
" <td>5.818378</td>\n",
|
|
" <td>0.246114</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>std</th>\n",
|
|
" <td>1.296434</td>\n",
|
|
" <td>0.164636</td>\n",
|
|
" <td>0.145318</td>\n",
|
|
" <td>4.757804</td>\n",
|
|
" <td>0.035034</td>\n",
|
|
" <td>17.749400</td>\n",
|
|
" <td>56.521855</td>\n",
|
|
" <td>0.002999</td>\n",
|
|
" <td>0.160787</td>\n",
|
|
" <td>0.148806</td>\n",
|
|
" <td>1.192712</td>\n",
|
|
" <td>0.873255</td>\n",
|
|
" <td>0.430779</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>min</th>\n",
|
|
" <td>3.800000</td>\n",
|
|
" <td>0.080000</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" <td>0.600000</td>\n",
|
|
" <td>0.009000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>6.000000</td>\n",
|
|
" <td>0.987110</td>\n",
|
|
" <td>2.720000</td>\n",
|
|
" <td>0.220000</td>\n",
|
|
" <td>8.000000</td>\n",
|
|
" <td>3.000000</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>25%</th>\n",
|
|
" <td>6.400000</td>\n",
|
|
" <td>0.230000</td>\n",
|
|
" <td>0.250000</td>\n",
|
|
" <td>1.800000</td>\n",
|
|
" <td>0.038000</td>\n",
|
|
" <td>17.000000</td>\n",
|
|
" <td>77.000000</td>\n",
|
|
" <td>0.992340</td>\n",
|
|
" <td>3.110000</td>\n",
|
|
" <td>0.430000</td>\n",
|
|
" <td>9.500000</td>\n",
|
|
" <td>5.000000</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>50%</th>\n",
|
|
" <td>7.000000</td>\n",
|
|
" <td>0.290000</td>\n",
|
|
" <td>0.310000</td>\n",
|
|
" <td>3.000000</td>\n",
|
|
" <td>0.047000</td>\n",
|
|
" <td>29.000000</td>\n",
|
|
" <td>118.000000</td>\n",
|
|
" <td>0.994890</td>\n",
|
|
" <td>3.210000</td>\n",
|
|
" <td>0.510000</td>\n",
|
|
" <td>10.300000</td>\n",
|
|
" <td>6.000000</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>75%</th>\n",
|
|
" <td>7.700000</td>\n",
|
|
" <td>0.400000</td>\n",
|
|
" <td>0.390000</td>\n",
|
|
" <td>8.100000</td>\n",
|
|
" <td>0.065000</td>\n",
|
|
" <td>41.000000</td>\n",
|
|
" <td>156.000000</td>\n",
|
|
" <td>0.996990</td>\n",
|
|
" <td>3.320000</td>\n",
|
|
" <td>0.600000</td>\n",
|
|
" <td>11.300000</td>\n",
|
|
" <td>6.000000</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>max</th>\n",
|
|
" <td>15.900000</td>\n",
|
|
" <td>1.580000</td>\n",
|
|
" <td>1.660000</td>\n",
|
|
" <td>65.800000</td>\n",
|
|
" <td>0.611000</td>\n",
|
|
" <td>289.000000</td>\n",
|
|
" <td>440.000000</td>\n",
|
|
" <td>1.038980</td>\n",
|
|
" <td>4.010000</td>\n",
|
|
" <td>2.000000</td>\n",
|
|
" <td>14.900000</td>\n",
|
|
" <td>9.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
|
|
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
|
|
"mean 7.215307 0.339666 0.318633 5.443235 \n",
|
|
"std 1.296434 0.164636 0.145318 4.757804 \n",
|
|
"min 3.800000 0.080000 0.000000 0.600000 \n",
|
|
"25% 6.400000 0.230000 0.250000 1.800000 \n",
|
|
"50% 7.000000 0.290000 0.310000 3.000000 \n",
|
|
"75% 7.700000 0.400000 0.390000 8.100000 \n",
|
|
"max 15.900000 1.580000 1.660000 65.800000 \n",
|
|
"\n",
|
|
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
|
|
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
|
|
"mean 0.056034 30.525319 115.744574 0.994697 \n",
|
|
"std 0.035034 17.749400 56.521855 0.002999 \n",
|
|
"min 0.009000 1.000000 6.000000 0.987110 \n",
|
|
"25% 0.038000 17.000000 77.000000 0.992340 \n",
|
|
"50% 0.047000 29.000000 118.000000 0.994890 \n",
|
|
"75% 0.065000 41.000000 156.000000 0.996990 \n",
|
|
"max 0.611000 289.000000 440.000000 1.038980 \n",
|
|
"\n",
|
|
" pH sulphates alcohol quality color \n",
|
|
"count 6497.000000 6497.000000 6497.000000 6497.000000 6497.000000 \n",
|
|
"mean 3.218501 0.531268 10.491801 5.818378 0.246114 \n",
|
|
"std 0.160787 0.148806 1.192712 0.873255 0.430779 \n",
|
|
"min 2.720000 0.220000 8.000000 3.000000 0.000000 \n",
|
|
"25% 3.110000 0.430000 9.500000 5.000000 0.000000 \n",
|
|
"50% 3.210000 0.510000 10.300000 6.000000 0.000000 \n",
|
|
"75% 3.320000 0.600000 11.300000 6.000000 0.000000 \n",
|
|
"max 4.010000 2.000000 14.900000 9.000000 1.000000 "
|
|
]
|
|
},
|
|
"execution_count": 61,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_dataset.describe(include='all')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<Axes: >"
|
|
]
|
|
},
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_dataset[\"color\"].value_counts().plot(kind=\"bar\")\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"1.2964337577998153"
|
|
]
|
|
},
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_dataset[\"fixed_acidity\"].std()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 64,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(array([], dtype=int64), array([], dtype=int64))"
|
|
]
|
|
},
|
|
"execution_count": 64,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"np.where(pd.isnull(wine_dataset))## sprawdzanie czy istnieją puste wartości"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 65,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for column in wine_dataset.columns:\n",
|
|
" wine_dataset[column] = wine_dataset[column] / wine_dataset[column].abs().max() # normalizacja"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 66,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>fixed_acidity</th>\n",
|
|
" <th>volatile_acidity</th>\n",
|
|
" <th>citric_acid</th>\n",
|
|
" <th>residual_sugar</th>\n",
|
|
" <th>chlorides</th>\n",
|
|
" <th>free_sulfur_dioxide</th>\n",
|
|
" <th>total_sulfur_dioxide</th>\n",
|
|
" <th>density</th>\n",
|
|
" <th>pH</th>\n",
|
|
" <th>sulphates</th>\n",
|
|
" <th>alcohol</th>\n",
|
|
" <th>quality</th>\n",
|
|
" <th>color</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>count</th>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" <td>6497.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>mean</th>\n",
|
|
" <td>0.453793</td>\n",
|
|
" <td>0.214978</td>\n",
|
|
" <td>0.191948</td>\n",
|
|
" <td>0.082724</td>\n",
|
|
" <td>0.091708</td>\n",
|
|
" <td>0.105624</td>\n",
|
|
" <td>0.263056</td>\n",
|
|
" <td>0.957378</td>\n",
|
|
" <td>0.802619</td>\n",
|
|
" <td>0.265634</td>\n",
|
|
" <td>0.704148</td>\n",
|
|
" <td>0.646486</td>\n",
|
|
" <td>0.246114</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>std</th>\n",
|
|
" <td>0.081537</td>\n",
|
|
" <td>0.104200</td>\n",
|
|
" <td>0.087541</td>\n",
|
|
" <td>0.072307</td>\n",
|
|
" <td>0.057338</td>\n",
|
|
" <td>0.061417</td>\n",
|
|
" <td>0.128459</td>\n",
|
|
" <td>0.002886</td>\n",
|
|
" <td>0.040097</td>\n",
|
|
" <td>0.074403</td>\n",
|
|
" <td>0.080048</td>\n",
|
|
" <td>0.097028</td>\n",
|
|
" <td>0.430779</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>min</th>\n",
|
|
" <td>0.238994</td>\n",
|
|
" <td>0.050633</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" <td>0.009119</td>\n",
|
|
" <td>0.014730</td>\n",
|
|
" <td>0.003460</td>\n",
|
|
" <td>0.013636</td>\n",
|
|
" <td>0.950076</td>\n",
|
|
" <td>0.678304</td>\n",
|
|
" <td>0.110000</td>\n",
|
|
" <td>0.536913</td>\n",
|
|
" <td>0.333333</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>25%</th>\n",
|
|
" <td>0.402516</td>\n",
|
|
" <td>0.145570</td>\n",
|
|
" <td>0.150602</td>\n",
|
|
" <td>0.027356</td>\n",
|
|
" <td>0.062193</td>\n",
|
|
" <td>0.058824</td>\n",
|
|
" <td>0.175000</td>\n",
|
|
" <td>0.955110</td>\n",
|
|
" <td>0.775561</td>\n",
|
|
" <td>0.215000</td>\n",
|
|
" <td>0.637584</td>\n",
|
|
" <td>0.555556</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>50%</th>\n",
|
|
" <td>0.440252</td>\n",
|
|
" <td>0.183544</td>\n",
|
|
" <td>0.186747</td>\n",
|
|
" <td>0.045593</td>\n",
|
|
" <td>0.076923</td>\n",
|
|
" <td>0.100346</td>\n",
|
|
" <td>0.268182</td>\n",
|
|
" <td>0.957564</td>\n",
|
|
" <td>0.800499</td>\n",
|
|
" <td>0.255000</td>\n",
|
|
" <td>0.691275</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>75%</th>\n",
|
|
" <td>0.484277</td>\n",
|
|
" <td>0.253165</td>\n",
|
|
" <td>0.234940</td>\n",
|
|
" <td>0.123100</td>\n",
|
|
" <td>0.106383</td>\n",
|
|
" <td>0.141869</td>\n",
|
|
" <td>0.354545</td>\n",
|
|
" <td>0.959585</td>\n",
|
|
" <td>0.827930</td>\n",
|
|
" <td>0.300000</td>\n",
|
|
" <td>0.758389</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>max</th>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
|
|
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
|
|
"mean 0.453793 0.214978 0.191948 0.082724 \n",
|
|
"std 0.081537 0.104200 0.087541 0.072307 \n",
|
|
"min 0.238994 0.050633 0.000000 0.009119 \n",
|
|
"25% 0.402516 0.145570 0.150602 0.027356 \n",
|
|
"50% 0.440252 0.183544 0.186747 0.045593 \n",
|
|
"75% 0.484277 0.253165 0.234940 0.123100 \n",
|
|
"max 1.000000 1.000000 1.000000 1.000000 \n",
|
|
"\n",
|
|
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
|
|
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
|
|
"mean 0.091708 0.105624 0.263056 0.957378 \n",
|
|
"std 0.057338 0.061417 0.128459 0.002886 \n",
|
|
"min 0.014730 0.003460 0.013636 0.950076 \n",
|
|
"25% 0.062193 0.058824 0.175000 0.955110 \n",
|
|
"50% 0.076923 0.100346 0.268182 0.957564 \n",
|
|
"75% 0.106383 0.141869 0.354545 0.959585 \n",
|
|
"max 1.000000 1.000000 1.000000 1.000000 \n",
|
|
"\n",
|
|
" pH sulphates alcohol quality color \n",
|
|
"count 6497.000000 6497.000000 6497.000000 6497.000000 6497.000000 \n",
|
|
"mean 0.802619 0.265634 0.704148 0.646486 0.246114 \n",
|
|
"std 0.040097 0.074403 0.080048 0.097028 0.430779 \n",
|
|
"min 0.678304 0.110000 0.536913 0.333333 0.000000 \n",
|
|
"25% 0.775561 0.215000 0.637584 0.555556 0.000000 \n",
|
|
"50% 0.800499 0.255000 0.691275 0.666667 0.000000 \n",
|
|
"75% 0.827930 0.300000 0.758389 0.666667 0.000000 \n",
|
|
"max 1.000000 1.000000 1.000000 1.000000 1.000000 "
|
|
]
|
|
},
|
|
"execution_count": 66,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_dataset.describe(include='all') # sprawdzanie wartości po znormalizowaniu"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 67,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"652 1.000000\n",
|
|
"442 0.981132\n",
|
|
"557 0.981132\n",
|
|
"554 0.974843\n",
|
|
"555 0.974843\n",
|
|
"243 0.943396\n",
|
|
"244 0.943396\n",
|
|
"544 0.899371\n",
|
|
"3125 0.893082\n",
|
|
"374 0.880503\n",
|
|
"Name: fixed_acidity, dtype: float64"
|
|
]
|
|
},
|
|
"execution_count": 67,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_dataset[\"fixed_acidity\"].nlargest(10) #sprawdza czy najwyższe wartości mają sens"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 68,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.0 4408\n",
|
|
"1.0 1439\n",
|
|
"Name: color, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 68,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"wine_train, wine_test = sklearn.model_selection.train_test_split(wine_dataset, test_size=0.1, random_state=1, stratify=wine_dataset[\"color\"])\n",
|
|
"wine_train[\"color\"].value_counts() \n",
|
|
"# podzielenie na train i test"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 69,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.0 490\n",
|
|
"1.0 160\n",
|
|
"Name: color, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 69,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_test[\"color\"].value_counts()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 70,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wine_test, wine_val = sklearn.model_selection.train_test_split(wine_test, test_size=0.5, random_state=1, stratify=wine_test[\"color\"]) # podzielenie na test i validation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 71,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.0 245\n",
|
|
"1.0 80\n",
|
|
"Name: color, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 71,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_test[\"color\"].value_counts()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 72,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.0 245\n",
|
|
"1.0 80\n",
|
|
"Name: color, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 72,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_val[\"color\"].value_counts()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import seaborn as sns\n",
|
|
"sns.set_theme()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"13"
|
|
]
|
|
},
|
|
"execution_count": 74,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(wine_dataset.columns)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 75,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#sns.pairplot(data=wine_dataset, hue=\"color\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 76,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>fixed_acidity</th>\n",
|
|
" <th>volatile_acidity</th>\n",
|
|
" <th>citric_acid</th>\n",
|
|
" <th>residual_sugar</th>\n",
|
|
" <th>chlorides</th>\n",
|
|
" <th>free_sulfur_dioxide</th>\n",
|
|
" <th>total_sulfur_dioxide</th>\n",
|
|
" <th>density</th>\n",
|
|
" <th>pH</th>\n",
|
|
" <th>sulphates</th>\n",
|
|
" <th>alcohol</th>\n",
|
|
" <th>quality</th>\n",
|
|
" <th>color</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>count</th>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>mean</th>\n",
|
|
" <td>0.460126</td>\n",
|
|
" <td>0.209883</td>\n",
|
|
" <td>0.197294</td>\n",
|
|
" <td>0.083839</td>\n",
|
|
" <td>0.096352</td>\n",
|
|
" <td>0.105307</td>\n",
|
|
" <td>0.272028</td>\n",
|
|
" <td>0.957685</td>\n",
|
|
" <td>0.799770</td>\n",
|
|
" <td>0.266477</td>\n",
|
|
" <td>0.691389</td>\n",
|
|
" <td>0.636239</td>\n",
|
|
" <td>0.246154</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>std</th>\n",
|
|
" <td>0.087321</td>\n",
|
|
" <td>0.100971</td>\n",
|
|
" <td>0.086532</td>\n",
|
|
" <td>0.072172</td>\n",
|
|
" <td>0.066017</td>\n",
|
|
" <td>0.061895</td>\n",
|
|
" <td>0.131981</td>\n",
|
|
" <td>0.002780</td>\n",
|
|
" <td>0.038640</td>\n",
|
|
" <td>0.082243</td>\n",
|
|
" <td>0.073293</td>\n",
|
|
" <td>0.088732</td>\n",
|
|
" <td>0.431433</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>min</th>\n",
|
|
" <td>0.308176</td>\n",
|
|
" <td>0.066456</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" <td>0.010638</td>\n",
|
|
" <td>0.026187</td>\n",
|
|
" <td>0.003460</td>\n",
|
|
" <td>0.020455</td>\n",
|
|
" <td>0.952030</td>\n",
|
|
" <td>0.698254</td>\n",
|
|
" <td>0.115000</td>\n",
|
|
" <td>0.577181</td>\n",
|
|
" <td>0.333333</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>25%</th>\n",
|
|
" <td>0.408805</td>\n",
|
|
" <td>0.139241</td>\n",
|
|
" <td>0.156627</td>\n",
|
|
" <td>0.027356</td>\n",
|
|
" <td>0.062193</td>\n",
|
|
" <td>0.058824</td>\n",
|
|
" <td>0.188636</td>\n",
|
|
" <td>0.955322</td>\n",
|
|
" <td>0.773067</td>\n",
|
|
" <td>0.215000</td>\n",
|
|
" <td>0.630872</td>\n",
|
|
" <td>0.555556</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>50%</th>\n",
|
|
" <td>0.440252</td>\n",
|
|
" <td>0.189873</td>\n",
|
|
" <td>0.186747</td>\n",
|
|
" <td>0.048632</td>\n",
|
|
" <td>0.078560</td>\n",
|
|
" <td>0.100346</td>\n",
|
|
" <td>0.275000</td>\n",
|
|
" <td>0.957978</td>\n",
|
|
" <td>0.795511</td>\n",
|
|
" <td>0.250000</td>\n",
|
|
" <td>0.671141</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>75%</th>\n",
|
|
" <td>0.484277</td>\n",
|
|
" <td>0.240506</td>\n",
|
|
" <td>0.246988</td>\n",
|
|
" <td>0.121581</td>\n",
|
|
" <td>0.116203</td>\n",
|
|
" <td>0.145329</td>\n",
|
|
" <td>0.356818</td>\n",
|
|
" <td>0.959787</td>\n",
|
|
" <td>0.822943</td>\n",
|
|
" <td>0.305000</td>\n",
|
|
" <td>0.738255</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>max</th>\n",
|
|
" <td>0.943396</td>\n",
|
|
" <td>0.715190</td>\n",
|
|
" <td>0.469880</td>\n",
|
|
" <td>0.303191</td>\n",
|
|
" <td>0.764321</td>\n",
|
|
" <td>0.479239</td>\n",
|
|
" <td>0.781818</td>\n",
|
|
" <td>0.966034</td>\n",
|
|
" <td>0.895262</td>\n",
|
|
" <td>0.975000</td>\n",
|
|
" <td>0.906040</td>\n",
|
|
" <td>0.888889</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
|
|
"count 325.000000 325.000000 325.000000 325.000000 \n",
|
|
"mean 0.460126 0.209883 0.197294 0.083839 \n",
|
|
"std 0.087321 0.100971 0.086532 0.072172 \n",
|
|
"min 0.308176 0.066456 0.000000 0.010638 \n",
|
|
"25% 0.408805 0.139241 0.156627 0.027356 \n",
|
|
"50% 0.440252 0.189873 0.186747 0.048632 \n",
|
|
"75% 0.484277 0.240506 0.246988 0.121581 \n",
|
|
"max 0.943396 0.715190 0.469880 0.303191 \n",
|
|
"\n",
|
|
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
|
|
"count 325.000000 325.000000 325.000000 325.000000 \n",
|
|
"mean 0.096352 0.105307 0.272028 0.957685 \n",
|
|
"std 0.066017 0.061895 0.131981 0.002780 \n",
|
|
"min 0.026187 0.003460 0.020455 0.952030 \n",
|
|
"25% 0.062193 0.058824 0.188636 0.955322 \n",
|
|
"50% 0.078560 0.100346 0.275000 0.957978 \n",
|
|
"75% 0.116203 0.145329 0.356818 0.959787 \n",
|
|
"max 0.764321 0.479239 0.781818 0.966034 \n",
|
|
"\n",
|
|
" pH sulphates alcohol quality color \n",
|
|
"count 325.000000 325.000000 325.000000 325.000000 325.000000 \n",
|
|
"mean 0.799770 0.266477 0.691389 0.636239 0.246154 \n",
|
|
"std 0.038640 0.082243 0.073293 0.088732 0.431433 \n",
|
|
"min 0.698254 0.115000 0.577181 0.333333 0.000000 \n",
|
|
"25% 0.773067 0.215000 0.630872 0.555556 0.000000 \n",
|
|
"50% 0.795511 0.250000 0.671141 0.666667 0.000000 \n",
|
|
"75% 0.822943 0.305000 0.738255 0.666667 0.000000 \n",
|
|
"max 0.895262 0.975000 0.906040 0.888889 1.000000 "
|
|
]
|
|
},
|
|
"execution_count": 76,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_test.describe()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>fixed_acidity</th>\n",
|
|
" <th>volatile_acidity</th>\n",
|
|
" <th>citric_acid</th>\n",
|
|
" <th>residual_sugar</th>\n",
|
|
" <th>chlorides</th>\n",
|
|
" <th>free_sulfur_dioxide</th>\n",
|
|
" <th>total_sulfur_dioxide</th>\n",
|
|
" <th>density</th>\n",
|
|
" <th>pH</th>\n",
|
|
" <th>sulphates</th>\n",
|
|
" <th>alcohol</th>\n",
|
|
" <th>quality</th>\n",
|
|
" <th>color</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>count</th>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" <td>5847.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>mean</th>\n",
|
|
" <td>0.453724</td>\n",
|
|
" <td>0.215128</td>\n",
|
|
" <td>0.192091</td>\n",
|
|
" <td>0.082877</td>\n",
|
|
" <td>0.091656</td>\n",
|
|
" <td>0.105899</td>\n",
|
|
" <td>0.262834</td>\n",
|
|
" <td>0.957374</td>\n",
|
|
" <td>0.802637</td>\n",
|
|
" <td>0.265601</td>\n",
|
|
" <td>0.704572</td>\n",
|
|
" <td>0.646846</td>\n",
|
|
" <td>0.246109</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>std</th>\n",
|
|
" <td>0.081597</td>\n",
|
|
" <td>0.104319</td>\n",
|
|
" <td>0.087166</td>\n",
|
|
" <td>0.072487</td>\n",
|
|
" <td>0.057502</td>\n",
|
|
" <td>0.061908</td>\n",
|
|
" <td>0.128388</td>\n",
|
|
" <td>0.002899</td>\n",
|
|
" <td>0.040030</td>\n",
|
|
" <td>0.074400</td>\n",
|
|
" <td>0.080399</td>\n",
|
|
" <td>0.097212</td>\n",
|
|
" <td>0.430780</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>min</th>\n",
|
|
" <td>0.238994</td>\n",
|
|
" <td>0.050633</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" <td>0.009119</td>\n",
|
|
" <td>0.014730</td>\n",
|
|
" <td>0.003460</td>\n",
|
|
" <td>0.013636</td>\n",
|
|
" <td>0.950076</td>\n",
|
|
" <td>0.678304</td>\n",
|
|
" <td>0.110000</td>\n",
|
|
" <td>0.536913</td>\n",
|
|
" <td>0.333333</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>25%</th>\n",
|
|
" <td>0.402516</td>\n",
|
|
" <td>0.145570</td>\n",
|
|
" <td>0.150602</td>\n",
|
|
" <td>0.027356</td>\n",
|
|
" <td>0.062193</td>\n",
|
|
" <td>0.058824</td>\n",
|
|
" <td>0.175000</td>\n",
|
|
" <td>0.955110</td>\n",
|
|
" <td>0.775561</td>\n",
|
|
" <td>0.215000</td>\n",
|
|
" <td>0.637584</td>\n",
|
|
" <td>0.555556</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>50%</th>\n",
|
|
" <td>0.440252</td>\n",
|
|
" <td>0.183544</td>\n",
|
|
" <td>0.186747</td>\n",
|
|
" <td>0.045593</td>\n",
|
|
" <td>0.076923</td>\n",
|
|
" <td>0.100346</td>\n",
|
|
" <td>0.268182</td>\n",
|
|
" <td>0.957555</td>\n",
|
|
" <td>0.800499</td>\n",
|
|
" <td>0.255000</td>\n",
|
|
" <td>0.691275</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>75%</th>\n",
|
|
" <td>0.484277</td>\n",
|
|
" <td>0.259494</td>\n",
|
|
" <td>0.234940</td>\n",
|
|
" <td>0.123100</td>\n",
|
|
" <td>0.106383</td>\n",
|
|
" <td>0.141869</td>\n",
|
|
" <td>0.354545</td>\n",
|
|
" <td>0.959585</td>\n",
|
|
" <td>0.827930</td>\n",
|
|
" <td>0.300000</td>\n",
|
|
" <td>0.758389</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>max</th>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
|
|
"count 5847.000000 5847.000000 5847.000000 5847.000000 \n",
|
|
"mean 0.453724 0.215128 0.192091 0.082877 \n",
|
|
"std 0.081597 0.104319 0.087166 0.072487 \n",
|
|
"min 0.238994 0.050633 0.000000 0.009119 \n",
|
|
"25% 0.402516 0.145570 0.150602 0.027356 \n",
|
|
"50% 0.440252 0.183544 0.186747 0.045593 \n",
|
|
"75% 0.484277 0.259494 0.234940 0.123100 \n",
|
|
"max 1.000000 1.000000 1.000000 1.000000 \n",
|
|
"\n",
|
|
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
|
|
"count 5847.000000 5847.000000 5847.000000 5847.000000 \n",
|
|
"mean 0.091656 0.105899 0.262834 0.957374 \n",
|
|
"std 0.057502 0.061908 0.128388 0.002899 \n",
|
|
"min 0.014730 0.003460 0.013636 0.950076 \n",
|
|
"25% 0.062193 0.058824 0.175000 0.955110 \n",
|
|
"50% 0.076923 0.100346 0.268182 0.957555 \n",
|
|
"75% 0.106383 0.141869 0.354545 0.959585 \n",
|
|
"max 1.000000 1.000000 1.000000 1.000000 \n",
|
|
"\n",
|
|
" pH sulphates alcohol quality color \n",
|
|
"count 5847.000000 5847.000000 5847.000000 5847.000000 5847.000000 \n",
|
|
"mean 0.802637 0.265601 0.704572 0.646846 0.246109 \n",
|
|
"std 0.040030 0.074400 0.080399 0.097212 0.430780 \n",
|
|
"min 0.678304 0.110000 0.536913 0.333333 0.000000 \n",
|
|
"25% 0.775561 0.215000 0.637584 0.555556 0.000000 \n",
|
|
"50% 0.800499 0.255000 0.691275 0.666667 0.000000 \n",
|
|
"75% 0.827930 0.300000 0.758389 0.666667 0.000000 \n",
|
|
"max 1.000000 1.000000 1.000000 1.000000 1.000000 "
|
|
]
|
|
},
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_train.describe()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 78,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>fixed_acidity</th>\n",
|
|
" <th>volatile_acidity</th>\n",
|
|
" <th>citric_acid</th>\n",
|
|
" <th>residual_sugar</th>\n",
|
|
" <th>chlorides</th>\n",
|
|
" <th>free_sulfur_dioxide</th>\n",
|
|
" <th>total_sulfur_dioxide</th>\n",
|
|
" <th>density</th>\n",
|
|
" <th>pH</th>\n",
|
|
" <th>sulphates</th>\n",
|
|
" <th>alcohol</th>\n",
|
|
" <th>quality</th>\n",
|
|
" <th>color</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>count</th>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" <td>325.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>mean</th>\n",
|
|
" <td>0.448708</td>\n",
|
|
" <td>0.217381</td>\n",
|
|
" <td>0.184022</td>\n",
|
|
" <td>0.078864</td>\n",
|
|
" <td>0.088017</td>\n",
|
|
" <td>0.100985</td>\n",
|
|
" <td>0.258073</td>\n",
|
|
" <td>0.957147</td>\n",
|
|
" <td>0.805141</td>\n",
|
|
" <td>0.265385</td>\n",
|
|
" <td>0.709269</td>\n",
|
|
" <td>0.650256</td>\n",
|
|
" <td>0.246154</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>std</th>\n",
|
|
" <td>0.073960</td>\n",
|
|
" <td>0.105388</td>\n",
|
|
" <td>0.094736</td>\n",
|
|
" <td>0.069232</td>\n",
|
|
" <td>0.043159</td>\n",
|
|
" <td>0.051174</td>\n",
|
|
" <td>0.126120</td>\n",
|
|
" <td>0.002746</td>\n",
|
|
" <td>0.042584</td>\n",
|
|
" <td>0.065946</td>\n",
|
|
" <td>0.079198</td>\n",
|
|
" <td>0.101225</td>\n",
|
|
" <td>0.431433</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>min</th>\n",
|
|
" <td>0.301887</td>\n",
|
|
" <td>0.075949</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" <td>0.012158</td>\n",
|
|
" <td>0.026187</td>\n",
|
|
" <td>0.006920</td>\n",
|
|
" <td>0.018182</td>\n",
|
|
" <td>0.950731</td>\n",
|
|
" <td>0.683292</td>\n",
|
|
" <td>0.150000</td>\n",
|
|
" <td>0.570470</td>\n",
|
|
" <td>0.333333</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>25%</th>\n",
|
|
" <td>0.402516</td>\n",
|
|
" <td>0.145570</td>\n",
|
|
" <td>0.138554</td>\n",
|
|
" <td>0.028875</td>\n",
|
|
" <td>0.062193</td>\n",
|
|
" <td>0.058824</td>\n",
|
|
" <td>0.179545</td>\n",
|
|
" <td>0.954879</td>\n",
|
|
" <td>0.775561</td>\n",
|
|
" <td>0.215000</td>\n",
|
|
" <td>0.637584</td>\n",
|
|
" <td>0.555556</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>50%</th>\n",
|
|
" <td>0.433962</td>\n",
|
|
" <td>0.177215</td>\n",
|
|
" <td>0.186747</td>\n",
|
|
" <td>0.042553</td>\n",
|
|
" <td>0.076923</td>\n",
|
|
" <td>0.100346</td>\n",
|
|
" <td>0.256818</td>\n",
|
|
" <td>0.957189</td>\n",
|
|
" <td>0.805486</td>\n",
|
|
" <td>0.260000</td>\n",
|
|
" <td>0.697987</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>75%</th>\n",
|
|
" <td>0.484277</td>\n",
|
|
" <td>0.253165</td>\n",
|
|
" <td>0.234940</td>\n",
|
|
" <td>0.117021</td>\n",
|
|
" <td>0.101473</td>\n",
|
|
" <td>0.138408</td>\n",
|
|
" <td>0.356818</td>\n",
|
|
" <td>0.959306</td>\n",
|
|
" <td>0.830424</td>\n",
|
|
" <td>0.305000</td>\n",
|
|
" <td>0.758389</td>\n",
|
|
" <td>0.666667</td>\n",
|
|
" <td>0.000000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>max</th>\n",
|
|
" <td>0.798742</td>\n",
|
|
" <td>0.696203</td>\n",
|
|
" <td>0.602410</td>\n",
|
|
" <td>0.303191</td>\n",
|
|
" <td>0.436989</td>\n",
|
|
" <td>0.280277</td>\n",
|
|
" <td>0.575000</td>\n",
|
|
" <td>0.962935</td>\n",
|
|
" <td>0.935162</td>\n",
|
|
" <td>0.490000</td>\n",
|
|
" <td>0.953020</td>\n",
|
|
" <td>0.888889</td>\n",
|
|
" <td>1.000000</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
|
|
"count 325.000000 325.000000 325.000000 325.000000 \n",
|
|
"mean 0.448708 0.217381 0.184022 0.078864 \n",
|
|
"std 0.073960 0.105388 0.094736 0.069232 \n",
|
|
"min 0.301887 0.075949 0.000000 0.012158 \n",
|
|
"25% 0.402516 0.145570 0.138554 0.028875 \n",
|
|
"50% 0.433962 0.177215 0.186747 0.042553 \n",
|
|
"75% 0.484277 0.253165 0.234940 0.117021 \n",
|
|
"max 0.798742 0.696203 0.602410 0.303191 \n",
|
|
"\n",
|
|
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
|
|
"count 325.000000 325.000000 325.000000 325.000000 \n",
|
|
"mean 0.088017 0.100985 0.258073 0.957147 \n",
|
|
"std 0.043159 0.051174 0.126120 0.002746 \n",
|
|
"min 0.026187 0.006920 0.018182 0.950731 \n",
|
|
"25% 0.062193 0.058824 0.179545 0.954879 \n",
|
|
"50% 0.076923 0.100346 0.256818 0.957189 \n",
|
|
"75% 0.101473 0.138408 0.356818 0.959306 \n",
|
|
"max 0.436989 0.280277 0.575000 0.962935 \n",
|
|
"\n",
|
|
" pH sulphates alcohol quality color \n",
|
|
"count 325.000000 325.000000 325.000000 325.000000 325.000000 \n",
|
|
"mean 0.805141 0.265385 0.709269 0.650256 0.246154 \n",
|
|
"std 0.042584 0.065946 0.079198 0.101225 0.431433 \n",
|
|
"min 0.683292 0.150000 0.570470 0.333333 0.000000 \n",
|
|
"25% 0.775561 0.215000 0.637584 0.555556 0.000000 \n",
|
|
"50% 0.805486 0.260000 0.697987 0.666667 0.000000 \n",
|
|
"75% 0.830424 0.305000 0.758389 0.666667 0.000000 \n",
|
|
"max 0.935162 0.490000 0.953020 0.888889 1.000000 "
|
|
]
|
|
},
|
|
"execution_count": 78,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"wine_val.describe()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"from torch.utils.data import DataLoader, Dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class TabularDataset(Dataset):\n",
|
|
" def __init__(self, data):\n",
|
|
" self.data = data.values.astype('float32')\n",
|
|
"\n",
|
|
" def __getitem__(self, index):\n",
|
|
" x = torch.tensor(self.data[index, :-1])\n",
|
|
" y = torch.tensor(self.data[index, -1])\n",
|
|
" return x, y\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 81,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 64\n",
|
|
"train_dataset = TabularDataset(wine_train)\n",
|
|
"train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
|
"test_dataset = TabularDataset(wine_test)\n",
|
|
"test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 82,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class TabularModel(nn.Module):\n",
|
|
" def __init__(self, input_dim, hidden_dim, output_dim):\n",
|
|
" super(TabularModel, self).__init__()\n",
|
|
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
|
|
" self.softmax = nn.Softmax(dim=1)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" out = self.fc1(x)\n",
|
|
" out = self.relu(out)\n",
|
|
" out = self.fc2(out)\n",
|
|
" out = self.softmax(out)\n",
|
|
" return out"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_dim = wine_train.shape[1] - 1\n",
|
|
"hidden_dim = 32\n",
|
|
"output_dim = 2\n",
|
|
"model = TabularModel(input_dim, hidden_dim, output_dim)\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"optimizer = torch.optim.Adam(model.parameters())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = TabularModel(input_dim=len(wine_train.columns)-1, hidden_dim=32, output_dim=2)\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 85,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1, loss: 0.4864\n",
|
|
"Epoch 3, loss: 0.3413\n",
|
|
"Epoch 5, loss: 0.3345\n",
|
|
"Epoch 7, loss: 0.3337\n",
|
|
"Epoch 9, loss: 0.3331\n",
|
|
"Finished Training\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"num_epochs = 10\n",
|
|
"for epoch in range(num_epochs):\n",
|
|
" running_loss = 0.0\n",
|
|
" for i, data in enumerate(train_dataloader, 0):\n",
|
|
" inputs, labels = data\n",
|
|
" labels = labels.type(torch.LongTensor)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" outputs = model(inputs)\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" running_loss += loss.item()\n",
|
|
"\n",
|
|
" # Print the loss every 1000 mini-batches\n",
|
|
" if (epoch%2) == 0:\n",
|
|
" print(f'Epoch {epoch + 1}, loss: {running_loss / len(train_dataloader):.4f}')\n",
|
|
"\n",
|
|
"print('Finished Training')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy on test set: 98 %\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"correct = 0\n",
|
|
"total = 0\n",
|
|
"with torch.no_grad():\n",
|
|
" for data in test_dataloader:\n",
|
|
" inputs, labels = data\n",
|
|
" outputs = model(inputs.float())\n",
|
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
|
" total += labels.size(0)\n",
|
|
" correct += (predicted == labels).sum().item()\n",
|
|
"print('Accuracy on test set: %d %%' % (100 * correct / total))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|