From dbc0fcfc54630c13a292587e50fa812ee16390a6 Mon Sep 17 00:00:00 2001 From: skorzewski Date: Sat, 19 Dec 2020 13:09:41 +0100 Subject: [PATCH] =?UTF-8?q?Modu=C5=82y=202=20i=203=20-=20szkielet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 1.ipynb => 1. Jupyter - podstawy.ipynb | 0 ... wykresy i prezentacje multimedialne.ipynb | 86 + ...przykład materiałów dydaktycznych.ipynb | 42666 ++++++++++++++++ data01_test.csv | 17 + data01_train.csv | 80 + 5 files changed, 42849 insertions(+) rename 1.ipynb => 1. Jupyter - podstawy.ipynb (100%) create mode 100644 2. Jupyter - interaktywne wykresy i prezentacje multimedialne.ipynb create mode 100644 3. Jupyter - przykład materiałów dydaktycznych.ipynb create mode 100644 data01_test.csv create mode 100644 data01_train.csv diff --git a/1.ipynb b/1. Jupyter - podstawy.ipynb similarity index 100% rename from 1.ipynb rename to 1. Jupyter - podstawy.ipynb diff --git a/2. Jupyter - interaktywne wykresy i prezentacje multimedialne.ipynb b/2. Jupyter - interaktywne wykresy i prezentacje multimedialne.ipynb new file mode 100644 index 0000000..851db08 --- /dev/null +++ b/2. Jupyter - interaktywne wykresy i prezentacje multimedialne.ipynb @@ -0,0 +1,86 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Przygotowanie innowacyjnych materiałów szkoleniowych i dokumentacji wewnętrznych w obszarze IT\n", + "# 2. Tworzenie materiałów szkoleniowych w Jupyter Notebook - interaktywne wykresy i prezentacje multimedialne" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1. Umieszczanie wzorów matematycznych" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2. Tworzenie wykresów" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tworzenie statycznych wykresów" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tworzenie interaktywnych wykresów" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.3. Wyświetlanie materiałów Jupyter Notebook w formie prezentacji multimedialnej" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Materiały przygotowane w formacie `.ipynb` można wyświetlać w formie prezentacji multimedialnej.\n", + "\n", + "Jest to możliwe dzięki narzędziu [**RISE**](https://rise.readthedocs.io).\n", + "\n", + "RISE (*Reveal.js Ipython Slideshow Extension*) jest rozszerzeniem do Jupytera umożliwiającym wyświetlanie notatników w trybie prezentacji w oparciu o framework [**Reveal.js**](https://revealjs.com/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/3. Jupyter - przykład materiałów dydaktycznych.ipynb b/3. Jupyter - przykład materiałów dydaktycznych.ipynb new file mode 100644 index 0000000..79f636a --- /dev/null +++ b/3. Jupyter - przykład materiałów dydaktycznych.ipynb @@ -0,0 +1,42666 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Przygotowanie innowacyjnych materiałów szkoleniowych i dokumentacji wewnętrznych w obszarze IT\n", + "# 3. Przykład materiałów szkoleniowych w Jupyter Notebook" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Poniższe materiały pochodzą z kursu \"Uczenie maszynowe i big data\"." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Funkcja kosztu" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Zadanie\n", + "Znając $x$ – ludność miasta (w dziesiątkach tysięcy mieszkańców),\n", + "należy przewidzieć $y$ – dochód firmy transportowej (w dziesiątkach tysięcy dolarów).\n", + "\n", + "(Dane pochodzą z kursu „Machine Learning”, Andrew Ng, Coursera)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib\n", + "import matplotlib.pyplot as pl\n", + "import ipywidgets as widgets\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'svg'\n", + "\n", + "from IPython.display import display, Math, Latex" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Dane" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.1101,17.592\n", + "\n", + "5.5277,9.1302\n", + "\n", + "8.5186,13.662\n", + "\n", + "7.0032,11.854\n", + "\n", + "5.8598,6.8233\n", + "\n", + "8.3829,11.886\n", + "\n", + "7.4764,4.3483\n", + "\n", + "8.5781,12\n", + "\n", + "6.4862,6.5987\n", + "\n", + "5.0546,3.8166\n", + "\n" + ] + } + ], + "source": [ + "with open('data01_train.csv') as data:\n", + " for line in data.readlines()[:10]:\n", + " print(line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Wczytanie danych" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x = [6.1101, 5.5277, 8.5186, 7.0032, 5.8598, 8.3829, 7.4764, 8.5781, 6.4862, 5.0546]\n", + "y = [17.592, 9.1302, 13.662, 11.854, 6.8233, 11.886, 4.3483, 12.0, 6.5987, 3.8166]\n" + ] + } + ], + "source": [ + "import csv\n", + "\n", + "reader = csv.reader(open('data01_train.csv'), delimiter=',')\n", + "\n", + "x = list()\n", + "y = list()\n", + "for xi, yi in reader:\n", + " x.append(float(xi))\n", + " y.append(float(yi)) \n", + " \n", + "print('x = {}'.format(x[:10])) \n", + "print('y = {}'.format(y[:10]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Hipoteza i parametry modelu" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Jak przewidzieć $y$ na podstawie danego $x$? W celu odpowiedzi na to pytanie będziemy starać się znaleźć taką funkcję $h(x)$, która będzie najlepiej obrazować zależność między $x$ a $y$, tj. $y \\sim h(x)$.\n", + "\n", + "Zacznijmy od najprostszego przypadku, kiedy $h(x)$ jest po prostu funkcją liniową. Ogólny wzór funkcji liniowej to\n", + "$$ h(x) = a \\, x + b $$\n", + "\n", + "Pamiętajmy jednak, że współczynniki $a$ i $b$ nie są w tej chwili dane z góry – naszym zadaniem właśnie będzie znalezienie takich ich wartości, żeby $h(x)$ było „możliwie jak najbliżej” $y$ (co właściwie oznacza to sformułowanie, wyjaśnię potem).\n", + "\n", + "Poszukiwaną funkcję $h$ będziemy nazywać **funkcją hipotezy**, a jej współczynniki – **parametrami modelu**.\n", + "\n", + "W teorii uczenia maszynowego parametry modelu oznacza się na ogół grecką literą $\\theta$ z odpowiednimi indeksami, dlatego powyższy wzór opisujący liniową funkcję hipotezy zapiszemy jako\n", + "$$ h(x) = \\theta_0 + \\theta_1 x $$\n", + "\n", + "**Parametry modelu** tworzą wektor, który oznaczymy po prostu przez $\\theta$:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ \\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right] $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Żeby podkreślić fakt, że funkcja hipotezy zależy od parametrów modelu, będziemy pisać $h_\\theta$ zamiast $h$:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x $$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Przyjrzyjmy się teraz, jak wyglądają dane, które mamy modelować:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Funkcje rysujące wykres kropkowy oraz prostą regresyjną\n", + "\n", + "def regdots(x, y): \n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111)\n", + " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", + " ax.scatter(x, y, c='r', s=50, label='Dane')\n", + " \n", + " ax.set_xlabel(u'Wielkość miejscowości [dzies. tys. mieszk.]')\n", + " ax.set_ylabel(u'Dochód firmy [dzies. tys. dolarów]')\n", + " ax.margins(.05, .05)\n", + " pl.ylim(min(y) - 1, max(y) + 1)\n", + " pl.xlim(min(x) - 1, max(x) + 1)\n", + " return fig\n", + "\n", + "def regline(fig, fun, theta, x):\n", + " ax = fig.axes[0]\n", + " x0, x1 = min(x), max(x)\n", + " X = [x0, x1]\n", + " Y = [fun(theta, x) for x in X]\n", + " ax.plot(X, Y, linewidth='2',\n", + " label=(r'$y={theta0}{op}{theta1}x$'.format(\n", + " theta0=theta[0],\n", + " theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n", + " op='+' if theta[1] >= 0 else '-')))\n", + "\n", + "def legend(fig):\n", + " ax = fig.axes[0]\n", + " handles, labels = ax.get_legend_handles_labels()\n", + " # try-except block is a fix for a bug in Poly3DCollection\n", + " try:\n", + " fig.legend(handles, labels, fontsize='15', loc='lower right')\n", + " except AttributeError:\n", + " pass\n", + "\n", + "fig = regdots(x,y)\n", + "legend(fig)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Hipoteza: funkcja liniowa jednej zmiennej\n", + "\n", + "def h(theta, x):\n", + " return theta[0] + theta[1] * x" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Przygotowanie interaktywnego wykresu\n", + "\n", + "sliderTheta01 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n", + "sliderTheta11 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n", + "\n", + "def slide1(theta0, theta1):\n", + " fig = regdots(x, y)\n", + " regline(fig, h, [theta0, theta1], x)\n", + " legend(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Na poniższym wykresie możesz spróbować ręcznie dopasować parametry modelu $\\theta_0$ i $\\theta_1$ tak, aby jak najlepiej modelowały zależność między $x$ a $y$:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "64432dd82e97458f86f45f4ed72c6616", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "widgets.interact_manual(slide1, theta0=sliderTheta01, theta1=sliderTheta11)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Skąd wiadomo, że przewidywania modelu (wartości funkcji $h(x)$) zgadzaja się z obserwacjami (wartości $y$)?\n", + "\n", + "Aby to zmierzyć wprowadzimy pojęcie funkcji kosztu." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Funkcja kosztu" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Funkcję kosztu zdefiniujemy w taki sposób, żeby odzwierciedlała ona różnicę między przewidywaniami modelu a obserwacjami.\n", + "\n", + "Jedną z możliwosci jest zdefiniowanie funkcji kosztu jako wartość **błędu średniokwadratowego** (metoda najmniejszych kwadratów, *root-mean-square error, RMSE*):" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ J(\\theta) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "gdzie $m$ jest liczbą wszystkich przykładów (obserwacji), czyli wielkością zbioru danych uczących.\n", + "\n", + "W powyższym wzorze sumujemy kwadraty różnic między przewidywaniami modelu ($h_\\theta \\left( x^{(i)} \\right)$) a obserwacjami ($y^{(i)}$) po wszystkich przykładach $i$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Teraz nasze zadanie sprowadza się do tego, że będziemy szukać takich parametrów $\\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right]$, które minimalizują fukcję kosztu $J(\\theta)$:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ \\hat\\theta = \\mathop{\\arg\\min}_{\\theta} J(\\theta) $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ \\theta \\in \\mathbb{R}^2, \\quad J \\colon \\mathbb{R}^2 \\to \\mathbb{R} $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Proszę zwrócić uwagę, że dziedziną funkcji kosztu jest zbiór wszystkich możliwych wartości parametrów $\\theta$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "$$ J(\\theta_0, \\theta_1) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( \\theta_0 + \\theta_1 x^{(i)} - y^{(i)} \\right) ^2 $$" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "def J(h, theta, x, y):\n", + " \"\"\"Funkcja kosztu\"\"\"\n", + " m = len(y)\n", + " return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Oblicz wartość funkcji kosztu i pokaż na wykresie\n", + "\n", + "def regline2(fig, fun, theta, xx, yy):\n", + " \"\"\"Rysuj regresję liniową\"\"\"\n", + " ax = fig.axes[0]\n", + " x0, x1 = min(xx), max(xx)\n", + " X = [x0, x1]\n", + " Y = [fun(theta, x) for x in X]\n", + " cost = J(fun, theta, xx, yy)\n", + " ax.plot(X, Y, linewidth='2',\n", + " label=(r'$y={theta0}{op}{theta1}x, \\; J(\\theta)={cost:.3}$'.format(\n", + " theta0=theta[0],\n", + " theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n", + " op='+' if theta[1] >= 0 else '-',\n", + " cost=cost)))\n", + "\n", + "sliderTheta02 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n", + "sliderTheta12 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n", + "\n", + "def slide2(theta0, theta1):\n", + " fig = regdots(x, y)\n", + " regline2(fig, h, [theta0, theta1], x, y)\n", + " legend(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Poniższy interaktywny wykres pokazuje wartość funkcji kosztu $J(\\theta)$. Czy teraz łatwiej jest dobrać parametry modelu?" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1d46c4b0e0234d1d9e95f9ef7aee61cc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "widgets.interact_manual(slide2, theta0=sliderTheta02, theta1=sliderTheta12)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Funkcja kosztu jako funkcja zmiennej $\\theta$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Funkcja kosztu zdefiniowana jako RMSE jest funkcją zmiennej wektorowej $\\theta$, czyli funkcją dwóch zmiennych rzeczywistych: $\\theta_0$ i $\\theta_1$.\n", + " \n", + "Zobaczmy, jak wygląda jej wykres." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Wykres funkcji kosztu dla ustalonego theta_1=1.0\n", + "\n", + "def costfun(fun, x, y):\n", + " return lambda theta: J(fun, theta, x, y)\n", + "\n", + "def costplot(hypothesis, x, y, theta1=1.0):\n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111)\n", + " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", + " ax.set_xlabel(r'$\\theta_0$')\n", + " ax.set_ylabel(r'$J(\\theta)$')\n", + " j = costfun(hypothesis, x, y)\n", + " fun = lambda theta0: j([theta0, theta1])\n", + " X = np.arange(-10, 10, 0.1)\n", + " Y = [fun(x) for x in X]\n", + " ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta_0, {theta1})$'.format(theta1=theta1)))\n", + " return fig\n", + "\n", + "def slide3(theta1):\n", + " fig = costplot(h, x, y, theta1)\n", + " legend(fig)\n", + "\n", + "sliderTheta13 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=1.0, description=r'$\\theta_1$', width=300)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a5409cb710364e6691dd67525cb971f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=1.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "widgets.interact_manual(slide3, theta1=sliderTheta13)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Wykres funkcji kosztu względem theta_0 i theta_1\n", + "\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "import pylab\n", + "\n", + "%matplotlib inline\n", + "\n", + "def costplot3d(hypothesis, x, y, show_gradient=False):\n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111, projection='3d')\n", + " fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)\n", + " ax.set_xlabel(r'$\\theta_0$')\n", + " ax.set_ylabel(r'$\\theta_1$')\n", + " ax.set_zlabel(r'$J(\\theta)$')\n", + " \n", + " j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n", + " X = np.arange(-10, 10.1, 0.1)\n", + " Y = np.arange(-1, 4.1, 0.1)\n", + " X, Y = np.meshgrid(X, Y)\n", + " Z = np.array([[J(hypothesis, [theta0, theta1], x, y) \n", + " for theta0, theta1 in zip(xRow, yRow)] \n", + " for xRow, yRow in zip(X, Y)])\n", + " \n", + " ax.plot_surface(X, Y, Z, rstride=2, cstride=8, linewidth=0.5,\n", + " alpha=0.5, cmap='jet', zorder=0,\n", + " label=r\"$J(\\theta)$\")\n", + " ax.view_init(elev=20., azim=-150)\n", + "\n", + " ax.set_xlim3d(-10, 10);\n", + " ax.set_ylim3d(-1, 4);\n", + " ax.set_zlim3d(-100, 800);\n", + "\n", + " N = range(0, 800, 20)\n", + " pl.contour(X, Y, Z, N, zdir='z', offset=-100, cmap='coolwarm', alpha=1)\n", + " \n", + " ax.plot([-3.89578088] * 2,\n", + " [ 1.19303364] * 2,\n", + " [-100, 4.47697137598], \n", + " color='red', alpha=1, linewidth=1.3, zorder=100, linestyle='dashed',\n", + " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", + " ax.scatter([-3.89578088] * 2,\n", + " [ 1.19303364] * 2,\n", + " [-100, 4.47697137598], \n", + " c='r', s=80, marker='x', alpha=1, linewidth=1.3, zorder=100, \n", + " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", + " \n", + " if show_gradient:\n", + " ax.plot([3.0, 1.1],\n", + " [3.0, 2.4],\n", + " [263.0, 125.0], \n", + " color='green', alpha=1, linewidth=1.3, zorder=100)\n", + " ax.scatter([3.0],\n", + " [3.0],\n", + " [263.0], \n", + " c='g', s=30, marker='D', alpha=1, linewidth=1.3, zorder=100)\n", + "\n", + " ax.margins(0,0,0)\n", + " fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "costplot3d(h, x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Na powyższym wykresie poszukiwane minimum funkcji kosztu oznaczone jest czerwonym krzyżykiem.\n", + "\n", + "Możemy też zobaczyć rzut powyższego trójwymiarowego wykresu na płaszczyznę $(\\theta_0, \\theta_1)$ poniżej:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "def costplot2d(hypothesis, x, y, gradient_values=[], nohead=False):\n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111)\n", + " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", + " ax.set_xlabel(r'$\\theta_0$')\n", + " ax.set_ylabel(r'$\\theta_1$')\n", + " \n", + " j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n", + " X = np.arange(-10, 10.1, 0.1)\n", + " Y = np.arange(-1, 4.1, 0.1)\n", + " X, Y = np.meshgrid(X, Y)\n", + " Z = np.array([[J(hypothesis, [theta0, theta1], x, y) \n", + " for theta0, theta1 in zip(xRow, yRow)] \n", + " for xRow, yRow in zip(X, Y)])\n", + " \n", + " N = range(0, 800, 20)\n", + " pl.contour(X, Y, Z, N, cmap='coolwarm', alpha=1)\n", + "\n", + " ax.scatter([-3.89578088], [1.19303364], c='r', s=80, marker='x',\n", + " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", + " \n", + " if len(gradient_values) > 0:\n", + " prev_theta = gradient_values[0][1]\n", + " ax.scatter([prev_theta[0]], [prev_theta[1]],\n", + " c='g', s=30, marker='D', zorder=100)\n", + " for cost, theta in gradient_values[1:]:\n", + " dtheta = [theta[0] - prev_theta[0], theta[1] - prev_theta[1]]\n", + " ax.arrow(prev_theta[0], prev_theta[1], dtheta[0], dtheta[1], \n", + " color='green', \n", + " head_width=(0.0 if nohead else 0.1), \n", + " head_length=(0.0 if nohead else 0.2),\n", + " zorder=100)\n", + " prev_theta = theta\n", + " \n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = costplot2d(h, x, y)\n", + "legend(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Cechy funkcji kosztu\n", + "* $J(\\theta)$ jest funkcją wypukłą\n", + "* $J(\\theta)$ posiada tylko jedno minimum lokalne" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Metoda gradientu prostego" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Metoda gradientu prostego\n", + "Metoda znajdowania minimów lokalnych." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Idea:\n", + " * Zacznijmy od dowolnego $\\theta$.\n", + " * Zmieniajmy powoli $\\theta$ tak, aby zmniejszać $J(\\theta)$, aż w końcu znajdziemy minimum." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "costplot3d(h, x, y, show_gradient=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Przykładowe wartości kolejnych przybliżeń (sztuczne)\n", + "\n", + "gv = [[_, [3.0, 3.0]], [_, [2.6, 2.4]], [_, [2.2, 2.0]], [_, [1.6, 1.6]], [_, [0.4, 1.2]]]\n", + "\n", + "# Przygotowanie interaktywnego wykresu\n", + "\n", + "sliderSteps1 = widgets.IntSlider(min=0, max=3, step=1, value=0, description='kroki', width=300)\n", + "\n", + "def slide4(steps):\n", + " costplot2d(h, x, y, gradient_values=gv[:steps+1])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "widgets.interact(slide4, steps=sliderSteps1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Metoda gradientu prostego\n", + "W każdym kroku będziemy aktualizować parametry $\\theta_j$:\n", + "\n", + "$$ \\theta_j := \\theta_j - \\alpha \\frac{\\partial}{\\partial \\theta_j} J(\\theta) \\quad \\mbox{ dla każdego } j $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Współczynnik $\\alpha$ nazywamy *długością kroku* lub *współczynnikiem szybkości uczenia* (_learning rate_)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "$$ \\begin{array}{rcl}\n", + "\\dfrac{\\partial}{\\partial \\theta_j} J(\\theta)\n", + " & = & \\dfrac{\\partial}{\\partial \\theta_j} \\dfrac{1}{2m} \\displaystyle\\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 \\\\\n", + " & = & 2 \\cdot \\dfrac{1}{2m} \\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\\\\n", + " & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( \\displaystyle\\sum_{i=0}^n \\theta_i x_i^{(i)} - y^{(i)} \\right)\\\\\n", + " & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) -y^{(i)} \\right) x_j^{(i)} \\\\\n", + "\\end{array} $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Czyli dla regresji liniowej jednej zmiennej:\n", + "\n", + "$$ h_\\theta(x) = \\theta_0 + \\theta_1x $$\n", + "\n", + "w każdym kroku będziemy aktualizować:\n", + "\n", + "$$\n", + "\\begin{array}{rcl}\n", + "\\theta_0 & := & \\theta_0 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) \\\\ \n", + "\\theta_1 & := & \\theta_1 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) x^{(i)}\\\\ \n", + "\\end{array}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "###### Uwaga!\n", + " * W każdym kroku aktualizujemy *jednocześnie* $\\theta_0$ i $\\theta_1$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + " * Kolejne kroki wykonujemy aż uzyskamy zbieżność" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Metoda gradientu prostego – implementacja" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Wyświetlanie macierzy w LaTeX-u\n", + "\n", + "def LatexMatrix(matrix):\n", + " ltx = r'\\left[\\begin{array}'\n", + " m, n = matrix.shape\n", + " ltx += '{' + (\"r\" * n) + '}'\n", + " for i in range(m):\n", + " ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n", + " ltx += r'\\end{array}\\right]'\n", + " return ltx" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "def gradient_descent(h, cost_fun, theta, x, y, alpha, eps):\n", + " current_cost = cost_fun(h, theta, x, y)\n", + " log = [[current_cost, theta]] # log przechowuje wartości kosztu i parametrów\n", + " m = len(y)\n", + " while True:\n", + " new_theta = [\n", + " theta[0] - alpha/float(m) * sum(h(theta, x[i]) - y[i]\n", + " for i in range(m)), \n", + " theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n", + " for i in range(m))]\n", + " theta = new_theta # jednoczesna aktualizacja - używamy zmiennej tymaczasowej\n", + " try:\n", + " current_cost, prev_cost = cost_fun(h, theta, x, y), current_cost\n", + " except OverflowError:\n", + " break \n", + " if abs(prev_cost - current_cost) <= eps:\n", + " break \n", + " log.append([current_cost, theta])\n", + " return theta, log" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle \\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}0.0525 \\\\ 0.8015 \\\\ \\end{array}\\right] \\quad J(\\theta) = 6.0273 \\quad \\textrm{po 3604 iteracjach}$" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.00001, eps=0.0001)\n", + "\n", + "display(Math(r'\\large\\textrm{Wynik:}\\quad \\theta = ' + \n", + " LatexMatrix(np.matrix(best_theta).reshape(2,1)) + \n", + " (r' \\quad J(\\theta) = %.4f' % log[-1][0]) \n", + " + r' \\quad \\textrm{po %d iteracjach}' % len(log))) " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Przygotowanie interaktywnego wykresu\n", + "\n", + "sliderSteps2 = widgets.IntSlider(min=0, max=500, step=1, value=1, description='kroki', width=300)\n", + "\n", + "def slide5(steps):\n", + " costplot2d(h, x, y, gradient_values=log[:steps+1], nohead=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "74447996dfc8460cb3230139ddb584c4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(IntSlider(value=1, description='kroki', max=500), Button(description='Run Interact', sty…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "widgets.interact_manual(slide5, steps=sliderSteps2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Współczynnik $\\alpha$ (długość kroku)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Tempo zbieżności metody gradientu prostego możemy regulować za pomocą parametru $\\alpha$, pamiętając, że:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + " * Jeżeli długość kroku jest zbyt mała, algorytm może działać zbyt wolno." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + " * Jeżeli długość kroku jest zbyt duża, algorytm może nie być zbieżny." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Predykcja wyników" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Zbudowaliśmy model, dzięki któremu wiemy, jaka jest zależność między dochodem firmy transportowej ($y$) a ludnością miasta ($x$).\n", + "\n", + "Wróćmy teraz do postawionego na początku wykładu pytania: jak przewidzieć dochód firmy transportowej w mieście o danej wielkości?\n", + "\n", + "Odpowiedź polega po prostu na zastosowaniu funkcji $h$ z wyznaczonymi w poprzednim kroku parametrami $\\theta$.\n", + "\n", + "Na przykład, jeżeli miasto ma $536\\,000$ ludności, to $x = 53.6$ (bo dane trenujące były wyrażone w dziesiątkach tysięcy mieszkańców, a $536\\,000 = 53.6 \\cdot 10\\,000$) i możemy użyć znalezionych parametrów $\\theta$, by wykonać następujące obliczenia:\n", + "$$ \\hat{y} \\, = \\, h_\\theta(x) \\, = \\, \\theta_0 + \\theta_1 \\, x \\, = \\, 0.0494 + 0.7591 \\cdot 53.6 \\, = \\, 40.7359 $$\n", + "\n", + "Czyli używając zdefiniowanych wcześniej funkcji:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "43.01265462001823\n" + ] + } + ], + "source": [ + "example_x = 53.6\n", + "predicted_y = h(best_theta, example_x)\n", + "print(predicted_y) ## taki jest przewidywany dochód tej firmy transportowej w 536-tysięcznym mieście" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Ewaluacja modelu" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Jak ocenić jakość stworzonego przez nas modelu?\n", + "\n", + " * Trzeba sprawdzić, jak przewidywania modelu zgadzają się z oczekiwaniami!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Czy możemy w tym celu użyć danych, których użyliśmy do wytrenowania modelu?\n", + "**NIE!**\n", + "\n", + " * Istotą uczenia maszynowego jest budowanie modeli/algorytmów, które dają dobre przewidywania dla **nieznanych** danych – takich, z którymi algorytm nie miał jeszcze styczności! Nie sztuką jest przewidywać rzeczy, które juz sie zna.\n", + " * Dlatego testowanie/ewaluowanie modelu na zbiorze uczącym mija się z celem i jest nieprzydatne.\n", + " * Do ewaluacji modelu należy użyć oddzielnego zbioru danych.\n", + " * **Dane uczące i dane testowe zawsze powinny stanowić oddzielne zbiory!**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Na wykładzie *5. Dobre praktyki w uczeniu maszynowym* dowiesz się, jak podzielić posiadane dane na zbiór uczący i zbiór testowy.\n", + "\n", + "Tutaj, na razie, do ewaluacji użyjemy specjalnie przygotowanego zbioru testowego." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Jako metrykę ewaluacji wykorzystamy znany nam już błąd średniokwadratowy (RMSE):" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(expected, predicted):\n", + " \"\"\"Błąd średniokwadratowy\"\"\"\n", + " m = len(expected)\n", + " if len(predicted) != m:\n", + " raise Exception('Wektory mają różne długości!')\n", + " return 1.0 / (2 * m) * sum((expected[i] - predicted[i])**2 for i in range(m))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.283571974307861\n" + ] + } + ], + "source": [ + "# Wczytwanie danych testowych z pliku za pomocą numpy\n", + "\n", + "test_data = np.loadtxt('data01_test.csv', delimiter=',')\n", + "x_test = test_data[:, 0]\n", + "y_test = test_data[:, 1]\n", + "\n", + "# Obliczenie przewidywań modelu\n", + "y_pred = h(best_theta, x_test)\n", + "\n", + "# Obliczenie RMSE na zbiorze testowym (im mniejszy RMSE, tym lepiej!)\n", + "evaluation_result = rmse(y_test, y_pred)\n", + "\n", + "print(evaluation_result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Otrzymana wartość mówi nam o tym, jak dobry jest stworzony przez nas model.\n", + "\n", + "W przypadku metryki RMSE im mniejsz wartość, tym lepiej.\n", + "\n", + "W ten sposób możemy np. porównywać różne modele." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Regresja liniowa wielu zmiennych" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Do przewidywania wartości $y$ możemy użyć więcej niż jednej cechy $x$:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Przykład – ceny mieszkań" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'data02_train.tsv'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mcsv\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mreader\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcsv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreader\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'data02_train.tsv'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdelimiter\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'\\t'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data02_train.tsv'" + ] + } + ], + "source": [ + "import csv\n", + "\n", + "reader = csv.reader(open('data02_train.tsv'), delimiter='\\t')\n", + "for i, row in enumerate(list(reader)[:10]):\n", + " if i == 0:\n", + " print(' '.join(['{}: {:8}'.format('x' + str(j) if j > 0 else 'y ', entry)\n", + " for j, entry in enumerate(row)]))\n", + " else:\n", + " print(' '.join(['{:12}'.format(entry) for entry in row]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ x^{(2)} = ({\\rm \"False\"}, 3, 2, {\\rm \"Sołacz\"}, 62), \\quad x_3^{(2)} = 2 $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Hipoteza" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "W naszym przypadku (wybraliśmy 5 cech):\n", + "\n", + "$$ h_\\theta(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5 $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "W ogólności ($n$ cech):\n", + "\n", + "$$ h_\\theta(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\ldots + \\theta_n x_n $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Jeżeli zdefiniujemy $x_0 = 1$, będziemy mogli powyższy wzór zapisać w bardziej kompaktowy sposób:\n", + "\n", + "$$\n", + "\\begin{array}{rcl}\n", + "h_\\theta(x)\n", + " & = & \\theta_0 x_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\ldots + \\theta_n x_n \\\\\n", + " & = & \\displaystyle\\sum_{i=0}^{n} \\theta_i x_i \\\\\n", + " & = & \\theta^T \\, x \\\\\n", + " & = & x^T \\, \\theta \\\\\n", + "\\end{array}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Metoda gradientu prostego – notacja macierzowa" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Metoda gradientu prostego przyjmie bardzo elegancką formę, jeżeli do jej zapisu użyjemy wektorów i macierzy." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$\n", + "X=\\left[\\begin{array}{cc}\n", + "1 & \\left( \\vec x^{(1)} \\right)^T \\\\\n", + "1 & \\left( \\vec x^{(2)} \\right)^T \\\\\n", + "\\vdots & \\vdots\\\\\n", + "1 & \\left( \\vec x^{(m)} \\right)^T \\\\\n", + "\\end{array}\\right] \n", + "= \\left[\\begin{array}{cccc}\n", + "1 & x_1^{(1)} & \\cdots & x_n^{(1)} \\\\\n", + "1 & x_1^{(2)} & \\cdots & x_n^{(2)} \\\\\n", + "\\vdots & \\vdots & \\ddots & \\vdots\\\\\n", + "1 & x_1^{(m)} & \\cdots & x_n^{(m)} \\\\\n", + "\\end{array}\\right]\n", + "\\quad\n", + "\\vec{y} = \n", + "\\left[\\begin{array}{c}\n", + "y^{(1)}\\\\\n", + "y^{(2)}\\\\\n", + "\\vdots\\\\\n", + "y^{(m)}\\\\\n", + "\\end{array}\\right]\n", + "\\quad\n", + "\\theta = \\left[\\begin{array}{c}\n", + "\\theta_0\\\\\n", + "\\theta_1\\\\\n", + "\\vdots\\\\\n", + "\\theta_n\\\\\n", + "\\end{array}\\right]\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Wersje macierzowe funkcji rysowania wykresów punktowych oraz krzywej regresyjnej\n", + "\n", + "def hMx(theta, X):\n", + " return X * theta\n", + "\n", + "def regdotsMx(X, y): \n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111)\n", + " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", + " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", + " \n", + " ax.set_xlabel('Populacja')\n", + " ax.set_ylabel('Zysk')\n", + " ax.margins(.05, .05)\n", + " pl.ylim(y.min() - 1, y.max() + 1)\n", + " pl.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", + " return fig\n", + "\n", + "def reglineMx(fig, fun, theta, X):\n", + " ax = fig.axes[0]\n", + " x0, x1 = np.min(X[:, 1]), np.max(X[:, 1])\n", + " L = [x0, x1]\n", + " LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n", + " ax.plot(L, fun(theta, LX), linewidth='2',\n", + " label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n", + " theta0=float(theta[0][0]),\n", + " theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n", + " op='+' if theta[1][0] >= 0 else '-')))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Wczytwanie danych z pliku za pomocą numpy – regresja liniowa wielu zmiennych – notacja macierzowa\n", + "\n", + "import pandas\n", + "\n", + "data = pandas.read_csv('data02_train.tsv', delimiter='\\t', usecols=['price', 'rooms', 'floor', 'sqrMetres'])\n", + "m, n_plus_1 = data.values.shape\n", + "n = n_plus_1 - 1\n", + "Xn = data.values[:, 1:].reshape(m, n)\n", + "\n", + "# Dodaj kolumnę jedynek do macierzy\n", + "XMx = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", + "yMx = np.matrix(data.values[:, 0]).reshape(m, 1)\n", + "\n", + "print(XMx[:5])\n", + "print(XMx.shape)\n", + "\n", + "print()\n", + "\n", + "print(yMx[:5])\n", + "print(yMx.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Funkcja kosztu – notacja macierzowa" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$J(\\theta)=\\dfrac{1}{2|\\vec y|}\\left(X\\theta-\\vec{y}\\right)^T\\left(X\\theta-\\vec{y}\\right)$$ \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "from IPython.display import display, Math, Latex\n", + "\n", + "def JMx(theta,X,y):\n", + " \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n", + " m = len(y)\n", + " J = 1.0 / (2.0 * m) * ((X * theta - y) . T * ( X * theta - y))\n", + " return J.item()\n", + "\n", + "thetaMx = np.matrix([10, 90, -1, 2.5]).reshape(4, 1) \n", + "\n", + "cost = JMx(thetaMx,XMx,yMx) \n", + "display(Math(r'\\Large J(\\theta) = %.4f' % cost))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Gradient – notacja macierzowa" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T\\left(X\\theta-\\vec y\\right)$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "from IPython.display import display, Math, Latex\n", + "\n", + "def dJMx(theta,X,y):\n", + " \"\"\"Wersja macierzowa gradientu funckji kosztu\"\"\"\n", + " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", + "\n", + "thetaMx = np.matrix([10, 90, -1, 2.5]).reshape(4, 1) \n", + "\n", + "display(Math(r'\\large \\theta = ' + LatexMatrix(thetaMx) + \n", + " r'\\quad' + r'\\large \\nabla J(\\theta) = ' \n", + " + LatexMatrix(dJMx(thetaMx,XMx,yMx))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Algorytm gradientu prostego – notacja macierzowa" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "$$ \\theta := \\theta - \\alpha \\, \\nabla J(\\theta) $$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n", + "\n", + "def GDMx(fJ, fdJ, theta, X, y, alpha, eps):\n", + " current_cost = fJ(theta, X, y)\n", + " log = [[current_cost, theta]]\n", + " while True:\n", + " theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n", + " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", + " if abs(prev_cost - current_cost) <= eps:\n", + " break\n", + " if current_cost > prev_cost:\n", + " print('Długość kroku (alpha) jest zbyt duża!')\n", + " break\n", + " log.append([current_cost, theta])\n", + " return theta, log\n", + "\n", + "thetaStartMx = np.zeros((n + 1, 1))\n", + "\n", + "# Zmieniamy wartości alpha (rozmiar kroku) oraz eps (kryterium stopu)\n", + "thetaBestMx, log = GDMx(JMx, dJMx, thetaStartMx, \n", + " XMx, yMx, alpha=0.0001, eps=0.1)\n", + "\n", + "######################################################################\n", + "display(Math(r'\\large\\textrm{Wynik:}\\quad \\theta = ' + \n", + " LatexMatrix(thetaBestMx) + \n", + " (r' \\quad J(\\theta) = %.4f' % log[-1][0]) \n", + " + r' \\quad \\textrm{po %d iteracjach}' % len(log))) " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Metoda gradientu prostego w praktyce" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Kryterium stopu" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Algorytm gradientu prostego polega na wykonywaniu określonych kroków w pętli. Pytanie brzmi: kiedy należy zatrzymać wykonywanie tej pętli?\n", + "\n", + "W każdej kolejnej iteracji wartość funkcji kosztu maleje o coraz mniejszą wartość.\n", + "Parametr `eps` określa, jaka wartość graniczna tej różnicy jest dla nas wystarczająca:\n", + "\n", + " * Im mniejsza wartość `eps`, tym dokładniejszy wynik, ale dłuższy czas działania algorytmu.\n", + " * Im większa wartość `eps`, tym krótszy czas działania algorytmu, ale mniej dokładny wynik." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Na wykresie zobaczymy porównanie regresji dla różnych wartości `eps`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "# Wczytwanie danych z pliku za pomocą numpy – wersja macierzowa\n", + "data = np.loadtxt('data01_train.csv', delimiter=',')\n", + "m, n_plus_1 = data.shape\n", + "n = n_plus_1 - 1\n", + "Xn = data[:, 0:n].reshape(m, n)\n", + "\n", + "# Dodaj kolumnę jedynek do macierzy\n", + "XMx = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", + "yMx = np.matrix(data[:, 1]).reshape(m, 1)\n", + "\n", + "thetaStartMx = np.zeros((2, 1))\n", + "\n", + "fig = regdotsMx(XMx, yMx)\n", + "theta_e1, log1 = GDMx(JMx, dJMx, thetaStartMx, XMx, yMx, alpha=0.01, eps=0.01) # niebieska linia\n", + "reglineMx(fig, hMx, theta_e1, XMx)\n", + "theta_e2, log2 = GDMx(JMx, dJMx, thetaStartMx, XMx, yMx, alpha=0.01, eps=0.000001) # pomarańczowa linia\n", + "reglineMx(fig, hMx, theta_e2, XMx)\n", + "legend(fig)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "display(Math(r'\\theta_{10^{-2}} = ' + LatexMatrix(theta_e1) +\n", + " r'\\quad\\theta_{10^{-6}} = ' + LatexMatrix(theta_e2)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Długość kroku ($\\alpha$)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Jak zmienia się koszt w kolejnych krokach w zależności od alfa\n", + "\n", + "def costchangeplot(logs):\n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111)\n", + " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", + " ax.set_xlabel('krok')\n", + " ax.set_ylabel(r'$J(\\theta)$')\n", + "\n", + " X = np.arange(0, 500, 1)\n", + " Y = [logs[step][0] for step in X]\n", + " ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta)$'))\n", + " return fig\n", + "\n", + "def slide7(alpha):\n", + " best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=alpha, eps=0.0001)\n", + " fig = costchangeplot(log)\n", + " legend(fig)\n", + "\n", + "sliderAlpha1 = widgets.FloatSlider(min=0.01, max=0.03, step=0.001, value=0.02, description=r'$\\alpha$', width=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "widgets.interact_manual(slide7, alpha=sliderAlpha1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Normalizacja danych" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Normalizacja danych to proces, który polega na dostosowaniu danych wejściowych w taki sposób, żeby ułatwić działanie algorytmowi gradientu prostego.\n", + "\n", + "Wyjaśnię to na przykladzie." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "Użyjemy danych z „Gratka flats challenge 2017”.\n", + "\n", + "Rozważmy model $h(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2$, w którym cena mieszkania prognozowana jest na podstawie liczby pokoi $x_1$ i metrażu $x_2$:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "# Wczytanie danych przy pomocy biblioteki pandas\n", + "import pandas\n", + "alldata = pandas.read_csv('data_flats.tsv', header=0, sep='\\t',\n", + " usecols=['price', 'rooms', 'sqrMetres'])\n", + "alldata[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Funkcja, która pokazuje wartości minimalne i maksymalne w macierzy X\n", + "\n", + "def show_mins_and_maxs(XMx):\n", + " mins = np.amin(XMx, axis=0).tolist()[0] # wartości minimalne\n", + " maxs = np.amax(XMx, axis=0).tolist()[0] # wartości maksymalne\n", + " for i, (xmin, xmax) in enumerate(zip(mins, maxs)):\n", + " display(Math(\n", + " r'${:.2F} \\leq x_{} \\leq {:.2F}$'.format(xmin, i, xmax)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Przygotowanie danych\n", + "\n", + "import numpy as np\n", + "\n", + "%matplotlib inline\n", + "\n", + "data2 = np.matrix(alldata[['rooms', 'sqrMetres', 'price']])\n", + "\n", + "m, n_plus_1 = data2.shape\n", + "n = n_plus_1 - 1\n", + "Xn = data2[:, 0:n]\n", + "\n", + "XMx2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", + "yMx2 = np.matrix(data2[:, -1]).reshape(m, 1) / 1000.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Cechy w danych treningowych przyjmują wartości z zakresu:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "show_mins_and_maxs(XMx2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Jak widzimy, $x_2$ przyjmuje wartości dużo większe niż $x_1$.\n", + "Powoduje to, że wykres funkcji kosztu jest bardzo „spłaszczony” wzdłuż jednej z osi:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "def contour_plot(X, y, rescale=10**8):\n", + " theta0_vals = np.linspace(-100000, 100000, 100)\n", + " theta1_vals = np.linspace(-100000, 100000, 100)\n", + "\n", + " J_vals = np.zeros(shape=(theta0_vals.size, theta1_vals.size))\n", + " for t1, element in enumerate(theta0_vals):\n", + " for t2, element2 in enumerate(theta1_vals):\n", + " thetaT = np.matrix([1.0, element, element2]).reshape(3,1)\n", + " J_vals[t1, t2] = JMx(thetaT, X, y) / rescale\n", + " \n", + " pl.figure()\n", + " pl.contour(theta0_vals, theta1_vals, J_vals.T, np.logspace(-2, 3, 20))\n", + " pl.xlabel(r'$\\theta_1$')\n", + " pl.ylabel(r'$\\theta_2$')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "contour_plot(XMx2, yMx2, rescale=10**10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Jeżeli funkcja kosztu ma kształt taki, jak na powyższym wykresie, to łatwo sobie wyobrazić, że znalezienie minimum lokalnego przy użyciu metody gradientu prostego musi stanowć nie lada wyzwanie: algorytm szybko znajdzie „rynnę”, ale „zjazd” wzdłuż „rynny” w poszukiwaniu minimum będzie odbywał się bardzo powoli.\n", + "\n", + "Jak temu zaradzić?\n", + "\n", + "Spróbujemy przekształcić dane tak, żeby funkcja kosztu miała „ładny”, regularny kształt." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Skalowanie" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Będziemy dążyć do tego, żeby każda z cech przyjmowała wartości w podobnym zakresie.\n", + "\n", + "W tym celu przeskalujemy wartości każdej z cech, dzieląc je przez wartość maksymalną:\n", + "\n", + "$$ \\hat{x_i}^{(j)} := \\frac{x_i^{(j)}}{\\max_j x_i^{(j)}} $$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "XMx2_scaled = XMx2 / np.amax(XMx2, axis=0)\n", + "\n", + "show_mins_and_maxs(XMx2_scaled)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "contour_plot(XMx2_scaled, yMx2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Normalizacja średniej" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Będziemy dążyć do tego, żeby dodatkowo średnia wartość każdej z cech była w okolicach $0$.\n", + "\n", + "W tym celu oprócz przeskalowania odejmiemy wartość średniej od wartości każdej z cech:\n", + "\n", + "$$ \\hat{x_i}^{(j)} := \\frac{x_i^{(j)} - \\mu_i}{\\max_j x_i^{(j)}} $$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "XMx2_norm = (XMx2 - np.mean(XMx2, axis=0)) / np.amax(XMx2, axis=0)\n", + "\n", + "show_mins_and_maxs(XMx2_norm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "contour_plot(XMx2_norm, yMx2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Teraz funkcja kosztu ma wykres o bardzo regularnym kształcie – algorytm gradientu prostego zastosowany w takim przypadku bardzo szybko znajdzie minimum funkcji kosztu." + ] + } + ], + "metadata": { + "celltoolbar": "Slideshow", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "livereveal": { + "start_slideshow_at": "selected", + "theme": "amu" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/data01_test.csv b/data01_test.csv new file mode 100644 index 0000000..72c4017 --- /dev/null +++ b/data01_test.csv @@ -0,0 +1,17 @@ +5.7292,0.47953 +5.1884,0.20421 +6.3557,0.67861 +9.7687,7.5435 +6.5159,5.3436 +8.5172,4.2415 +9.1802,6.7981 +6.002,0.92695 +5.5204,0.152 +5.0594,2.8214 +5.7077,1.8451 +7.6366,4.2959 +5.8707,7.2029 +5.3054,1.9869 +8.2934,0.14454 +13.394,9.0551 +5.4369,0.61705 diff --git a/data01_train.csv b/data01_train.csv new file mode 100644 index 0000000..ab04fdf --- /dev/null +++ b/data01_train.csv @@ -0,0 +1,80 @@ +6.1101,17.592 +5.5277,9.1302 +8.5186,13.662 +7.0032,11.854 +5.8598,6.8233 +8.3829,11.886 +7.4764,4.3483 +8.5781,12 +6.4862,6.5987 +5.0546,3.8166 +5.7107,3.2522 +14.164,15.505 +5.734,3.1551 +8.4084,7.2258 +5.6407,0.71618 +5.3794,3.5129 +6.3654,5.3048 +5.1301,0.56077 +6.4296,3.6518 +7.0708,5.3893 +6.1891,3.1386 +20.27,21.767 +5.4901,4.263 +6.3261,5.1875 +5.5649,3.0825 +18.945,22.638 +12.828,13.501 +10.957,7.0467 +13.176,14.692 +22.203,24.147 +5.2524,-1.22 +6.5894,5.9966 +9.2482,12.134 +5.8918,1.8495 +8.2111,6.5426 +7.9334,4.5623 +8.0959,4.1164 +5.6063,3.3928 +12.836,10.117 +6.3534,5.4974 +5.4069,0.55657 +6.8825,3.9115 +11.708,5.3854 +5.7737,2.4406 +7.8247,6.7318 +7.0931,1.0463 +5.0702,5.1337 +5.8014,1.844 +11.7,8.0043 +5.5416,1.0179 +7.5402,6.7504 +5.3077,1.8396 +7.4239,4.2885 +7.6031,4.9981 +6.3328,1.4233 +6.3589,-1.4211 +6.2742,2.4756 +5.6397,4.6042 +9.3102,3.9624 +9.4536,5.4141 +8.8254,5.1694 +5.1793,-0.74279 +21.279,17.929 +14.908,12.054 +18.959,17.054 +7.2182,4.8852 +8.2951,5.7442 +10.236,7.7754 +5.4994,1.0173 +20.341,20.992 +10.136,6.6799 +7.3345,4.0259 +6.0062,1.2784 +7.2259,3.3411 +5.0269,-2.6807 +6.5479,0.29678 +7.5386,3.8845 +5.0365,5.7014 +10.274,6.7526 +5.1077,2.0576