\n",
" If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another notebook frontend (for example, a static\n",
" rendering on GitHub or NBViewer ),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(value=0.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide1, theta0=sliderTheta01, theta1=sliderTheta11)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja kosztu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Będziemy szukać takich parametrów $\\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right]$, które minimalizują fukcję kosztu $J(\\theta)$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\hat\\theta = \\mathop{\\arg\\min}_{\\theta} J(\\theta) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\theta \\in \\mathbb{R}^2, \\quad J \\colon \\mathbb{R}^2 \\to \\mathbb{R} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Błąd średniokwadratowy\n",
"#### (metoda najmniejszych kwadratów)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ J(\\theta) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ J(\\theta_0, \\theta_1) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( \\theta_0 + \\theta_1 x^{(i)} - y^{(i)} \\right) ^2 $$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true,
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"def J(h, theta, x, y):\n",
" m = len(y)\n",
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true,
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Oblicz wartość funkcji kosztu i pokaż na wykresie\n",
"\n",
"def regline(fig, fun, theta, x, y):\n",
" ax = fig.axes[0]\n",
" x0, x1 = min(x), max(x)\n",
" X = [x0, x1]\n",
" Y = [fun(theta, x) for x in X]\n",
" cost = J(fun, theta, x, y)\n",
" ax.plot(X, Y, linewidth='2',\n",
" label=(r'$y={theta0}{op}{theta1}x, \\; J(\\theta)={cost:.3}$'.format(\n",
" theta0=theta[0],\n",
" theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n",
" op='+' if theta[1] >= 0 else '-',\n",
" cost=cost)))\n",
"\n",
"sliderTheta02 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n",
"sliderTheta12 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n",
"\n",
"def slide2(theta0, theta1):\n",
" fig = regdots(x, y)\n",
" regline(fig, h, [theta0, theta1], x, y)\n",
" legend(fig)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c380a0076fb4480a9f9a0eb4f5fba756",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"Failed to display Jupyter Widget of type interactive
.
\n",
"\n",
" If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another notebook frontend (for example, a static\n",
" rendering on GitHub or NBViewer ),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(value=0.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide2, theta0=sliderTheta02, theta1=sliderTheta12)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja kosztu jako funkcja zmiennej $\\theta$"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true,
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wykres funkcji kosztu dla ustalonego theta_1=1.0\n",
"\n",
"def costfun(fun, x, y):\n",
" return lambda theta: J(fun, theta, x, y)\n",
"\n",
"def costplot(hypothesis, x, y, theta1=1.0):\n",
" fig = pl.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.set_xlabel(r'$\\theta_0$')\n",
" ax.set_ylabel(r'$J(\\theta)$')\n",
" j = costfun(hypothesis, x, y)\n",
" fun = lambda theta0: j([theta0, theta1])\n",
" X = np.arange(-10, 10, 0.1)\n",
" Y = [fun(x) for x in X]\n",
" ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta_0, {theta1})$'.format(theta1=theta1)))\n",
" return fig\n",
"\n",
"def slide3(theta1):\n",
" fig = costplot(h, x, y, theta1)\n",
" legend(fig)\n",
"\n",
"sliderTheta13 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=1.0, description=r'$\\theta_1$', width=300)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1b675985144a43d0997c79ada18bb446",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"Failed to display Jupyter Widget of type interactive
.
\n",
"\n",
" If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the Jupyter\n",
" Widgets Documentation for setup instructions.\n",
"
\n",
"\n",
" If you're reading this message in another notebook frontend (for example, a static\n",
" rendering on GitHub or NBViewer ),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"
\n"
],
"text/plain": [
"interactive(children=(FloatSlider(value=1.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widget-interact',))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide3, theta1=sliderTheta13)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": true,
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wykres funkcji kosztu względem theta_0 i theta_1\n",
"\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import pylab\n",
"\n",
"%matplotlib inline\n",
"\n",
"def costplot3d(hypothesis, x, y, show_gradient=False):\n",
" fig = pl.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111, projection='3d')\n",
" fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)\n",
" ax.set_xlabel(r'$\\theta_0$')\n",
" ax.set_ylabel(r'$\\theta_1$')\n",
" ax.set_zlabel(r'$J(\\theta)$')\n",
" \n",
" j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n",
" X = np.arange(-10, 10.1, 0.1)\n",
" Y = np.arange(-1, 4.1, 0.1)\n",
" X, Y = np.meshgrid(X, Y)\n",
" Z = np.matrix([[J(hypothesis, [theta0, theta1], x, y) \n",
" for theta0, theta1 in zip(xRow, yRow)] \n",
" for xRow, yRow in zip(X, Y)])\n",
" \n",
" ax.plot_surface(X, Y, Z, rstride=2, cstride=8, linewidth=0.5,\n",
" alpha=0.5, cmap='jet', zorder=0,\n",
" label=r\"$J(\\theta)$\")\n",
" ax.view_init(elev=20., azim=-150)\n",
"\n",
" ax.set_xlim3d(-10, 10);\n",
" ax.set_ylim3d(-1, 4);\n",
" ax.set_zlim3d(-100, 800);\n",
"\n",
" N = range(0, 800, 20)\n",
" pl.contour(X, Y, Z, N, zdir='z', offset=-100, cmap='coolwarm', alpha=1)\n",
" \n",
" ax.plot([-3.89578088] * 2,\n",
" [ 1.19303364] * 2,\n",
" [-100, 4.47697137598], \n",
" color='red', alpha=1, linewidth=1.3, zorder=100, linestyle='dashed',\n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" ax.scatter([-3.89578088] * 2,\n",
" [ 1.19303364] * 2,\n",
" [-100, 4.47697137598], \n",
" c='r', s=80, marker='x', alpha=1, linewidth=1.3, zorder=100, \n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" \n",
" if show_gradient:\n",
" ax.plot([3.0, 1.1],\n",
" [3.0, 2.4],\n",
" [263.0, 125.0], \n",
" color='green', alpha=1, linewidth=1.3, zorder=100)\n",
" ax.scatter([3.0],\n",
" [3.0],\n",
" [263.0], \n",
" c='g', s=30, marker='D', alpha=1, linewidth=1.3, zorder=100)\n",
"\n",
" ax.margins(0,0,0)\n",
" fig.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"