polynomial_regression/05_Regresja_wielomianowa.ipynb

225 lines
239 KiB
Plaintext
Raw Normal View History

2022-06-12 17:45:18 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Regresja wielomianowa"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-7):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print('Algorithm does not converge!')\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv('data_flats.tsv', header=0, sep='\\t',\n",
" usecols=['price', 'rooms', 'sqrMetres'])\n",
"data = np.matrix(alldata[['sqrMetres', 'price']])"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"def get_poly_data(data, deg):\n",
" m, n_plus_1 = data.shape\n",
" n = n_plus_1 - 1\n",
"\n",
" X1 = data[:, 0:n]\n",
" X1 /= np.amax(X1, axis=0)\n",
"\n",
" Xs = [np.ones((m, 1)), X1]\n",
"\n",
" for i in range(2, deg+1):\n",
" Xn = np.power(X1, i)\n",
" Xn /= np.amax(Xn, axis=0)\n",
" Xs.append(Xn)\n",
"\n",
" X = np.matrix(np.concatenate(Xs, axis=1)).reshape(m, deg * n + 1)\n",
"\n",
" y = np.matrix(data[:, -1]).reshape(m, 1)\n",
"\n",
" return X, y\n",
"\n",
"\n",
"def polynomial_regression(theta, X, y, n):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" theta_start = np.matrix([0] * (n+1)).reshape(n+1, 1)\n",
" theta, logs = gradient_descent(cost, gradient, theta_start, X, y)\n",
" return lambda x: h_poly(theta, x)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x23797dfe880>]"
]
},
"metadata": {},
"execution_count": 85
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 691.2x388.8 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"366.394687pt\" version=\"1.1\" viewBox=\"0 0 611.892812 366.394687\" width=\"611.892812pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n <cc:Work>\r\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n <dc:date>2022-06-12T17:44:07.060976</dc:date>\r\n <dc:format>image/svg+xml</dc:format>\r\n <dc:creator>\r\n <cc:Agent>\r\n <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n </cc:Agent>\r\n </dc:creator>\r\n </cc:Work>\r\n </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 366.394687 \r\nL 611.892812 366.394687 \r\nL 611.892812 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 43.78125 328.838437 \r\nL 596.74125 328.838437 \r\nL 596.74125 17.798437 \r\nL 43.78125 17.798437 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"PathCollection_1\">\r\n <defs>\r\n <path d=\"M 0 3.535534 \r\nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \r\nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \r\nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \r\nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \r\nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \r\nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \r\nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \r\nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \r\nz\r\n\" id=\"m9649fa14e8\" style=\"stroke:#ff0000;\"/>\r\n </defs>\r\n <g clip-path=\"url(#pdadd225bfd)\">\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"295.30125\" xlink:href=\"#m9649fa14e8\" y=\"283.843067\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"279.94125\" xlink:href=\"#m9649fa14e8\" y=\"285.827381\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"234.82125\" xlink:href=\"#m9649fa14e8\" y=\"291.566543\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"233.86125\" xlink:href=\"#m9649fa14e8\" y=\"281.414803\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"234.82125\" xlink:href=\"#m9649fa14e8\" y=\"292.227502\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.22125\" xlink:href=\"#m9649fa14e8\" y=\"286.96447\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"257.86125\" xlink:href=\"#m9649fa14e8\" y=\"272.464856\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"272.26125\" xlink:href=\"#m9649fa14e8\" y=\"301.921415\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"306.82125\" xlink:href=\"#m9649fa14e8\" y=\"308.712969\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"258.82125\" xlink:href=\"#m9649fa14e8\" y=\"268.444445\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"259.78125\" xlink:href=\"#m9649fa14e8\" y=\"313.041087\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"305.86125\" xlink:href=\"#m9649fa14e8\" y=\"293.020175\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"271.30125\" xlink:href=\"#m9649fa14e8\" y=\"304.225622\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"256.90125\" xlink:href=\"#m9649fa14e8\" y=\"277.160655\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"294.34125\" xlink:href=\"#m9649fa14e8\" y=\"306.292122\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"265.54125\" xlink:href=\"#m9649fa14e8\" y=\"304.314029\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"249.22125\" xlink:href=\"#m9649fa14e8\" y=\"291.752
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAFvCAYAAADkPtfiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABWIklEQVR4nO3dd3hUZfr/8fcz6Y2eAAm9CAIKSBNFRdeKunbBXti17y7ruj91d7/bXV23uNa1YMOKvazYG6J0BCS0hB5CCYQySUid8/vjZEgIM8lMMpMzmfm8rotrwjknZ24mQ3LnKfdtLMtCRERERJzjcjoAERERkVinhExERETEYUrIRERERBymhExERETEYUrIRERERBymhExERETEYW0yITPGPGOM2WmMWRHg9ZcaY1YaY3KNMS+HOz4RERGRYJi2WIfMGHMiUALMsCxrWBPXDgReA06xLGuPMSbLsqydrRGniIiISCDa5AiZZVmzgeL6x4wx/Y0xHxljFhtjvjHGDK499VPgUcuy9tR+rpIxERERiShtMiHz40ngZ5ZljQLuAB6rPX4EcIQx5ltjzDxjzJmORSgiIiLiQ7zTAYSCMSYdOA543RjjPZxU+xgPDAQmAj2Ab4wxwyzL2tvKYYqIiIj4FBUJGfZI317Lskb4OFcAzLMsqwrYYIxZg52gLWzF+ERERET8ioopS8uy9mMnW5cAGNvw2tPvACfXHu+CPYW53ok4RURERHxpkwmZMeYVYC4wyBhTYIyZClwBTDXGLANygfNqL/8Y2G2MWQl8CfzasqzdTsQtIiIi4kubLHshIiIiEk3a5AiZiIiISDRRQiYiIiLisDa3y7JLly5Wnz59ACjce4DdpZV0Tk8ku32Ks4GJSOsqKIAdO/yf79YNcnJaLx6JGVU1Fqu378cAR3ZvR5zLNPk5EhsWL168y7KszOZ8bptLyPr06cOiRYsA+KFgH+c+ModOaYl8d/ePSIzXgJ9IzJg+HaZNg9LSw8+lpcFf/wpTp7Z6WBL9npq9nntmreLMod14/KpRTocjEcQYs6m5n9umM5hhOe0Y1DWD4tJKvlyjjkgiMWXyZHD5+RbmctnnRcLg7e+3AnD+SI3ASui06YTMGMPFo3oA8MbiAoejEZFWlZEBs2bZj2lp9rG0tLrj6enOxidRac12Nyu37addcjwnD27WzJSIT21uyrKh80Zmc99Hq/ly9U52lVTQJT2p6U8SkegwYQIUFsLMmZCfDwMG2CNjSsYkTN5Zao+OnX10NknxcQ5HI9GkzSdkWRnJTDwik89X7+TdpYVMndDX6ZBEpDWlp2utmLQKj8fi3drpygs0XSkh1qanLL00bSkiIuE2f0MxhfvKyemQwujeHZ0OR6JMVCRkpxyZRYfUBFZt209u4T6nwxERkSj0Tr3RMZdKXUiIRUVClhQfx3nDswGNkomISOiVV9Uw64dtAJw/MtvhaCQaRUVCBnDxqJ4AvLu0kMpqj8PRiIhINPli9U7cFdUcldOeAVkZTocjUShqEjLVJBMRkXBR7TEJt6hJyFSTTEREwmFPaSVfrdmJy8C5w7s7HY5EqahJyMCuSRbnMgdrkomIiLTU/37YRlWNxYSBmWRlJDsdjkSpqErIvDXJqj0W7y4tdDocERGJAt7dlRdqulLCKKoSMoBLRmvaUkREQmPz7jIWb9pDamIcpw/t6nQ4EsWiLiE7ZXBXOqommYiIhIC3VdIZQ7uRmtjmm9tIBIu6hCwx3sV5I+xhZY2SiYhIc1mWxVtL7J8j2l0p4RZ1CRnUtVJSTTIREWmueeuL2bi7jG7tkjm+f2enw5EoF5UJ2dDsdgzuZtck+2K1apKJiEjwXlmwGYBLR/cgPi4qf1xKBInKd5hqkomISEvsKa3koxXbMQYuHdPT6XAkBoQtITPG9DTGfGmMWWWMyTXG/MLHNRONMfuMMUtr//w+VM9/3ogcuybZmp0UuVWTTEREAvfW91uprPFwwsBMenRMdTociQHhHCGrBn5lWdaRwLHArcaYIT6u+8ayrBG1f/4cqifPzEji5EGZ1Hgs3q3dJSMiItIUy7J4tXa68vKxGh2T1hG2hMyyrG2WZS2p/dgNrAJadZtK/WlLy7Ja86lFRKSNWrJ5D3k7S+iSnsSPjlTtMWkdrbKGzBjTBxgJzPdxerwxZpkx5kNjzNBQPq+3Jtnq7W5yC/eH8tYiIhKlXlmwBbB/qU/QYn5pJWF/pxlj0oE3gWmWZTXMipYAvS3LGg48DLzj5x43GGMWGWMWFRUVBfzcqkkmIiLB2F9exf+W2633pmgxv7SisCZkxpgE7GTsJcuy3mp43rKs/ZZlldR+PAtIMMZ08XHdk5ZljbYsa3RmZmZQMdTVJNuqmmQiItKod5cWUl7lYXy/zvTpkuZ0OBJDwrnL0gBPA6ssy/q3n2u61V6HMWZsbTy7QxmHtybZnrIq1SQTERG/LMvilfn2Yv4pWswvrSycI2THA1cBp9QrazHJGHOTMeam2msuBlYYY5YBDwFTrBCvvldNMhERCcQPW/exctt+OqQmcMbQbk6HIzEmbJ1SLcuaA5gmrnkEeCRcMXidPzKH+z5cfbAmWWZGUrifUkRE2hjvYv4LR/YgOSHO4Wgk1sTE9pEu6UlMHJSlmmQiIuJTaUU179X+fLhM05XigJhIyEA1yURExL//LS+ktLKG0b07MrBrhtPhSAyKmYTslMFZqkkmIiI+eacrp4zt5XAkEqtiJiFTTTIREfFl9fb9LN2yl4zkeM4+qrvT4UiMipmEDFSTTEREDvdq7ejY+SNySEnUYn5xRkwlZKpJJiIi9ZVX1fDWEnvWRLXHxEkxlZCpJpmIiNT34Ypt7C+v5uge7Rma3d7pcCSGxVRCBnZNsniXOViTTEREYtcr82sX84/RYn5xVswlZKpJJiIiAPk7S1iwsZjUxDh+PCLb6XAkxsVcQgZ1i/tfX6SaZCIisWrmQrtv5Y+HZ5OeFLbGNSIBicmE7JTBWXROS2TNDjffb9nrdDgiItLKKqpreHOJPUui2mMSCWIyIUuMd3HpGHs3zYtzNzkcjYiItLZPV+6guLSSwd0yGN5Di/nFeTGZkAFcPrYXxsD/lm+juLTS6XBERKQVeWuPXTa2F8YYh6MRieGErGenVE4elEVljYfXFm1xOhwREWklm3eXMSd/F0nxLs6v7eAi4rSYTcgArjq2NwAvzd9EjUeL+0VEYsHMRfZi/rOP6k771ASHoxGxxXRCduIRmfTslMKW4gPMXlvkdDgiIhJmVTUeXl/krcyvxfwSOWI6IYtzGa4YZ4+SvTBPi/tFRKLdF6t3stNdQf/MNMb06eh0OCIHxXRCBnDp6J4kxrv4cs1OthSXOR2OiIiE0asL7OlKLeaXSBPzCVmntETOOao7lgUvzd/sdDgiIhImhXsP8PXaIhLjXFx4TA+nwxE5RMwnZABXjrenLV9btIWK6hqHoxERkXB4bdEWPBacPrQrndISnQ5H5BBKyICRPTswNLsdxaWVfPjDdqfDERGREKvxWLy2sK72mEikUUIGGGO48lgt7hcRiVaz84oo3FdOr06pjO/X2elwRA6jhKzWeSOyyUiKZ/GmPeQW7nM6HBERCSHvYv7JY3ricmkxv0QeJWS1UhPjuWiUvcjzxXla3C8iEi12usv5fNVO4lyGS0ZpMb9EJiVk9XinLd/5fiv7y6scjkZERELhjcUFVHssTj0yi6x2yU6HI+KTErJ6BmSlc1z/zhyoquGtxQVOhyMiIi3k8VgHG4mrMr9EMiVkDVxVb3G/Zam/pYhIWzZ3/W42F5eR0yGFEwdmOh2OiF9KyBo4dUhXurZLYl1RKXPX73Y6HBERaYFXahfzXzK6B3FazC8RTAlZAwlxroM1al5UCQwRkTaruLSST3J34DJ2mzyRSKaEzIfLxvYizmX4OHcHO/aXOx2OiIg0w1tLCqis8XDSEZlkd0hxOhyRRikh86Fru2ROH9KVGo91cLhbRETaDsuq+/6txfzSFigh88O7uP+VBZupqvE4HI2IiARj0aY9rCsqJTMjiVMGZzkdjkiTlJD5Mb5/Z/p
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"n = 2\n",
"x, y = get_poly_data(data, n)\n",
"fig = plot_data(x, y, xlabel='x', ylabel='y')\n",
"plot_fun(fig, polynomial_regression(theta, x, y, n), x)"
]
}
],
"metadata": {
"author": "Paweł Skórzewski",
"celltoolbar": "Slideshow",
"email": "pawel.skorzewski@amu.edu.pl",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8-final"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
},
"subtitle": "5.Regresja wielomianowa. Problem nadmiernego dopasowania[wykład]",
"title": "Uczenie maszynowe",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}