{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import pandas as pd\n",
"import sklearn.model_selection\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset wine (C:/Users/macty/.cache/huggingface/datasets/mstz___wine/wine/1.0.0/7c3844cac7ac7a22d5fbbaf60fc1d4e9c9deb1b9b9c4dbae6a7b1a962dbc96d8)\n",
"100%|██████████| 1/1 [00:00<00:00, 49.24it/s]\n"
]
}
],
"source": [
"dataset = load_dataset(\"mstz/wine\", \"wine\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['fixed_acidity', 'volatile_acidity', 'citric_acid', 'residual_sugar', 'chlorides', 'free_sulfur_dioxide', 'total_sulfur_dioxide', 'density', 'pH', 'sulphates', 'alcohol', 'quality', 'is_red'],\n",
" num_rows: 6497\n",
"})"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset[\"train\"]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"wine_dataset = pd.DataFrame(dataset[\"train\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed_acidity | \n",
" volatile_acidity | \n",
" citric_acid | \n",
" residual_sugar | \n",
" chlorides | \n",
" free_sulfur_dioxide | \n",
" total_sulfur_dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
" is_red | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" 5 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 7.8 | \n",
" 0.88 | \n",
" 0.00 | \n",
" 2.6 | \n",
" 0.098 | \n",
" 25.0 | \n",
" 67.0 | \n",
" 0.9968 | \n",
" 3.20 | \n",
" 0.68 | \n",
" 9.8 | \n",
" 5 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 7.8 | \n",
" 0.76 | \n",
" 0.04 | \n",
" 2.3 | \n",
" 0.092 | \n",
" 15.0 | \n",
" 54.0 | \n",
" 0.9970 | \n",
" 3.26 | \n",
" 0.65 | \n",
" 9.8 | \n",
" 5 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" 11.2 | \n",
" 0.28 | \n",
" 0.56 | \n",
" 1.9 | \n",
" 0.075 | \n",
" 17.0 | \n",
" 60.0 | \n",
" 0.9980 | \n",
" 3.16 | \n",
" 0.58 | \n",
" 9.8 | \n",
" 6 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" 5 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed_acidity volatile_acidity citric_acid residual_sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free_sulfur_dioxide total_sulfur_dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality is_red \n",
"0 9.4 5 0 \n",
"1 9.8 5 0 \n",
"2 9.8 5 0 \n",
"3 9.8 6 0 \n",
"4 9.4 5 0 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_dataset.head()# podgląd danych"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed_acidity | \n",
" volatile_acidity | \n",
" citric_acid | \n",
" residual_sugar | \n",
" chlorides | \n",
" free_sulfur_dioxide | \n",
" total_sulfur_dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
" is_red | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 7.215307 | \n",
" 0.339666 | \n",
" 0.318633 | \n",
" 5.443235 | \n",
" 0.056034 | \n",
" 30.525319 | \n",
" 115.744574 | \n",
" 0.994697 | \n",
" 3.218501 | \n",
" 0.531268 | \n",
" 10.491801 | \n",
" 5.818378 | \n",
" 0.753886 | \n",
"
\n",
" \n",
" std | \n",
" 1.296434 | \n",
" 0.164636 | \n",
" 0.145318 | \n",
" 4.757804 | \n",
" 0.035034 | \n",
" 17.749400 | \n",
" 56.521855 | \n",
" 0.002999 | \n",
" 0.160787 | \n",
" 0.148806 | \n",
" 1.192712 | \n",
" 0.873255 | \n",
" 0.430779 | \n",
"
\n",
" \n",
" min | \n",
" 3.800000 | \n",
" 0.080000 | \n",
" 0.000000 | \n",
" 0.600000 | \n",
" 0.009000 | \n",
" 1.000000 | \n",
" 6.000000 | \n",
" 0.987110 | \n",
" 2.720000 | \n",
" 0.220000 | \n",
" 8.000000 | \n",
" 3.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 25% | \n",
" 6.400000 | \n",
" 0.230000 | \n",
" 0.250000 | \n",
" 1.800000 | \n",
" 0.038000 | \n",
" 17.000000 | \n",
" 77.000000 | \n",
" 0.992340 | \n",
" 3.110000 | \n",
" 0.430000 | \n",
" 9.500000 | \n",
" 5.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 50% | \n",
" 7.000000 | \n",
" 0.290000 | \n",
" 0.310000 | \n",
" 3.000000 | \n",
" 0.047000 | \n",
" 29.000000 | \n",
" 118.000000 | \n",
" 0.994890 | \n",
" 3.210000 | \n",
" 0.510000 | \n",
" 10.300000 | \n",
" 6.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 75% | \n",
" 7.700000 | \n",
" 0.400000 | \n",
" 0.390000 | \n",
" 8.100000 | \n",
" 0.065000 | \n",
" 41.000000 | \n",
" 156.000000 | \n",
" 0.996990 | \n",
" 3.320000 | \n",
" 0.600000 | \n",
" 11.300000 | \n",
" 6.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" max | \n",
" 15.900000 | \n",
" 1.580000 | \n",
" 1.660000 | \n",
" 65.800000 | \n",
" 0.611000 | \n",
" 289.000000 | \n",
" 440.000000 | \n",
" 1.038980 | \n",
" 4.010000 | \n",
" 2.000000 | \n",
" 14.900000 | \n",
" 9.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
"mean 7.215307 0.339666 0.318633 5.443235 \n",
"std 1.296434 0.164636 0.145318 4.757804 \n",
"min 3.800000 0.080000 0.000000 0.600000 \n",
"25% 6.400000 0.230000 0.250000 1.800000 \n",
"50% 7.000000 0.290000 0.310000 3.000000 \n",
"75% 7.700000 0.400000 0.390000 8.100000 \n",
"max 15.900000 1.580000 1.660000 65.800000 \n",
"\n",
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
"mean 0.056034 30.525319 115.744574 0.994697 \n",
"std 0.035034 17.749400 56.521855 0.002999 \n",
"min 0.009000 1.000000 6.000000 0.987110 \n",
"25% 0.038000 17.000000 77.000000 0.992340 \n",
"50% 0.047000 29.000000 118.000000 0.994890 \n",
"75% 0.065000 41.000000 156.000000 0.996990 \n",
"max 0.611000 289.000000 440.000000 1.038980 \n",
"\n",
" pH sulphates alcohol quality is_red \n",
"count 6497.000000 6497.000000 6497.000000 6497.000000 6497.000000 \n",
"mean 3.218501 0.531268 10.491801 5.818378 0.753886 \n",
"std 0.160787 0.148806 1.192712 0.873255 0.430779 \n",
"min 2.720000 0.220000 8.000000 3.000000 0.000000 \n",
"25% 3.110000 0.430000 9.500000 5.000000 1.000000 \n",
"50% 3.210000 0.510000 10.300000 6.000000 1.000000 \n",
"75% 3.320000 0.600000 11.300000 6.000000 1.000000 \n",
"max 4.010000 2.000000 14.900000 9.000000 1.000000 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_dataset.describe(include='all')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGYCAYAAABcVthxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeAElEQVR4nO3df2xV9f3H8Vcp9PLz3sqP9tJQtAvRUgUcdaN3Uza048oui46SiWPIBDSQYtY2AjYhVdkSCP5AmGI3mSvLJArJdEIDtSlSt3HlR0214CBuYtql3hbnei/0C20p9/vH0hPuQKRQuH2X5yM5Cfd8Pvf0c4ylz5yee0iIRqNRAQAAGNIv3gsAAADoLgIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5vSP9wKulrNnz6qxsVHDhg1TQkJCvJcDAAAuQTQa1YkTJ5SWlqZ+/b76OkufDZjGxkalp6fHexkAAOAyNDQ0aMyYMV853mcDZtiwYZL++x/A7XbHeTUAAOBSRCIRpaenOz/Hv0qfDZiuXxu53W4CBgAAY77u9g9u4gUAAOYQMAAAwBwCBgAAmNOtgHnqqaeUkJAQs2VmZjrjp0+fVn5+vkaMGKGhQ4cqLy9PTU1NMceor69XIBDQ4MGDlZKSomXLlunMmTMxc/bs2aPJkyfL5XJp3LhxKisru/wzBAAAfU63r8Dceuut+vzzz53tr3/9qzNWWFio7du3a9u2baqurlZjY6NmzZrljHd2dioQCKi9vV179+7V5s2bVVZWppKSEmfOsWPHFAgENG3aNNXW1qqgoECLFi1SRUXFFZ4qAADoKxKi0Wj0Uic/9dRTeuutt1RbW3veWDgc1qhRo7RlyxbNnj1bknTkyBGNHz9ewWBQOTk52rlzp2bOnKnGxkalpqZKkkpLS7VixQodP35cSUlJWrFihcrLy3Xo0CHn2HPmzFFLS4t27dp1yScWiUTk8XgUDof5FBIAAEZc6s/vbl+B+eSTT5SWlqZvfOMbmjt3rurr6yVJNTU16ujoUG5urjM3MzNTY8eOVTAYlCQFg0FNmDDBiRdJ8vv9ikQiOnz4sDPn3GN0zek6xldpa2tTJBKJ2QAAQN/UrYCZMmWKysrKtGvXLr388ss6duyY7rrrLp04cUKhUEhJSUlKTk6OeU9qaqpCoZAkKRQKxcRL13jX2MXmRCIRnTp16ivXtnr1ank8HmfjKbwAAPRd3XqQ3YwZM5w/T5w4UVOmTNGNN96orVu3atCgQT2+uO4oLi5WUVGR87rrSX4AAKDvuaKPUScnJ+vmm2/WP/7xD3m9XrW3t6ulpSVmTlNTk7xeryTJ6/We96mkrtdfN8ftdl80klwul/PUXZ6+CwBA33ZFAXPy5En985//1OjRo5Wdna0BAwaoqqrKGT969Kjq6+vl8/kkST6fT3V1dWpubnbmVFZWyu12Kysry5lz7jG65nQdAwAAoFsB8/jjj6u6ulqfffaZ9u7dqx//+MdKTEzUgw8+KI/Ho4ULF6qoqEjvvvuuampq9PDDD8vn8yknJ0eSNH36dGVlZWnevHn68MMPVVFRoZUrVyo/P18ul0uStHjxYn366adavny5jhw5oo0bN2rr1q0qLCzs+bMHAAAmdesemH/961968MEH9e9//1ujRo3SnXfeqffff1+jRo2SJK1bt079+vVTXl6e2tra5Pf7tXHjRuf9iYmJ2rFjh5YsWSKfz6chQ4Zo/vz5WrVqlTMnIyND5eXlKiws1Pr16zVmzBht2rRJfr+/h04ZAABY163nwFjCc2AAALDnUn9+d+sKDGy46YnyeC8B19BnawLxXgIAXHP8Y44AAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYM4VBcyaNWuUkJCggoICZ9/p06eVn5+vESNGaOjQocrLy1NTU1PM++rr6xUIBDR48GClpKRo2bJlOnPmTMycPXv2aPLkyXK5XBo3bpzKysquZKkAAKAPueyAOXDggH7zm99o4sSJMfsLCwu1fft2bdu2TdXV1WpsbNSsWbOc8c7OTgUCAbW3t2vv3r3avHmzysrKVFJS4sw5duyYAoGApk2bptraWhUUFGjRokWqqKi43OUCAIA+5LIC5uTJk5o7d65eeeUV3XDDDc7+cDis3/3ud3r++ed19913Kzs7W7///e+1d+9evf/++5Kkd955Rx9//LH++Mc/6vbbb9eMGTP0y1/+Ui+99JLa29slSaWlpcrIyNBzzz2n8ePHa+nSpZo9e7bWrVvXA6cMAACsu6yAyc/PVyAQUG5ubsz+mpoadXR0xOzPzMzU2LFjFQwGJUnBYFATJkxQamqqM8fv9ysSiejw4cPOnP89tt/vd45xIW1tbYpEIjEbAADom/p39w2vv/66PvjgAx04cOC8sVAopKSkJCUnJ8fsT01NVSgUcuacGy9d411jF5sTiUR06tQpDRo06LyvvXr1aj399NPdPR0AAGBQt67ANDQ06Be/+IVee+01DRw48Gqt6bIUFxcrHA47W0NDQ7yXBAAArpJuBUxNTY2am5s1efJk9e/fX/3791d1dbU2bNig/v37KzU1Ve3t7WppaYl5X1NTk7xeryTJ6/We96mkrtdfN8ftdl/w6oskuVwuud3umA0AAPRN3QqYe+65R3V1daqtrXW2O+64Q3PnznX+PGDAAFVVVTnvOXr0qOrr6+Xz+SRJPp9PdXV1am5uduZUVlbK7XYrKyvLmXPuMbrmdB0DAABc37p1D8ywYcN02223xewbMmSIRowY4exfuHChioqKNHz4cLndbj322GPy+XzKycmRJE2fPl1ZWVmaN2+e1q5dq1AopJUrVyo/P18ul0uStHjxYr344otavny5FixYoN27d2vr1q0qLy/viXMGAADGdfsm3q+zbt069evXT3l5eWpra5Pf79fGjRud8cTERO3YsUNLliyRz+fTkCFDNH/+fK1atcqZk5GRofLychUWFmr9+vUaM2aMNm3aJL/f39PLBQAABiVEo9FovBdxNUQiEXk8HoXD4evufpibnuBK1fXkszWBeC8BAHrMpf785t9CAgAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGBOtwLm5Zdf1sSJE+V2u+V2u+Xz+bRz505n/PTp08rPz9eIESM0dOhQ5eXlqampKeYY9fX1CgQCGjx4sFJSUrRs2TKdOXMmZs6ePXs0efJkuVwujRs3TmVlZZd/hgAAoM/pVsCMGTNGa9asUU1NjQ4ePKi7775b9913nw4fPixJKiws1Pbt27Vt2zZVV1ersbFRs2bNct7f2dmpQCCg9vZ27d27V5s3b1ZZWZlKSkqcOceOHVMgENC0adNUW1urgoICLVq0SBUVFT10ygAAwLqEaDQavZIDDB8+XM8884xmz56tUaNGacuWLZo9e7Yk6ciRIxo/fryCwaBycnK0c+dOzZw5U42NjUpNTZUklZaWasWKFTp+/LiSkpK0YsUKlZeX69ChQ87XmDNnjlpaWrRr165LXlckEpHH41E4HJbb7b6SUzTnpifK470EXEOfrQnEewkA0GMu9ef3Zd8D09nZqddff12tra3y+XyqqalRR0eHcnNznTmZmZkaO3asgsGgJCkYDGrChAlOvEiS3+9XJBJxruIEg8GYY3TN6TrGV2lra1MkEonZAABA39TtgKmrq9PQoUPlcrm0ePFivfnmm8rKylIoFFJSUpKSk5Nj5qempioUCkmSQqFQTLx0jXeNXWxOJBLRqVOnvnJdq1evlsfjcbb09PTunhoAADCi2wFzyy23qLa2Vvv27dOSJUs0f/58ffzxx1djbd1SXFyscDjsbA0NDfFeEgAAuEr6d/cNSUlJGjdunCQpOztbBw4c0Pr16/XAAw+ovb1dLS0tMVdhmpqa5PV6JUler1f79++POV7Xp5TOnfO/n1xqamqS2+3WoEGDvnJdLpdLLperu6cDAAAMuuLnwJw9e1ZtbW3Kzs7WgAEDVFVV5YwdPXpU9fX18vl8kiSfz6e6ujo1Nzc7cyorK+V2u5WVleXMOfcYXXO6jgEAANCtKzDFxcWaMWOGxo4dqxMnTmjLli3as2ePKioq5PF4tHDhQhUVFWn48OFyu9167LHH5PP5lJOTI0maPn26srKyNG/ePK1du1ahUEgrV65Ufn6+c/Vk8eLFevHFF7V8+XItWLBAu3fv1tatW1VezidrAADAf3UrYJqbm/XQQw/p888/l8fj0cSJE1VRUaEf/OAHkqR169apX79+ysvLU1tbm/x+vzZu3Oi8PzExUTt27NCSJUvk8/k0ZMgQzZ8/X6tWrXLmZGRkqLy8XIWFhVq/fr3GjBmjTZs2ye/399ApAwAA6674OTC9Fc+BwfWC58AA6Euu+nNgAAAA4oWAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACY062AWb16tb71rW9p2LBhSklJ0f3336+jR4/GzDl9+rTy8/M1YsQIDR06VHl5eWpqaoqZU19fr0AgoMGDByslJUXLli3TmTNnYubs2bNHkydPlsvl0rhx41RWVnZ5ZwgAAPqcbgVMdXW18vPz9f7776uyslIdHR2aPn26WltbnTmFhYXavn27tm3bpurqajU2NmrWrFnOeGdnpwKBgNrb27V3715t3rxZZWVlKikpceYcO3ZMgUBA06ZNU21trQoKCrRo0SJVVFT0wCkDAADrEqLRaPRy33z8+HGlpKSourpaU6dOVTgc1qhRo7RlyxbNnj1bknTkyBGNHz9ewWBQOTk52rlzp2bOnKnGxkalpqZKkkpLS7VixQodP35cSUlJWrFihcrLy3Xo0CHna82ZM0ctLS3atWvXJa0tEonI4/EoHA7L7XZf7imadNMT5fFeAq6hz9YE4r0EAOgxl/rz+4rugQmHw5Kk4cOHS5JqamrU0dGh3NxcZ05mZqbGjh2rYDAoSQoGg5owYYITL5Lk9/sViUR0+PBhZ865x+ia03UMAABwfet/uW88e/asCgoK9N3vfle33XabJCkUCikpKUnJyckxc1NTUxUKhZw558ZL13jX2MXmRCIRnTp1SoMGDTpvPW1tbWpra3NeRyKRyz01AADQy132FZj8/HwdOnRIr7/+ek+u57KtXr1aHo/H2dLT0+O9JAAAcJVcVsAsXbpUO3bs0LvvvqsxY8Y4+71er9rb29XS0hIzv6mpSV6v15nzv59K6nr9dXPcbvcFr75IUnFxscLhsLM1NDRczqkBAAADuhUw0WhUS5cu1Ztvvqndu3crIyMjZjw7O1sDBgxQVVWVs+/o0aOqr6+Xz+eTJPl8PtXV1am5udmZU1lZKbfbraysLGfOucfomtN1jAtxuVxyu90xGwAA6Ju6dQ9Mfn6+tmzZoj//+c8aNmyYc8+Kx+PRoEGD5PF4tHDhQhUVFWn48OFyu9167LHH5PP5lJOTI0maPn26srKyNG/ePK1du1ahUEgrV65Ufn6+XC6XJGnx4sV68cUXtXz5ci1YsEC7d+/W1q1bVV7Op2sAAEA3r8C8/PLLCofD+v73v6/Ro0c72xtvvOHMWbdunWbOnKm8vDxNnTpVXq9Xf/rTn5zxxMRE7dixQ4mJifL5fPrZz36mhx56SKtWrXLmZGRkqLy8XJWVlZo0aZKee+45bdq0SX6/vwdOGQAAWHdFz4HpzXgODK4XPAcGQF9yTZ4DAwAAEA8EDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABz+sd7AQCAS3fTE+XxXgKuoc/WBOK9hF6LKzAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOQQMAAAwh4ABAADmEDAAAMAcAgYAAJhDwAAAAHMIGAAAYA4BAwAAzCFgAACAOd0OmPfee08/+tGPlJaWpoSEBL311lsx49FoVCUlJRo9erQGDRqk3NxcffLJJzFzvvzyS82dO1dut1vJyclauHChTp48GTPno48+0l133aWBAwcqPT1da9eu7f7ZAQCAPqnbAdPa2qpJkybppZdeuuD42rVrtWHDBpWWlmrfvn0aMmSI/H6/Tp8+7cyZO3euDh8+rMrKSu3YsUPvvfeeHn30UWc8Eolo+vTpuvHGG1VTU6NnnnlGTz31lH77299exikCAIC+pn933zBjxgzNmDHjgmPRaFQvvPCCVq5cqfvuu0+S9Ic//EGpqal66623NGfOHP3973/Xrl27dODAAd1xxx2SpF//+tf64Q9/qGeffVZpaWl67bXX1N7erldffVVJSUm69dZbVVtbq+effz4mdAAAwPWpR++BOXbsmEKhkHJzc519Ho9HU6ZMUTAYlCQFg0ElJyc78SJJubm56tevn/bt2+fMmTp1qpKSkpw5fr9fR48e1X/+85+eXDIAADCo21dgLiYUCkmSUlNTY/anpqY6Y6FQSCkpKbGL6N9fw4cPj5mTkZFx3jG6xm644YbzvnZbW5va2tqc15FI5ArPBgAA9FZ95lNIq1evlsfjcbb09PR4LwkAAFwlPRowXq9XktTU1BSzv6mpyRnzer1qbm6OGT9z5oy+/PLLmDkXOsa5X+N/FRcXKxwOO1tDQ8OVnxAAAOiVejRgMjIy5PV6VVVV5eyLRCLat2+ffD6fJMnn86mlpUU1NTXOnN27d+vs2bOaMmWKM+e9995TR0eHM6eyslK33HLLBX99JEkul0tutztmAwAAfVO3A+bkyZOqra1VbW2tpP/euFtbW6v6+nolJCSooKBAv/rVr/T222+rrq5ODz30kNLS0nT//fdLksaPH697771XjzzyiPbv36+//e1vWrp0qebMmaO0tDRJ0k9/+lMlJSVp4cKFOnz4sN544w2tX79eRUVFPXbiAADArm7fxHvw4EFNmzbNed0VFfPnz1dZWZmWL1+u1tZWPfroo2ppadGdd96pXbt2aeDAgc57XnvtNS1dulT33HOP+vXrp7y8PG3YsMEZ93g8euedd5Sfn6/s7GyNHDlSJSUlfIQaAABIkhKi0Wg03ou4GiKRiDwej8Lh8HX366SbniiP9xJwDX22JhDvJeAa4vv7+nI9fn9f6s/vPvMpJAAAcP0gYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAHAIGAACYQ8AAAABzCBgAAGAOAQMAAMwhYAAAgDkEDAAAMIeAAQAA5hAwAADAnF4dMC+99JJuuukmDRw4UFOmTNH+/fvjvSQAANAL9NqAeeONN1RUVKQnn3xSH3zwgSZNmiS/36/m5uZ4Lw0AAMRZrw2Y559/Xo888ogefvhhZWVlqbS0VIMHD9arr74a76UBAIA46x/vBVxIe3u7ampqVFxc7Ozr16+fcnNzFQwGL/ietrY2tbW1Oa/D4bAkKRKJXN3F9kJn2/4v3kvANXQ9/j9+PeP7+/pyPX5/d51zNBq96LxeGTBffPGFOjs7lZqaGrM/NTVVR44cueB7Vq9eraeffvq8/enp6VdljUBv4Xkh3isAcLVcz9/fJ06ckMfj+crxXhkwl6O4uFhFRUXO67Nnz+rLL7/UiBEjlJCQEMeV4VqIRCJKT09XQ0OD3G53vJcDoAfx/X19iUajOnHihNLS0i46r1cGzMiRI5WYmKimpqaY/U1NTfJ6vRd8j8vlksvlitmXnJx8tZaIXsrtdvMXHNBH8f19/bjYlZcuvfIm3qSkJGVnZ6uqqsrZd/bsWVVVVcnn88VxZQAAoDfolVdgJKmoqEjz58/XHXfcoW9/+9t64YUX1NraqocffjjeSwMAAHHWawPmgQce0PHjx1VSUqJQKKTbb79du3btOu/GXkD6768Qn3zyyfN+jQjAPr6/cSEJ0a/7nBIAAEAv0yvvgQEAALgYAgYAAJhDwAAAAHMIGAAAYA4BAwAAzOm1H6MGAFyfvvjiC7366qsKBoMKhUKSJK/Xq+985zv6+c9/rlGjRsV5hegNuAKDPqehoUELFiyI9zIAXIYDBw7o5ptv1oYNG+TxeDR16lRNnTpVHo9HGzZsUGZmpg4ePBjvZaIX4Dkw6HM+/PBDTZ48WZ2dnfFeCoBuysnJ0aRJk1RaWnreP8QbjUa1ePFiffTRRwoGg3FaIXoLfoUEc95+++2Ljn/66afXaCUAetqHH36osrKy8+JFkhISElRYWKhvfvObcVgZehsCBubcf//9SkhI0MUuHl7oLz8AvZ/X69X+/fuVmZl5wfH9+/fzT8pAEgEDg0aPHq2NGzfqvvvuu+B4bW2tsrOzr/GqAPSExx9/XI8++qhqamp0zz33OLHS1NSkqqoqvfLKK3r22WfjvEr0BgQMzMnOzlZNTc1XBszXXZ0B0Hvl5+dr5MiRWrdunTZu3Ojcy5aYmKjs7GyVlZXpJz/5SZxXid6Am3hhzl/+8he1trbq3nvvveB4a2urDh48qO9973vXeGUAelJHR4e++OILSdLIkSM1YMCAOK8IvQkBAwAAzOE5MAAAwBwCBgAAmEPAAAAAcwgYAABgDgEDAADMIWAAAIA5BAwAADCHgAEAAOb8P8GPNhrybRfSAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"wine_dataset[\"is_red\"].value_counts().plot(kind=\"bar\")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.2964337577998153"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_dataset[\"fixed_acidity\"].std()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([], dtype=int64), array([], dtype=int64))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.where(pd.isnull(wine_dataset))## sprawdzanie czy istnieją puste wartości"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"for column in wine_dataset.columns:\n",
" wine_dataset[column] = wine_dataset[column] / wine_dataset[column].abs().max() # normalizacja"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed_acidity | \n",
" volatile_acidity | \n",
" citric_acid | \n",
" residual_sugar | \n",
" chlorides | \n",
" free_sulfur_dioxide | \n",
" total_sulfur_dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
" is_red | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
" 6497.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 0.453793 | \n",
" 0.214978 | \n",
" 0.191948 | \n",
" 0.082724 | \n",
" 0.091708 | \n",
" 0.105624 | \n",
" 0.263056 | \n",
" 0.957378 | \n",
" 0.802619 | \n",
" 0.265634 | \n",
" 0.704148 | \n",
" 0.646486 | \n",
" 0.753886 | \n",
"
\n",
" \n",
" std | \n",
" 0.081537 | \n",
" 0.104200 | \n",
" 0.087541 | \n",
" 0.072307 | \n",
" 0.057338 | \n",
" 0.061417 | \n",
" 0.128459 | \n",
" 0.002886 | \n",
" 0.040097 | \n",
" 0.074403 | \n",
" 0.080048 | \n",
" 0.097028 | \n",
" 0.430779 | \n",
"
\n",
" \n",
" min | \n",
" 0.238994 | \n",
" 0.050633 | \n",
" 0.000000 | \n",
" 0.009119 | \n",
" 0.014730 | \n",
" 0.003460 | \n",
" 0.013636 | \n",
" 0.950076 | \n",
" 0.678304 | \n",
" 0.110000 | \n",
" 0.536913 | \n",
" 0.333333 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 25% | \n",
" 0.402516 | \n",
" 0.145570 | \n",
" 0.150602 | \n",
" 0.027356 | \n",
" 0.062193 | \n",
" 0.058824 | \n",
" 0.175000 | \n",
" 0.955110 | \n",
" 0.775561 | \n",
" 0.215000 | \n",
" 0.637584 | \n",
" 0.555556 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 50% | \n",
" 0.440252 | \n",
" 0.183544 | \n",
" 0.186747 | \n",
" 0.045593 | \n",
" 0.076923 | \n",
" 0.100346 | \n",
" 0.268182 | \n",
" 0.957564 | \n",
" 0.800499 | \n",
" 0.255000 | \n",
" 0.691275 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 75% | \n",
" 0.484277 | \n",
" 0.253165 | \n",
" 0.234940 | \n",
" 0.123100 | \n",
" 0.106383 | \n",
" 0.141869 | \n",
" 0.354545 | \n",
" 0.959585 | \n",
" 0.827930 | \n",
" 0.300000 | \n",
" 0.758389 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" max | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
"mean 0.453793 0.214978 0.191948 0.082724 \n",
"std 0.081537 0.104200 0.087541 0.072307 \n",
"min 0.238994 0.050633 0.000000 0.009119 \n",
"25% 0.402516 0.145570 0.150602 0.027356 \n",
"50% 0.440252 0.183544 0.186747 0.045593 \n",
"75% 0.484277 0.253165 0.234940 0.123100 \n",
"max 1.000000 1.000000 1.000000 1.000000 \n",
"\n",
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
"count 6497.000000 6497.000000 6497.000000 6497.000000 \n",
"mean 0.091708 0.105624 0.263056 0.957378 \n",
"std 0.057338 0.061417 0.128459 0.002886 \n",
"min 0.014730 0.003460 0.013636 0.950076 \n",
"25% 0.062193 0.058824 0.175000 0.955110 \n",
"50% 0.076923 0.100346 0.268182 0.957564 \n",
"75% 0.106383 0.141869 0.354545 0.959585 \n",
"max 1.000000 1.000000 1.000000 1.000000 \n",
"\n",
" pH sulphates alcohol quality is_red \n",
"count 6497.000000 6497.000000 6497.000000 6497.000000 6497.000000 \n",
"mean 0.802619 0.265634 0.704148 0.646486 0.753886 \n",
"std 0.040097 0.074403 0.080048 0.097028 0.430779 \n",
"min 0.678304 0.110000 0.536913 0.333333 0.000000 \n",
"25% 0.775561 0.215000 0.637584 0.555556 1.000000 \n",
"50% 0.800499 0.255000 0.691275 0.666667 1.000000 \n",
"75% 0.827930 0.300000 0.758389 0.666667 1.000000 \n",
"max 1.000000 1.000000 1.000000 1.000000 1.000000 "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_dataset.describe(include='all') # sprawdzanie wartości po znormalizowaniu"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"652 1.000000\n",
"442 0.981132\n",
"557 0.981132\n",
"554 0.974843\n",
"555 0.974843\n",
"243 0.943396\n",
"244 0.943396\n",
"544 0.899371\n",
"3125 0.893082\n",
"374 0.880503\n",
"Name: fixed_acidity, dtype: float64"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_dataset[\"fixed_acidity\"].nlargest(10) #sprawdza czy najwyższe wartości mają sens"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0 4408\n",
"0.0 1439\n",
"Name: is_red, dtype: int64"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"wine_train, wine_test = sklearn.model_selection.train_test_split(wine_dataset, test_size=0.1, random_state=1, stratify=wine_dataset[\"is_red\"])\n",
"wine_train[\"is_red\"].value_counts() \n",
"# podzielenie na train i test"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0 490\n",
"0.0 160\n",
"Name: is_red, dtype: int64"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_test[\"is_red\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"wine_test, wine_val = sklearn.model_selection.train_test_split(wine_test, test_size=0.5, random_state=1, stratify=wine_test[\"is_red\"]) # podzielenie na test i validation"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0 245\n",
"0.0 80\n",
"Name: is_red, dtype: int64"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_test[\"is_red\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0 245\n",
"0.0 80\n",
"Name: is_red, dtype: int64"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_val[\"is_red\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"sns.set_theme()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(wine_dataset.columns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"#sns.pairplot(data=wine_dataset, hue=\"is_red\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed_acidity | \n",
" volatile_acidity | \n",
" citric_acid | \n",
" residual_sugar | \n",
" chlorides | \n",
" free_sulfur_dioxide | \n",
" total_sulfur_dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
" is_red | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 0.448244 | \n",
" 0.217069 | \n",
" 0.180630 | \n",
" 0.078990 | \n",
" 0.088742 | \n",
" 0.103024 | \n",
" 0.257462 | \n",
" 0.957255 | \n",
" 0.803553 | \n",
" 0.263877 | \n",
" 0.703930 | \n",
" 0.646154 | \n",
" 0.753846 | \n",
"
\n",
" \n",
" std | \n",
" 0.074301 | \n",
" 0.107627 | \n",
" 0.078046 | \n",
" 0.070045 | \n",
" 0.051400 | \n",
" 0.054750 | \n",
" 0.125165 | \n",
" 0.002786 | \n",
" 0.039808 | \n",
" 0.072275 | \n",
" 0.078704 | \n",
" 0.095014 | \n",
" 0.431433 | \n",
"
\n",
" \n",
" min | \n",
" 0.314465 | \n",
" 0.063291 | \n",
" 0.000000 | \n",
" 0.012158 | \n",
" 0.031097 | \n",
" 0.010381 | \n",
" 0.020455 | \n",
" 0.951116 | \n",
" 0.713217 | \n",
" 0.130000 | \n",
" 0.570470 | \n",
" 0.333333 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 25% | \n",
" 0.402516 | \n",
" 0.145570 | \n",
" 0.144578 | \n",
" 0.027356 | \n",
" 0.060556 | \n",
" 0.058824 | \n",
" 0.168182 | \n",
" 0.955168 | \n",
" 0.775561 | \n",
" 0.210000 | \n",
" 0.637584 | \n",
" 0.555556 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 50% | \n",
" 0.433962 | \n",
" 0.177215 | \n",
" 0.180723 | \n",
" 0.042553 | \n",
" 0.078560 | \n",
" 0.100346 | \n",
" 0.261364 | \n",
" 0.957478 | \n",
" 0.800499 | \n",
" 0.250000 | \n",
" 0.691275 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 75% | \n",
" 0.471698 | \n",
" 0.253165 | \n",
" 0.222892 | \n",
" 0.113982 | \n",
" 0.101473 | \n",
" 0.141869 | \n",
" 0.343182 | \n",
" 0.959354 | \n",
" 0.827930 | \n",
" 0.300000 | \n",
" 0.758389 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" max | \n",
" 0.817610 | \n",
" 0.569620 | \n",
" 0.445783 | \n",
" 0.334347 | \n",
" 0.679214 | \n",
" 0.231834 | \n",
" 0.575000 | \n",
" 0.965264 | \n",
" 0.917706 | \n",
" 0.585000 | \n",
" 0.939597 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
"count 325.000000 325.000000 325.000000 325.000000 \n",
"mean 0.448244 0.217069 0.180630 0.078990 \n",
"std 0.074301 0.107627 0.078046 0.070045 \n",
"min 0.314465 0.063291 0.000000 0.012158 \n",
"25% 0.402516 0.145570 0.144578 0.027356 \n",
"50% 0.433962 0.177215 0.180723 0.042553 \n",
"75% 0.471698 0.253165 0.222892 0.113982 \n",
"max 0.817610 0.569620 0.445783 0.334347 \n",
"\n",
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
"count 325.000000 325.000000 325.000000 325.000000 \n",
"mean 0.088742 0.103024 0.257462 0.957255 \n",
"std 0.051400 0.054750 0.125165 0.002786 \n",
"min 0.031097 0.010381 0.020455 0.951116 \n",
"25% 0.060556 0.058824 0.168182 0.955168 \n",
"50% 0.078560 0.100346 0.261364 0.957478 \n",
"75% 0.101473 0.141869 0.343182 0.959354 \n",
"max 0.679214 0.231834 0.575000 0.965264 \n",
"\n",
" pH sulphates alcohol quality is_red \n",
"count 325.000000 325.000000 325.000000 325.000000 325.000000 \n",
"mean 0.803553 0.263877 0.703930 0.646154 0.753846 \n",
"std 0.039808 0.072275 0.078704 0.095014 0.431433 \n",
"min 0.713217 0.130000 0.570470 0.333333 0.000000 \n",
"25% 0.775561 0.210000 0.637584 0.555556 1.000000 \n",
"50% 0.800499 0.250000 0.691275 0.666667 1.000000 \n",
"75% 0.827930 0.300000 0.758389 0.666667 1.000000 \n",
"max 0.917706 0.585000 0.939597 1.000000 1.000000 "
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_test.describe()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed_acidity | \n",
" volatile_acidity | \n",
" citric_acid | \n",
" residual_sugar | \n",
" chlorides | \n",
" free_sulfur_dioxide | \n",
" total_sulfur_dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
" is_red | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
" 5847.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 0.453848 | \n",
" 0.215061 | \n",
" 0.192235 | \n",
" 0.082331 | \n",
" 0.092161 | \n",
" 0.105659 | \n",
" 0.262894 | \n",
" 0.957364 | \n",
" 0.802569 | \n",
" 0.265798 | \n",
" 0.704326 | \n",
" 0.646732 | \n",
" 0.753891 | \n",
"
\n",
" \n",
" std | \n",
" 0.081742 | \n",
" 0.104315 | \n",
" 0.088036 | \n",
" 0.071982 | \n",
" 0.058619 | \n",
" 0.061749 | \n",
" 0.128256 | \n",
" 0.002882 | \n",
" 0.039880 | \n",
" 0.074864 | \n",
" 0.079852 | \n",
" 0.096928 | \n",
" 0.430780 | \n",
"
\n",
" \n",
" min | \n",
" 0.238994 | \n",
" 0.050633 | \n",
" 0.000000 | \n",
" 0.009119 | \n",
" 0.014730 | \n",
" 0.003460 | \n",
" 0.013636 | \n",
" 0.950076 | \n",
" 0.678304 | \n",
" 0.110000 | \n",
" 0.536913 | \n",
" 0.333333 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 25% | \n",
" 0.402516 | \n",
" 0.145570 | \n",
" 0.150602 | \n",
" 0.027356 | \n",
" 0.062193 | \n",
" 0.058824 | \n",
" 0.176136 | \n",
" 0.955071 | \n",
" 0.775561 | \n",
" 0.215000 | \n",
" 0.637584 | \n",
" 0.555556 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 50% | \n",
" 0.440252 | \n",
" 0.183544 | \n",
" 0.186747 | \n",
" 0.045593 | \n",
" 0.076923 | \n",
" 0.100346 | \n",
" 0.268182 | \n",
" 0.957516 | \n",
" 0.800499 | \n",
" 0.255000 | \n",
" 0.691275 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 75% | \n",
" 0.484277 | \n",
" 0.253165 | \n",
" 0.234940 | \n",
" 0.123100 | \n",
" 0.106383 | \n",
" 0.141869 | \n",
" 0.353409 | \n",
" 0.959581 | \n",
" 0.827930 | \n",
" 0.300000 | \n",
" 0.758389 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" max | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
"count 5847.000000 5847.000000 5847.000000 5847.000000 \n",
"mean 0.453848 0.215061 0.192235 0.082331 \n",
"std 0.081742 0.104315 0.088036 0.071982 \n",
"min 0.238994 0.050633 0.000000 0.009119 \n",
"25% 0.402516 0.145570 0.150602 0.027356 \n",
"50% 0.440252 0.183544 0.186747 0.045593 \n",
"75% 0.484277 0.253165 0.234940 0.123100 \n",
"max 1.000000 1.000000 1.000000 1.000000 \n",
"\n",
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
"count 5847.000000 5847.000000 5847.000000 5847.000000 \n",
"mean 0.092161 0.105659 0.262894 0.957364 \n",
"std 0.058619 0.061749 0.128256 0.002882 \n",
"min 0.014730 0.003460 0.013636 0.950076 \n",
"25% 0.062193 0.058824 0.176136 0.955071 \n",
"50% 0.076923 0.100346 0.268182 0.957516 \n",
"75% 0.106383 0.141869 0.353409 0.959581 \n",
"max 1.000000 1.000000 1.000000 1.000000 \n",
"\n",
" pH sulphates alcohol quality is_red \n",
"count 5847.000000 5847.000000 5847.000000 5847.000000 5847.000000 \n",
"mean 0.802569 0.265798 0.704326 0.646732 0.753891 \n",
"std 0.039880 0.074864 0.079852 0.096928 0.430780 \n",
"min 0.678304 0.110000 0.536913 0.333333 0.000000 \n",
"25% 0.775561 0.215000 0.637584 0.555556 1.000000 \n",
"50% 0.800499 0.255000 0.691275 0.666667 1.000000 \n",
"75% 0.827930 0.300000 0.758389 0.666667 1.000000 \n",
"max 1.000000 1.000000 1.000000 1.000000 1.000000 "
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_train.describe()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed_acidity | \n",
" volatile_acidity | \n",
" citric_acid | \n",
" residual_sugar | \n",
" chlorides | \n",
" free_sulfur_dioxide | \n",
" total_sulfur_dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
" is_red | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
" 325.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 0.458355 | \n",
" 0.211412 | \n",
" 0.198091 | \n",
" 0.093521 | \n",
" 0.086537 | \n",
" 0.107596 | \n",
" 0.271556 | \n",
" 0.957757 | \n",
" 0.802570 | \n",
" 0.264446 | \n",
" 0.701160 | \n",
" 0.642393 | \n",
" 0.753846 | \n",
"
\n",
" \n",
" std | \n",
" 0.084621 | \n",
" 0.098749 | \n",
" 0.086862 | \n",
" 0.079346 | \n",
" 0.035141 | \n",
" 0.061805 | \n",
" 0.135185 | \n",
" 0.003031 | \n",
" 0.044183 | \n",
" 0.068086 | \n",
" 0.084939 | \n",
" 0.100957 | \n",
" 0.431433 | \n",
"
\n",
" \n",
" min | \n",
" 0.295597 | \n",
" 0.056962 | \n",
" 0.000000 | \n",
" 0.012158 | \n",
" 0.019640 | \n",
" 0.010381 | \n",
" 0.018182 | \n",
" 0.950413 | \n",
" 0.715711 | \n",
" 0.140000 | \n",
" 0.563758 | \n",
" 0.333333 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 25% | \n",
" 0.402516 | \n",
" 0.145570 | \n",
" 0.156627 | \n",
" 0.030395 | \n",
" 0.063830 | \n",
" 0.055363 | \n",
" 0.179545 | \n",
" 0.955456 | \n",
" 0.773067 | \n",
" 0.215000 | \n",
" 0.630872 | \n",
" 0.555556 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 50% | \n",
" 0.446541 | \n",
" 0.183544 | \n",
" 0.186747 | \n",
" 0.069149 | \n",
" 0.078560 | \n",
" 0.100346 | \n",
" 0.284091 | \n",
" 0.957978 | \n",
" 0.800499 | \n",
" 0.250000 | \n",
" 0.684564 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 75% | \n",
" 0.490566 | \n",
" 0.253165 | \n",
" 0.240964 | \n",
" 0.133739 | \n",
" 0.098200 | \n",
" 0.155709 | \n",
" 0.370455 | \n",
" 0.960028 | \n",
" 0.827930 | \n",
" 0.305000 | \n",
" 0.758389 | \n",
" 0.666667 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" max | \n",
" 0.943396 | \n",
" 0.746835 | \n",
" 0.445783 | \n",
" 0.480243 | \n",
" 0.278232 | \n",
" 0.266436 | \n",
" 0.570455 | \n",
" 0.972396 | \n",
" 1.000000 | \n",
" 0.570000 | \n",
" 0.939597 | \n",
" 0.888889 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" fixed_acidity volatile_acidity citric_acid residual_sugar \\\n",
"count 325.000000 325.000000 325.000000 325.000000 \n",
"mean 0.458355 0.211412 0.198091 0.093521 \n",
"std 0.084621 0.098749 0.086862 0.079346 \n",
"min 0.295597 0.056962 0.000000 0.012158 \n",
"25% 0.402516 0.145570 0.156627 0.030395 \n",
"50% 0.446541 0.183544 0.186747 0.069149 \n",
"75% 0.490566 0.253165 0.240964 0.133739 \n",
"max 0.943396 0.746835 0.445783 0.480243 \n",
"\n",
" chlorides free_sulfur_dioxide total_sulfur_dioxide density \\\n",
"count 325.000000 325.000000 325.000000 325.000000 \n",
"mean 0.086537 0.107596 0.271556 0.957757 \n",
"std 0.035141 0.061805 0.135185 0.003031 \n",
"min 0.019640 0.010381 0.018182 0.950413 \n",
"25% 0.063830 0.055363 0.179545 0.955456 \n",
"50% 0.078560 0.100346 0.284091 0.957978 \n",
"75% 0.098200 0.155709 0.370455 0.960028 \n",
"max 0.278232 0.266436 0.570455 0.972396 \n",
"\n",
" pH sulphates alcohol quality is_red \n",
"count 325.000000 325.000000 325.000000 325.000000 325.000000 \n",
"mean 0.802570 0.264446 0.701160 0.642393 0.753846 \n",
"std 0.044183 0.068086 0.084939 0.100957 0.431433 \n",
"min 0.715711 0.140000 0.563758 0.333333 0.000000 \n",
"25% 0.773067 0.215000 0.630872 0.555556 1.000000 \n",
"50% 0.800499 0.250000 0.684564 0.666667 1.000000 \n",
"75% 0.827930 0.305000 0.758389 0.666667 1.000000 \n",
"max 1.000000 0.570000 0.939597 0.888889 1.000000 "
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_val.describe()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader, Dataset"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"class TabularDataset(Dataset):\n",
" def __init__(self, data):\n",
" self.data = data.values.astype('float32')\n",
"\n",
" def __getitem__(self, index):\n",
" x = torch.tensor(self.data[index, :-1])\n",
" y = torch.tensor(self.data[index, -1])\n",
" return x, y\n",
"\n",
" def __len__(self):\n",
" return len(self.data)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 64\n",
"train_dataset = TabularDataset(wine_train)\n",
"train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
"test_dataset = TabularDataset(wine_test)\n",
"test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"class TabularModel(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, output_dim):\n",
" super(TabularModel, self).__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
" self.relu = nn.ReLU()\n",
" self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
" self.softmax = nn.Softmax(dim=1)\n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.relu(out)\n",
" out = self.fc2(out)\n",
" out = self.softmax(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"input_dim = wine_train.shape[1] - 1\n",
"hidden_dim = 32\n",
"output_dim = 2\n",
"model = TabularModel(input_dim, hidden_dim, output_dim)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"model = TabularModel(input_dim=len(wine_train.columns)-1, hidden_dim=32, output_dim=2)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, loss: 0.5358\n",
"Epoch 3, loss: 0.3417\n",
"Epoch 5, loss: 0.3344\n",
"Epoch 7, loss: 0.3338\n",
"Epoch 9, loss: 0.3318\n",
"Finished Training\n"
]
}
],
"source": [
"num_epochs = 10\n",
"for epoch in range(num_epochs):\n",
" running_loss = 0.0\n",
" for i, data in enumerate(train_dataloader, 0):\n",
" inputs, labels = data\n",
" labels = labels.type(torch.LongTensor)\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
"\n",
" # Print the loss every 1000 mini-batches\n",
" if (epoch%2) == 0:\n",
" print(f'Epoch {epoch + 1}, loss: {running_loss / len(train_dataloader):.4f}')\n",
"\n",
"print('Finished Training')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy on test set: 98 %\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"with torch.no_grad():\n",
" for data in test_dataloader:\n",
" inputs, labels = data\n",
" outputs = model(inputs.float())\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"print('Accuracy on test set: %d %%' % (100 * correct / total))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}