{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Uczenie maszynowe UMZ 2017/2018\n", "# 1. Wprowadzenie. Regresja liniowa" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 1.4 Funkcja kosztu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Zadanie\n", "Znając $x$ – ludność miasta (w dziesiątkach tysięcy mieszkańców),\n", "należy przewidzieć $y$ – dochód firmy transportowej (w dziesiątkach tysięcy dolarów).\n", "\n", "(Dane pochodzą z kursu „Machine Learning”, Andrew Ng, Coursera)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Nagłowki, można zignorować\n", "\n", "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as pl\n", "import ipywidgets as widgets\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "from IPython.display import display, Math, Latex" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Dane" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.1101,17.592\n", "\n", "5.5277,9.1302\n", "\n", "8.5186,13.662\n", "\n", "7.0032,11.854\n", "\n", "5.8598,6.8233\n", "\n", "8.3829,11.886\n", "\n", "7.4764,4.3483\n", "\n", "8.5781,12\n", "\n", "6.4862,6.5987\n", "\n", "5.0546,3.8166\n", "\n" ] } ], "source": [ "with open('data01.csv') as data:\n", " for line in data.readlines()[:10]:\n", " print(line)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Wczytanie danych" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x = [6.1101, 5.5277, 8.5186, 7.0032, 5.8598, 8.3829, 7.4764, 8.5781, 6.4862, 5.0546]\n", "y = [17.592, 9.1302, 13.662, 11.854, 6.8233, 11.886, 4.3483, 12.0, 6.5987, 3.8166]\n" ] } ], "source": [ "import csv\n", "\n", "reader = csv.reader(open('data01.csv'), delimiter=',')\n", "\n", "x = list()\n", "y = list()\n", "for xi, yi in reader:\n", " x.append(float(xi))\n", " y.append(float(yi)) \n", " \n", "print('x = {}'.format(x[:10])) \n", "print('y = {}'.format(y[:10]))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Hipoteza i parametry modelu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x $$\n", "$$ \\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right] $$\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Funkcje rysujące wykres kropkowy oraz prostą regresyjną\n", "\n", "def regdots(x, y): \n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter(x, y, c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel(u'Wielkość miejscowości [dzies. tys. mieszk.]')\n", " ax.set_ylabel(u'Dochód firmy [dzies. tys. dolarów]')\n", " ax.margins(.05, .05)\n", " pl.ylim(min(y) - 1, max(y) + 1)\n", " pl.xlim(min(x) - 1, max(x) + 1)\n", " return fig\n", "\n", "def regline(fig, fun, theta, x):\n", " ax = fig.axes[0]\n", " x0, x1 = min(x), max(x)\n", " X = [x0, x1]\n", " Y = [fun(theta, x) for x in X]\n", " ax.plot(X, Y, linewidth='2',\n", " label=(r'$y={theta0}{op}{theta1}x$'.format(\n", " theta0=theta[0],\n", " theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n", " op='+' if theta[1] >= 0 else '-')))\n", "\n", "def legend(fig):\n", " ax = fig.axes[0]\n", " handles, labels = ax.get_legend_handles_labels()\n", " # try-except block is a fix for a bug in Poly3DCollection\n", " try:\n", " fig.legend(handles, labels, fontsize='15', loc='lower right')\n", " except AttributeError:\n", " pass\n", "\n", "fig = regdots(x,y)\n", "legend(fig)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Hipoteza: funkcja liniowa jednej zmiennej\n", "\n", "def h(theta, x):\n", " return theta[0] + theta[1] * x" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przygotowanie interaktywnego wykresu\n", "\n", "sliderTheta01 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n", "sliderTheta11 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n", "\n", "def slide1(theta0, theta1):\n", " fig = regdots(x, y)\n", " regline(fig, h, [theta0, theta1], x)\n", " legend(fig)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58798d3c3e2c4f14b9fc02d62f47f1da", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(value=0.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide1, theta0=sliderTheta01, theta1=sliderTheta11)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Funkcja kosztu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Będziemy szukać takich parametrów $\\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right]$, które minimalizują fukcję kosztu $J(\\theta)$:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ \\hat\\theta = \\mathop{\\arg\\min}_{\\theta} J(\\theta) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ \\theta \\in \\mathbb{R}^2, \\quad J \\colon \\mathbb{R}^2 \\to \\mathbb{R} $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Błąd średniokwadratowy\n", "#### (metoda najmniejszych kwadratów)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ J(\\theta) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ J(\\theta_0, \\theta_1) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( \\theta_0 + \\theta_1 x^{(i)} - y^{(i)} \\right) ^2 $$" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def J(h, theta, x, y):\n", " m = len(y)\n", " return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Oblicz wartość funkcji kosztu i pokaż na wykresie\n", "\n", "def regline(fig, fun, theta, x, y):\n", " ax = fig.axes[0]\n", " x0, x1 = min(x), max(x)\n", " X = [x0, x1]\n", " Y = [fun(theta, x) for x in X]\n", " cost = J(fun, theta, x, y)\n", " ax.plot(X, Y, linewidth='2',\n", " label=(r'$y={theta0}{op}{theta1}x, \\; J(\\theta)={cost:.3}$'.format(\n", " theta0=theta[0],\n", " theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n", " op='+' if theta[1] >= 0 else '-',\n", " cost=cost)))\n", "\n", "sliderTheta02 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n", "sliderTheta12 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n", "\n", "def slide2(theta0, theta1):\n", " fig = regdots(x, y)\n", " regline(fig, h, [theta0, theta1], x, y)\n", " legend(fig)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c380a0076fb4480a9f9a0eb4f5fba756", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(value=0.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide2, theta0=sliderTheta02, theta1=sliderTheta12)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Funkcja kosztu jako funkcja zmiennej $\\theta$" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wykres funkcji kosztu dla ustalonego theta_1=1.0\n", "\n", "def costfun(fun, x, y):\n", " return lambda theta: J(fun, theta, x, y)\n", "\n", "def costplot(hypothesis, x, y, theta1=1.0):\n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.set_xlabel(r'$\\theta_0$')\n", " ax.set_ylabel(r'$J(\\theta)$')\n", " j = costfun(hypothesis, x, y)\n", " fun = lambda theta0: j([theta0, theta1])\n", " X = np.arange(-10, 10, 0.1)\n", " Y = [fun(x) for x in X]\n", " ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta_0, {theta1})$'.format(theta1=theta1)))\n", " return fig\n", "\n", "def slide3(theta1):\n", " fig = costplot(h, x, y, theta1)\n", " legend(fig)\n", "\n", "sliderTheta13 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=1.0, description=r'$\\theta_1$', width=300)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1b675985144a43d0997c79ada18bb446", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(FloatSlider(value=1.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide3, theta1=sliderTheta13)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wykres funkcji kosztu względem theta_0 i theta_1\n", "\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import pylab\n", "\n", "%matplotlib inline\n", "\n", "def costplot3d(hypothesis, x, y, show_gradient=False):\n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111, projection='3d')\n", " fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)\n", " ax.set_xlabel(r'$\\theta_0$')\n", " ax.set_ylabel(r'$\\theta_1$')\n", " ax.set_zlabel(r'$J(\\theta)$')\n", " \n", " j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n", " X = np.arange(-10, 10.1, 0.1)\n", " Y = np.arange(-1, 4.1, 0.1)\n", " X, Y = np.meshgrid(X, Y)\n", " Z = np.matrix([[J(hypothesis, [theta0, theta1], x, y) \n", " for theta0, theta1 in zip(xRow, yRow)] \n", " for xRow, yRow in zip(X, Y)])\n", " \n", " ax.plot_surface(X, Y, Z, rstride=2, cstride=8, linewidth=0.5,\n", " alpha=0.5, cmap='jet', zorder=0,\n", " label=r\"$J(\\theta)$\")\n", " ax.view_init(elev=20., azim=-150)\n", "\n", " ax.set_xlim3d(-10, 10);\n", " ax.set_ylim3d(-1, 4);\n", " ax.set_zlim3d(-100, 800);\n", "\n", " N = range(0, 800, 20)\n", " pl.contour(X, Y, Z, N, zdir='z', offset=-100, cmap='coolwarm', alpha=1)\n", " \n", " ax.plot([-3.89578088] * 2,\n", " [ 1.19303364] * 2,\n", " [-100, 4.47697137598], \n", " color='red', alpha=1, linewidth=1.3, zorder=100, linestyle='dashed',\n", " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", " ax.scatter([-3.89578088] * 2,\n", " [ 1.19303364] * 2,\n", " [-100, 4.47697137598], \n", " c='r', s=80, marker='x', alpha=1, linewidth=1.3, zorder=100, \n", " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", " \n", " if show_gradient:\n", " ax.plot([3.0, 1.1],\n", " [3.0, 2.4],\n", " [263.0, 125.0], \n", " color='green', alpha=1, linewidth=1.3, zorder=100)\n", " ax.scatter([3.0],\n", " [3.0],\n", " [263.0], \n", " c='g', s=30, marker='D', alpha=1, linewidth=1.3, zorder=100)\n", "\n", " ax.margins(0,0,0)\n", " fig.tight_layout()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "costplot3d(h, x, y)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "def costplot2d(hypothesis, x, y, gradient_values=[], nohead=False):\n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.set_xlabel(r'$\\theta_0$')\n", " ax.set_ylabel(r'$\\theta_1$')\n", " \n", " j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n", " X = np.arange(-10, 10.1, 0.1)\n", " Y = np.arange(-1, 4.1, 0.1)\n", " X, Y = np.meshgrid(X, Y)\n", " Z = np.matrix([[J(hypothesis, [theta0, theta1], x, y) \n", " for theta0, theta1 in zip(xRow, yRow)] \n", " for xRow, yRow in zip(X, Y)])\n", " \n", " N = range(0, 800, 20)\n", " pl.contour(X, Y, Z, N, cmap='coolwarm', alpha=1)\n", "\n", " ax.scatter([-3.89578088], [1.19303364], c='r', s=80, marker='x',\n", " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", " \n", " if len(gradient_values) > 0:\n", " prev_theta = gradient_values[0][1]\n", " ax.scatter([prev_theta[0]], [prev_theta[1]],\n", " c='g', s=30, marker='D', zorder=100)\n", " for cost, theta in gradient_values[1:]:\n", " dtheta = [theta[0] - prev_theta[0], theta[1] - prev_theta[1]]\n", " ax.arrow(prev_theta[0], prev_theta[1], dtheta[0], dtheta[1], \n", " color='green', \n", " head_width=(0.0 if nohead else 0.1), \n", " head_length=(0.0 if nohead else 0.2),\n", " zorder=100)\n", " prev_theta = theta\n", " \n", " return fig" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = costplot2d(h, x, y)\n", "legend(fig)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Cechy funkcji kosztu\n", "* $J(\\theta)$ jest funkcją wypukłą\n", "* $J(\\theta)$ posiada tylko jedno minimum lokalne" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 1.5. Metoda gradientu prostego" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Metoda gradientu prostego\n", "Metoda znajdowania minimów lokalnych." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Idea:\n", " * Zacznijmy od dowolnego $\\theta$.\n", " * Zmieniajmy powoli $\\theta$ tak, aby zmniejszać $J(\\theta)$, aż w końcu znajdziemy minimum." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "costplot3d(h, x, y, show_gradient=True)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przykładowe wartości kolejnych przybliżeń (sztuczne)\n", "\n", "gv = [[_, [3.0, 3.0]], [_, [2.6, 2.4]], [_, [2.2, 2.0]], [_, [1.6, 1.6]], [_, [0.4, 1.2]]]\n", "\n", "# Przygotowanie interaktywnego wykresu\n", "\n", "sliderSteps1 = widgets.IntSlider(min=0, max=3, step=1, value=0, description='kroki', width=300)\n", "\n", "def slide4(steps):\n", " costplot2d(h, x, y, gradient_values=gv[:steps+1])" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9b6665115b748edb457981078f0042c", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(IntSlider(value=0, description='kroki', max=3), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact(slide4, steps=sliderSteps1)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Metoda gradientu prostego\n", "W każdym kroku będziemy aktualizować parametry $\\theta_j$:\n", "\n", "$$ \\theta_j := \\theta_j - \\alpha \\frac{\\partial}{\\partial \\theta_j} J(\\theta) \\quad \\mbox{ dla każdego } j $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Współczynnik $\\alpha$ nazywamy *długością kroku* lub *współczynnikiem szybkości uczenia* (_learning rate_)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ \\begin{array}{rcl}\n", "\\dfrac{\\partial}{\\partial \\theta_j} J(\\theta)\n", " & = & \\dfrac{\\partial}{\\partial \\theta_j} \\dfrac{1}{2m} \\displaystyle\\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 \\\\\n", " & = & 2 \\cdot \\dfrac{1}{2m} \\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\\\\n", " & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( \\displaystyle\\sum_{i=0}^n \\theta_i x_i^{(i)} - y^{(i)} \\right)\\\\\n", " & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) -y^{(i)} \\right) x_j^{(i)} \\\\\n", "\\end{array} $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Czyli dla regresji liniowej jednej zmiennej:\n", "\n", "$$ h_\\theta(x) = \\theta_0 + \\theta_1x $$\n", "\n", "w każdym kroku będziemy aktualizować:\n", "\n", "$$\n", "\\begin{array}{rcl}\n", "\\theta_0 & := & \\theta_0 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) \\\\ \n", "\\theta_1 & := & \\theta_1 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) x^{(i)}\\\\ \n", "\\end{array}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "###### Uwaga!\n", " * W każdym kroku aktualizujemy *jednocześnie* $\\theta_0$ i $\\theta_1$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " * Kolejne kroki wykonujemy aż uzyskamy zbieżność" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Metoda gradientu prostego – implementacja" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wyświetlanie macierzy w LaTeX-u\n", "\n", "def LatexMatrix(matrix):\n", " ltx = r'\\left[\\begin{array}'\n", " m, n = matrix.shape\n", " ltx += '{' + (\"r\" * n) + '}'\n", " for i in range(m):\n", " ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n", " ltx += r'\\end{array}\\right]'\n", " return ltx" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def gradient_descent(h, cost_fun, theta, x, y, alpha, eps):\n", " current_cost = cost_fun(h, theta, x, y)\n", " log = [[current_cost, theta]] # log przechowuje wartości kosztu i parametrów\n", " m = len(y)\n", " while True:\n", " new_theta = [\n", " theta[0] - alpha/float(m) * sum(h(theta, x[i]) - y[i]\n", " for i in range(m)), \n", " theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n", " for i in range(m))]\n", " theta = new_theta # jednoczesna aktualizacja - używamy zmiennej tymaczasowej\n", " try:\n", " current_cost, prev_cost = cost_fun(h, theta, x, y), current_cost\n", " except OverflowError:\n", " break \n", " if abs(prev_cost - current_cost) <= eps:\n", " break \n", " log.append([current_cost, theta])\n", " return theta, log" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/latex": [ "$$\\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}-3.5074 \\\\ 1.1540 \\\\ \\end{array}\\right] \\quad J(\\theta) = 4.4908 \\quad \\textrm{po 644 iteracjach}$$" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.02, eps=0.0001)\n", "\n", "display(Math(r'\\large\\textrm{Wynik:}\\quad \\theta = ' + \n", " LatexMatrix(np.matrix(best_theta).reshape(2,1)) + \n", " (r' \\quad J(\\theta) = %.4f' % log[-1][0]) \n", " + r' \\quad \\textrm{po %d iteracjach}' % len(log))) " ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przygotowanie interaktywnego wykresu\n", "\n", "sliderSteps2 = widgets.IntSlider(min=0, max=500, step=1, value=1, description='kroki', width=300)\n", "\n", "def slide5(steps):\n", " costplot2d(h, x, y, gradient_values=log[:steps+1], nohead=True)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fc6dd6fa531b4d8eac47c456831f61a4", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(IntSlider(value=1, description='kroki', max=500), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide5, steps=sliderSteps2)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Współczynnik $\\alpha$ (długość kroku)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " * Jeżeli długość kroku jest zbyt mała, algorytm może działać zbyt wolno." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " * Jeżeli długość kroku jest zbyt duża, algorytm może nie być zbieżny." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 1.6. Regresja liniowa wielu zmiennych" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Przykład – ceny mieszkań" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y : price x1: isNew x2: rooms x3: floor x4: location x5: sqrMetres\n", "476118.0 False 3 1 Centrum 78 \n", "459531.0 False 3 2 Sołacz 62 \n", "411557.0 False 3 0 Sołacz 15 \n", "496416.0 False 4 0 Sołacz 14 \n", "406032.0 False 3 0 Sołacz 15 \n", "450026.0 False 3 1 Naramowice 80 \n", "571229.15 False 2 4 Wilda 39 \n", "325000.0 False 3 1 Grunwald 54 \n", "268229.0 False 2 1 Grunwald 90 \n" ] } ], "source": [ "reader = csv.reader(open('data02.tsv'), delimiter='\\t')\n", "for i, row in enumerate(list(reader)[:10]):\n", " if i == 0:\n", " print(' '.join(['{}: {:8}'.format('x' + str(j) if j > 0 else 'y ', entry)\n", " for j, entry in enumerate(row)]))\n", " else:\n", " print(' '.join(['{:12}'.format(entry) for entry in row]))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ x^{(2)} = ({\\rm \"False\"}, 3, 2, {\\rm \"Sołacz\"}, 62), \\quad x_3^{(2)} = 2 $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Hipoteza" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "W naszym przypadku:\n", "\n", "$$ h_\\theta(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5 $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "W ogólności:\n", "\n", "$$ h_\\theta(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\ldots + \\theta_n x_n $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Jeżeli zdefiniujemy $x_0 = 1$:\n", "\n", "$$\n", "\\begin{array}{rcl}\n", "h_\\theta(x)\n", " & = & \\theta_0 x_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\ldots + \\theta_n x_n \\\\\n", " & = & \\displaystyle\\sum_{i=0}^{n} \\theta_i x_i \\\\\n", " & = & \\theta^T \\, x \\\\\n", " & = & x^T \\, \\theta \\\\\n", "\\end{array}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Metoda gradientu prostego – notacja macierzowa" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\n", "X=\\left[\\begin{array}{cc}\n", "1 & \\left( \\vec x^{(1)} \\right)^T \\\\\n", "1 & \\left( \\vec x^{(2)} \\right)^T \\\\\n", "\\vdots & \\vdots\\\\\n", "1 & \\left( \\vec x^{(m)} \\right)^T \\\\\n", "\\end{array}\\right] \n", "= \\left[\\begin{array}{cccc}\n", "1 & x_1^{(1)} & \\cdots & x_n^{(1)} \\\\\n", "1 & x_1^{(2)} & \\cdots & x_n^{(2)} \\\\\n", "\\vdots & \\vdots & \\ddots & \\vdots\\\\\n", "1 & x_1^{(m)} & \\cdots & x_n^{(m)} \\\\\n", "\\end{array}\\right]\n", "\\quad\n", "\\vec{y} = \n", "\\left[\\begin{array}{c}\n", "y^{(1)}\\\\\n", "y^{(2)}\\\\\n", "\\vdots\\\\\n", "y^{(m)}\\\\\n", "\\end{array}\\right]\n", "\\quad\n", "\\theta = \\left[\\begin{array}{c}\n", "\\theta_0\\\\\n", "\\theta_1\\\\\n", "\\vdots\\\\\n", "\\theta_n\\\\\n", "\\end{array}\\right]\n", "$$" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wczytwanie danych z pliku za pomocą numpy\n", "# Wersje macierzowe funkcji rysowania wykresów punktowych oraz krzywej regresyjnej\n", "\n", "data = np.loadtxt('data01.csv', delimiter=',')\n", "m, n_plus_1 = data.shape\n", "n = n_plus_1 - 1\n", "Xn = data[:, 0:n].reshape(m, n)\n", "\n", "# Dodaj kolumnę jedynek do macierzy\n", "XMx = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", "yMx = np.matrix(data[:, 1]).reshape(m, 1)\n", "\n", "def hMx(theta, X):\n", " return X * theta\n", "\n", "def regdotsMx(X, y): \n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel('Populacja')\n", " ax.set_ylabel('Zysk')\n", " ax.margins(.05, .05)\n", " pl.ylim(y.min() - 1, y.max() + 1)\n", " pl.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", " return fig\n", "\n", "def reglineMx(fig, fun, theta, X):\n", " ax = fig.axes[0]\n", " x0, x1 = np.min(X[:, 1]), np.max(X[:, 1])\n", " L = [x0, x1]\n", " LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n", " ax.plot(L, fun(theta, LX), linewidth='2',\n", " label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n", " theta0=float(theta[0][0]),\n", " theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n", " op='+' if theta[1][0] >= 0 else '-')))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Funkcja kosztu – notacja macierzowa" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$J(\\theta)=\\dfrac{1}{2|\\vec y|}\\left(X\\theta-\\vec{y}\\right)^T\\left(X\\theta-\\vec{y}\\right)$$ \n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "text/latex": [ "$$\\Large J(\\theta) = 4.5885$$" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Wersja macierzowa funkcji kosztu\n", "\n", "def JMx(theta,X,y):\n", " m = len(y)\n", " J = 1.0 / (2.0 * m) * ((X * theta - y) . T * ( X * theta - y))\n", " return J.item()\n", "\n", "thetaMx = np.matrix([-5, 1.3]).reshape(2, 1) \n", "\n", "cost = JMx(thetaMx,XMx,yMx) \n", "display(Math(r'\\Large J(\\theta) = %.4f' % cost))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Gradient – notacja macierzowa" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T\\left(X\\theta-\\vec y\\right)$$" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "text/latex": [ "$$\\large \\theta = \\left[\\begin{array}{r}-5.0000 \\\\ 1.3000 \\\\ \\end{array}\\right]\\quad\\large \\nabla J(\\theta) = \\left[\\begin{array}{r}-0.2314 \\\\ -0.3027 \\\\ \\end{array}\\right]$$" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Wersja macierzowa gradientu funkcji kosztu\n", "\n", "def dJMx(theta,X,y):\n", " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", "\n", "thetaMx2 = np.matrix([-5, 1.3]).reshape(2, 1)\n", "\n", "display(Math(r'\\large \\theta = ' + LatexMatrix(thetaMx2) + \n", " r'\\quad' + r'\\large \\nabla J(\\theta) = ' \n", " + LatexMatrix(dJMx(thetaMx2,XMx,yMx))))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Algorytm gradientu prostego – notacja macierzowa" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ \\theta := \\theta - \\alpha \\, \\nabla J(\\theta) $$" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "text/latex": [ "$$\\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}-3.7217 \\\\ 1.1755 \\\\ \\end{array}\\right] \\quad J(\\theta) = 4.4797 \\quad \\textrm{po 1734 iteracjach}$$" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n", "\n", "def GDMx(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-3):\n", " current_cost = fJ(theta, X, y)\n", " log = [[current_cost, theta]]\n", " while True:\n", " theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n", " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " log.append([current_cost, theta]) \n", " return theta, log\n", "\n", "thetaStartMx = np.matrix([0, 0]).reshape(2, 1)\n", "\n", "# Zmieniamy wartości alpha (rozmiar kroku) oraz eps (kryterium stopu)\n", "thetaBestMx, log = GDMx(JMx, dJMx, thetaStartMx, \n", " XMx, yMx, alpha=0.01, eps=0.00001)\n", "\n", "######################################################################\n", "display(Math(r'\\large\\textrm{Wynik:}\\quad \\theta = ' + \n", " LatexMatrix(thetaBestMx) + \n", " (r' \\quad J(\\theta) = %.4f' % log[-1][0]) \n", " + r' \\quad \\textrm{po %d iteracjach}' % len(log))) " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 1.7. Metoda gradientu prostego w praktyce" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Kryterium stopu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Na wykresie zobaczymy porównanie regresji dla różnych wartości `eps`" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = regdotsMx(XMx, yMx)\n", "theta_e1, log1 = GDMx(JMx, dJMx, thetaStartMx, XMx, yMx, alpha=0.01, eps=0.01)\n", "reglineMx(fig, hMx, theta_e1, XMx)\n", "theta_e2, log2 = GDMx(JMx, dJMx, thetaStartMx, XMx, yMx, alpha=0.01, eps=0.000001)\n", "reglineMx(fig, hMx, theta_e2, XMx)\n", "legend(fig)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/latex": [ "$$\\theta_{10^{-2}} = \\left[\\begin{array}{r}0.0511 \\\\ 0.7957 \\\\ \\end{array}\\right]\\quad\\theta_{10^{-6}} = \\left[\\begin{array}{r}-3.8407 \\\\ 1.1875 \\\\ \\end{array}\\right]$$" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(Math(r'\\theta_{10^{-2}} = ' + LatexMatrix(theta_e1) +\n", " r'\\quad\\theta_{10^{-6}} = ' + LatexMatrix(theta_e2)))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Długość kroku ($\\alpha$)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": true, "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Jak zmienia się koszt w kolejnych krokach w zależności od alfa\n", "\n", "def costchangeplot(logs):\n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.set_xlabel('krok')\n", " ax.set_ylabel(r'$J(\\theta)$')\n", "\n", " X = np.arange(0, 500, 1)\n", " Y = [logs[step][0] for step in X]\n", " ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta)$'))\n", " return fig\n", "\n", "def slide7(alpha):\n", " best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=alpha, eps=0.0001)\n", " fig = costchangeplot(log)\n", " legend(fig)\n", "\n", "sliderAlpha1 = widgets.FloatSlider(min=0.01, max=0.03, step=0.001, value=0.02, description=r'$\\alpha$', width=300)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8f617b06eebc49168e4b8daac54528cd", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type interactive.

\n", "

\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "interactive(children=(FloatSlider(value=0.02, description='$\\\\alpha$', max=0.03, min=0.01, step=0.001), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide7, alpha=sliderAlpha1)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 1.8 Regresja liniowa – dodatek\n", "### Regresja liniowa za pomocą macierzy normalnej" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Zamiast korzystać z algorytmu gradientu prostego\n", "możemy bezpośrednio obliczyć minimum $J(\\theta)$ dla regresji liniowej ze wzoru: \n", "\n", "$$ \\theta = \\left( X^T X \\right)^{-1} \\, X^T \\, \\vec y $$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15rc1" }, "livereveal": { "start_slideshow_at": "selected", "theme": "amu" } }, "nbformat": 4, "nbformat_minor": 2 }