{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Uczenie maszynowe\n", "# 2. Regresja liniowa – część 1" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 2.1. Funkcja kosztu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Zadanie\n", "Znając $x$ – ludność miasta, należy przewidzieć $y$ – dochód firmy transportowej.\n", "\n", "(Dane pochodzą z kursu „Machine Learning”, Andrew Ng, Coursera)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "**Uwaga**: Ponieważ ten przykład ma być tak prosty, jak to tylko możliwe, ludność miasta podana jest w dziesiątkach tysięcy mieszkańców, a dochód firmy w dziesiątkach tysięcy dolarów. Dzięki temu funkcja kosztu obliczona w dalszej części wykładu będzie osiągać wartości, które łatwo przedstawić na wykresie." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import ipywidgets as widgets\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = \"svg\"\n", "\n", "from IPython.display import display, Math, Latex" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Dane" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " x y\n", "0 6.1101 17.59200\n", "1 5.5277 9.13020\n", "2 8.5186 13.66200\n", "3 7.0032 11.85400\n", "4 5.8598 6.82330\n", ".. ... ...\n", "75 6.5479 0.29678\n", "76 7.5386 3.88450\n", "77 5.0365 5.70140\n", "78 10.2740 6.75260\n", "79 5.1077 2.05760\n", "\n", "[80 rows x 2 columns]\n" ] } ], "source": [ "import pandas as pd\n", "\n", "data = pd.read_csv(\"data1_train.csv\", names=[\"x\", \"y\"])\n", "print(data)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "x = data[\"x\"].to_numpy()\n", "y = data[\"y\"].to_numpy()\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Hipoteza i parametry modelu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Jak przewidzieć $y$ na podstawie danego $x$? W celu odpowiedzi na to pytanie będziemy starać się znaleźć taką funkcję $h(x)$, która będzie najlepiej obrazować zależność między $x$ a $y$, tj. $y \\sim h(x)$.\n", "\n", "Zacznijmy od najprostszego przypadku, kiedy $h(x)$ jest po prostu funkcją liniową. Ogólny wzór funkcji liniowej to" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ h(x) = a \\, x + b $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Pamiętajmy jednak, że współczynniki $a$ i $b$ nie są w tej chwili dane z góry – naszym zadaniem właśnie będzie znalezienie takich ich wartości, żeby $h(x)$ było „możliwie jak najbliżej” $y$ (co właściwie oznacza to sformułowanie, wyjaśnię potem)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Poszukiwaną funkcję $h$ będziemy nazywać **funkcją hipotezy**, a jej współczynniki – **parametrami modelu**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "W teorii uczenia maszynowego parametry modelu oznacza się na ogół grecką literą $\\theta$ z odpowiednimi indeksami, dlatego powyższy wzór opisujący liniową funkcję hipotezy zapiszemy jako\n", "$$ h(x) = \\theta_0 + \\theta_1 x $$\n", "\n", "**Parametry modelu** tworzą wektor, który oznaczymy po prostu przez $\\theta$:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ \\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right] $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Żeby podkreślić fakt, że funkcja hipotezy zależy od parametrów modelu, będziemy pisać $h_\\theta$ zamiast $h$:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Przyjrzyjmy się teraz, jak wyglądają dane, które mamy modelować:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Na poniższym wykresie możesz spróbować ręcznie dopasować parametry modelu $\\theta_0$ i $\\theta_1$ tak, aby jak najlepiej modelowały zależność między $x$ a $y$:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Funkcje rysujące wykres kropkowy oraz prostą regresyjną\n", "\n", "\n", "def regdots(x, y):\n", " fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter(x, y, c=\"r\", label=\"Dane\")\n", "\n", " ax.set_xlabel(\"Wielkość miejscowości\")\n", " ax.set_ylabel(\"Dochód firmy\")\n", " ax.margins(0.05, 0.05)\n", " plt.ylim(min(y) - 1, max(y) + 1)\n", " plt.xlim(min(x) - 1, max(x) + 1)\n", " return fig\n", "\n", "\n", "def regline(fig, fun, theta, x):\n", " ax = fig.axes[0]\n", " x0, x1 = min(x), max(x)\n", " X = [x0, x1]\n", " Y = [fun(theta, x) for x in X]\n", " ax.plot(\n", " X,\n", " Y,\n", " linewidth=\"2\",\n", " label=(\n", " r\"$y={theta0}{op}{theta1}x$\".format(\n", " theta0=theta[0],\n", " theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n", " op=\"+\" if theta[1] >= 0 else \"-\",\n", " )\n", " ),\n", " )\n", "\n", "\n", "def legend(fig):\n", " ax = fig.axes[0]\n", " handles, labels = ax.get_legend_handles_labels()\n", " # try-except block is a fix for a bug in Poly3DCollection\n", " try:\n", " fig.legend(handles, labels, fontsize=\"15\", loc=\"lower right\")\n", " except AttributeError:\n", " pass\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-10-28T11:24:23.297410\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = regdots(x, y)\n", "legend(fig)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Hipoteza: funkcja liniowa jednej zmiennej\n", "\n", "\n", "def h(theta, x):\n", " return theta[0] + theta[1] * x\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przygotowanie interaktywnego wykresu\n", "\n", "sliderTheta01 = widgets.FloatSlider(\n", " min=-10, max=10, step=0.1, value=0, description=r\"$\\theta_0$\", width=300\n", ")\n", "sliderTheta11 = widgets.FloatSlider(\n", " min=-5, max=5, step=0.1, value=0, description=r\"$\\theta_1$\", width=300\n", ")\n", "\n", "\n", "def slide1(theta0, theta1):\n", " fig = regdots(x, y)\n", " regline(fig, h, [theta0, theta1], x)\n", " legend(fig)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0672501afb5c4fca818b52c3fdd274fd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide1, theta0=sliderTheta01, theta1=sliderTheta11)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Skąd wiadomo, że przewidywania modelu (wartości funkcji $h(x)$) zgadzaja się z obserwacjami (wartości $y$)?\n", "\n", "Aby to zmierzyć wprowadzimy pojęcie funkcji kosztu." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Funkcja kosztu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Funkcję kosztu zdefiniujemy w taki sposób, żeby odzwierciedlała ona różnicę między przewidywaniami modelu a obserwacjami.\n", "\n", "Jedną z możliwosci jest zdefiniowanie funkcji kosztu jako wartość **błędu średniokwadratowego** (metoda najmniejszych kwadratów, *mean-square error, MSE*).\n", "\n", "My zdefiniujemy funkcję kosztu jako *połowę* błędu średniokwadratowego w celu ułatwienia późniejszych obliczeń (obliczenie pochodnej funkcji kosztu w dalszej części wykładu). Możemy tak zrobić, ponieważ $\\frac{1}{2}$ jest stałą, a pomnożenie przez stałą nie wpływa na przebieg zmienności funkcji." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ J(\\theta) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "gdzie $m$ jest liczbą wszystkich przykładów (obserwacji), czyli wielkością zbioru danych uczących.\n", "\n", "W powyższym wzorze sumujemy kwadraty różnic między przewidywaniami modelu ($h_\\theta \\left( x^{(i)} \\right)$) a obserwacjami ($y^{(i)}$) po wszystkich przykładach $i$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Teraz nasze zadanie sprowadza się do tego, że będziemy szukać takich parametrów $\\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right]$, które minimalizują fukcję kosztu $J(\\theta)$:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ \\hat\\theta = \\mathop{\\arg\\min}_{\\theta} J(\\theta) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$ \\theta \\in \\mathbb{R}^2, \\quad J \\colon \\mathbb{R}^2 \\to \\mathbb{R} $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Proszę zwrócić uwagę, że dziedziną funkcji kosztu jest zbiór wszystkich możliwych wartości parametrów $\\theta$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ J(\\theta_0, \\theta_1) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( \\theta_0 + \\theta_1 x^{(i)} - y^{(i)} \\right) ^2 $$" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def J(h, theta, x, y):\n", " \"\"\"Funkcja kosztu\"\"\"\n", " m = len(y)\n", " return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i]) ** 2 for i in range(m))\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# Oblicz wartość funkcji kosztu i pokaż na wykresie\n", "\n", "\n", "def regline2(fig, fun, theta, xx, yy):\n", " \"\"\"Rysuj regresję liniową\"\"\"\n", " ax = fig.axes[0]\n", " x0, x1 = min(xx), max(xx)\n", " X = [x0, x1]\n", " Y = [fun(theta, x) for x in X]\n", " cost = J(fun, theta, xx, yy)\n", " ax.plot(\n", " X,\n", " Y,\n", " linewidth=\"2\",\n", " label=(\n", " r\"$y={theta0}{op}{theta1}x, \\; J(\\theta)={cost:.3}$\".format(\n", " theta0=theta[0],\n", " theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n", " op=\"+\" if theta[1] >= 0 else \"-\",\n", " cost=cost,\n", " )\n", " ),\n", " )\n", "\n", "\n", "sliderTheta02 = widgets.FloatSlider(\n", " min=-10, max=10, step=0.1, value=0, description=r\"$\\theta_0$\", width=300\n", ")\n", "sliderTheta12 = widgets.FloatSlider(\n", " min=-5, max=5, step=0.1, value=0, description=r\"$\\theta_1$\", width=300\n", ")\n", "\n", "\n", "def slide2(theta0, theta1):\n", " fig = regdots(x, y)\n", " regline2(fig, h, [theta0, theta1], x, y)\n", " legend(fig)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Poniższy interaktywny wykres pokazuje wartość funkcji kosztu $J(\\theta)$. Czy teraz łatwiej jest dobrać parametry modelu?" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0e9e869c222b4f928c8b7c358ccdd973", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide2, theta0=sliderTheta02, theta1=sliderTheta12)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Funkcja kosztu jako funkcja zmiennej $\\theta$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Funkcja kosztu zdefiniowana jako MSE jest funkcją zmiennej wektorowej $\\theta$, czyli funkcją dwóch zmiennych rzeczywistych: $\\theta_0$ i $\\theta_1$.\n", " \n", "Zobaczmy, jak wygląda jej wykres." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wykres funkcji kosztu dla ustalonego theta_1=1.0\n", "\n", "\n", "def costfun(fun, x, y):\n", " return lambda theta: J(fun, theta, x, y)\n", "\n", "\n", "def costplot(hypothesis, x, y, theta1=1.0):\n", " fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.set_xlabel(r\"$\\theta_0$\")\n", " ax.set_ylabel(r\"$J(\\theta)$\")\n", " j = costfun(hypothesis, x, y)\n", " fun = lambda theta0: j([theta0, theta1])\n", " X = np.arange(-10, 10, 0.1)\n", " Y = [fun(x) for x in X]\n", " ax.plot(\n", " X, Y, linewidth=\"2\", label=(r\"$J(\\theta_0, {theta1})$\".format(theta1=theta1))\n", " )\n", " return fig\n", "\n", "\n", "def slide3(theta1):\n", " fig = costplot(h, x, y, theta1)\n", " legend(fig)\n", "\n", "\n", "sliderTheta13 = widgets.FloatSlider(\n", " min=-5, max=5, step=0.1, value=1.0, description=r\"$\\theta_1$\", width=300\n", ")\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bbfe8cfd8e274970904cd2c9e12fce6f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(FloatSlider(value=1.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide3, theta1=sliderTheta13)\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wykres funkcji kosztu względem theta_0 i theta_1\n", "\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import pylab\n", "\n", "%matplotlib inline\n", "\n", "def costplot3d(hypothesis, x, y, show_gradient=False):\n", " fig = plt.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111, projection='3d')\n", " fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)\n", " ax.set_xlabel(r'$\\theta_0$')\n", " ax.set_ylabel(r'$\\theta_1$')\n", " ax.set_zlabel(r'$J(\\theta)$')\n", " \n", " j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n", " X = np.arange(-10, 10.1, 0.1)\n", " Y = np.arange(-1, 4.1, 0.1)\n", " X, Y = np.meshgrid(X, Y)\n", " Z = np.array([[J(hypothesis, [theta0, theta1], x, y) \n", " for theta0, theta1 in zip(xRow, yRow)] \n", " for xRow, yRow in zip(X, Y)])\n", " \n", " ax.plot_surface(X, Y, Z, rstride=2, cstride=8, linewidth=0.5,\n", " alpha=0.5, cmap='jet', zorder=0,\n", " label=r\"$J(\\theta)$\")\n", " ax.view_init(elev=20., azim=-150)\n", "\n", " ax.set_xlim3d(-10, 10);\n", " ax.set_ylim3d(-1, 4);\n", " ax.set_zlim3d(-100, 800);\n", "\n", " N = range(0, 800, 20)\n", " plt.contour(X, Y, Z, N, zdir='z', offset=-100, cmap='coolwarm', alpha=1)\n", " \n", " ax.plot([-3.89578088] * 2,\n", " [ 1.19303364] * 2,\n", " [-100, 4.47697137598], \n", " color='red', alpha=1, linewidth=1.3, zorder=100, linestyle='dashed',\n", " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", " ax.scatter([-3.89578088] * 2,\n", " [ 1.19303364] * 2,\n", " [-100, 4.47697137598], \n", " c='r', s=80, marker='x', alpha=1, linewidth=1.3, zorder=100, \n", " label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n", " \n", " if show_gradient:\n", " ax.plot([3.0, 1.1],\n", " [3.0, 2.4],\n", " [263.0, 125.0], \n", " color='green', alpha=1, linewidth=1.3, zorder=100)\n", " ax.scatter([3.0],\n", " [3.0],\n", " [263.0], \n", " c='g', s=30, marker='D', alpha=1, linewidth=1.3, zorder=100)\n", "\n", " ax.margins(0,0,0)\n", " fig.tight_layout()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-10-28T11:24:26.098085\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "costplot3d(h, x, y)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Na powyższym wykresie poszukiwane minimum funkcji kosztu oznaczone jest czerwonym krzyżykiem.\n", "\n", "Możemy też zobaczyć rzut powyższego trójwymiarowego wykresu na płaszczyznę $(\\theta_0, \\theta_1)$ poniżej:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "def costplot2d(hypothesis, x, y, gradient_values=[], nohead=False):\n", " fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.set_xlabel(r\"$\\theta_0$\")\n", " ax.set_ylabel(r\"$\\theta_1$\")\n", "\n", " j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n", " X = np.arange(-10, 10.1, 0.1)\n", " Y = np.arange(-1, 4.1, 0.1)\n", " X, Y = np.meshgrid(X, Y)\n", " Z = np.array(\n", " [\n", " [\n", " J(hypothesis, [theta0, theta1], x, y)\n", " for theta0, theta1 in zip(xRow, yRow)\n", " ]\n", " for xRow, yRow in zip(X, Y)\n", " ]\n", " )\n", "\n", " N = range(0, 800, 20)\n", " plt.contour(X, Y, Z, N, cmap=\"coolwarm\", alpha=1)\n", "\n", " ax.scatter(\n", " [-3.89578088],\n", " [1.19303364],\n", " c=\"r\",\n", " s=80,\n", " marker=\"x\",\n", " label=r\"minimum: $J(-3.90, 1.19) = 4.48$\",\n", " )\n", "\n", " if len(gradient_values) > 0:\n", " prev_theta = gradient_values[0][1]\n", " ax.scatter(\n", " [prev_theta[0]], [prev_theta[1]], c=\"g\", s=30, marker=\"D\", zorder=100\n", " )\n", " for cost, theta in gradient_values[1:]:\n", " dtheta = [theta[0] - prev_theta[0], theta[1] - prev_theta[1]]\n", " ax.arrow(\n", " prev_theta[0],\n", " prev_theta[1],\n", " dtheta[0],\n", " dtheta[1],\n", " color=\"green\",\n", " head_width=(0.0 if nohead else 0.1),\n", " head_length=(0.0 if nohead else 0.2),\n", " zorder=100,\n", " )\n", " prev_theta = theta\n", "\n", " return fig\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-10-28T11:24:28.270713\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = costplot2d(h, x, y)\n", "legend(fig)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Cechy funkcji kosztu\n", "Funkcja kosztu $J(\\theta)$ zdefiniowana powyżej jest funkcją wypukłą, dlatego posiada tylko jedno minimum lokalne." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 2.2. Metoda gradientu prostego" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Metoda gradientu prostego\n", "Metoda znajdowania minimów lokalnych." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Idea:\n", " * Zacznijmy od dowolnego $\\theta$.\n", " * Zmieniajmy powoli $\\theta$ tak, aby zmniejszać $J(\\theta)$, aż w końcu znajdziemy minimum." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-10-28T11:24:32.960329\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "costplot3d(h, x, y, show_gradient=True)\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przykładowe wartości kolejnych przybliżeń (sztuczne)\n", "\n", "gv = [\n", " [_, [3.0, 3.0]],\n", " [_, [2.6, 2.4]],\n", " [_, [2.2, 2.0]],\n", " [_, [1.6, 1.6]],\n", " [_, [0.4, 1.2]],\n", "]\n", "\n", "# Przygotowanie interaktywnego wykresu\n", "\n", "sliderSteps1 = widgets.IntSlider(\n", " min=0, max=3, step=1, value=0, description=\"kroki\", width=300\n", ")\n", "\n", "\n", "def slide4(steps):\n", " costplot2d(h, x, y, gradient_values=gv[: steps + 1])\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4b3e33d96ad447e08dace02f0ea543d2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(IntSlider(value=0, description='kroki', max=3), Output()), _dom_classes=('widget-interac…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact(slide4, steps=sliderSteps1)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Metoda gradientu prostego\n", "W każdym kroku będziemy aktualizować parametry $\\theta_j$:\n", "\n", "$$ \\theta_j := \\theta_j - \\alpha \\frac{\\partial}{\\partial \\theta_j} J(\\theta) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Współczynnik $\\alpha$ nazywamy **długością kroku** lub **współczynnikiem szybkości uczenia** (*learning rate*)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ \\begin{array}{rcl}\n", "\\dfrac{\\partial}{\\partial \\theta_j} J(\\theta)\n", " & = & \\dfrac{\\partial}{\\partial \\theta_j} \\dfrac{1}{2m} \\displaystyle\\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 \\\\\n", " & = & 2 \\cdot \\dfrac{1}{2m} \\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\\\\n", " & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( \\displaystyle\\sum_{i=0}^n \\theta_i x_i^{(i)} - y^{(i)} \\right)\\\\\n", " & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) -y^{(i)} \\right) x_j^{(i)} \\\\\n", "\\end{array} $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Czyli dla regresji liniowej jednej zmiennej:\n", "\n", "$$ h_\\theta(x) = \\theta_0 + \\theta_1x $$\n", "\n", "w każdym kroku będziemy aktualizować:\n", "\n", "$$\n", "\\begin{array}{rcl}\n", "\\theta_0 & := & \\theta_0 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) \\\\ \n", "\\theta_1 & := & \\theta_1 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) x^{(i)}\\\\ \n", "\\end{array}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "###### Uwaga!\n", " * W każdym kroku aktualizujemy *jednocześnie* $\\theta_0$ i $\\theta_1$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " * Kolejne kroki wykonujemy aż uzyskamy zbieżność" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Metoda gradientu prostego – implementacja" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Wyświetlanie macierzy w LaTeX-u\n", "\n", "\n", "def LatexMatrix(matrix):\n", " ltx = r\"\\left[\\begin{array}\"\n", " m, n = matrix.shape\n", " ltx += \"{\" + (\"r\" * n) + \"}\"\n", " for i in range(m):\n", " ltx += r\" & \".join([(\"%.4f\" % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n", " ltx += r\"\\end{array}\\right]\"\n", " return ltx\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def gradient_descent(h, cost_fun, theta, x, y, alpha, eps):\n", " current_cost = cost_fun(h, theta, x, y)\n", " history = [\n", " [current_cost, theta]\n", " ] # zapiszmy wartości kosztu i parametrów, by potem zrobić wykres\n", " m = len(y)\n", " while True:\n", " new_theta = [\n", " theta[0] - alpha / float(m) * sum(h(theta, x[i]) - y[i] for i in range(m)),\n", " theta[1]\n", " - alpha / float(m) * sum((h(theta, x[i]) - y[i]) * x[i] for i in range(m)),\n", " ]\n", " theta = new_theta # jednoczesna aktualizacja - używamy zmiennej tymczasowej\n", " try:\n", " prev_cost = current_cost\n", " current_cost = cost_fun(h, theta, x, y)\n", " except OverflowError:\n", " break\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " history.append([current_cost, theta])\n", " return theta, history\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "best_theta, history = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.05, eps=0.0001)\n", "\n", "display(\n", " Math(\n", " r\"\\large\\textrm{Wynik:}\\quad \\theta = \"\n", " + LatexMatrix(np.matrix(best_theta).reshape(2, 1))\n", " + (r\" \\quad J(\\theta) = %.4f\" % history[-1][0])\n", " + r\" \\quad \\textrm{po %d iteracjach}\" % len(history)\n", " )\n", ")\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przygotowanie interaktywnego wykresu\n", "\n", "sliderSteps2 = widgets.IntSlider(\n", " min=0, max=500, step=1, value=1, description=\"kroki\", width=300\n", ")\n", "\n", "\n", "def slide5(steps):\n", " costplot2d(h, x, y, gradient_values=history[: steps + 1], nohead=True)\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e2f6797395434b26b93f0392c7b58207", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(IntSlider(value=1, description='kroki', max=500), Button(description='Run Interact', sty…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "widgets.interact_manual(slide5, steps=sliderSteps2)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Współczynnik szybkości uczenia $\\alpha$ (długość kroku)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Tempo zbieżności metody gradientu prostego możemy regulować za pomocą parametru $\\alpha$, pamiętając, że:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " * Jeżeli długość kroku jest zbyt mała, algorytm może działać zbyt wolno." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " * Jeżeli długość kroku jest zbyt duża, algorytm może nie być zbieżny." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 2.3. Predykcja wyników" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Zbudowaliśmy model, dzięki któremu wiemy, jaka jest zależność między dochodem firmy transportowej ($y$) a ludnością miasta ($x$).\n", "\n", "Wróćmy teraz do postawionego na początku wykładu pytania: jak przewidzieć dochód firmy transportowej w mieście o danej wielkości?\n", "\n", "Odpowiedź polega po prostu na zastosowaniu funkcji $h$ z wyznaczonymi w poprzednim kroku parametrami $\\theta$.\n", "\n", "Na przykład, jeżeli miasto ma $536\\,000$ ludności, to $x = 53.6$ (bo dane trenujące były wyrażone w dziesiątkach tysięcy mieszkańców, a $536\\,000 = 53.6 \\cdot 10\\,000$) i możemy użyć znalezionych parametrów $\\theta$, by wykonać następujące obliczenia:\n", "$$ \\hat{y} \\, = \\, h_\\theta(x) \\, = \\, \\theta_0 + \\theta_1 \\, x \\, = \\, 0.0494 + 0.7591 \\cdot 53.6 \\, = \\, 40.7359 $$\n", "\n", "Czyli używając zdefiniowanych wcześniej funkcji:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "52.96131370254696\n" ] } ], "source": [ "example_x = 53.6\n", "predicted_y = h(best_theta, example_x)\n", "print(\n", " predicted_y\n", ") ## taki jest przewidywany dochód tej firmy transportowej w 536-tysięcznym mieście\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 2.4. Ewaluacja modelu" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Jak ocenić jakość stworzonego przez nas modelu?\n", "\n", " * Trzeba sprawdzić, jak przewidywania modelu zgadzają się z oczekiwaniami!" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Czy możemy w tym celu użyć danych, których użyliśmy do wytrenowania modelu?\n", "**NIE!**\n", "\n", " * Istotą uczenia maszynowego jest budowanie modeli/algorytmów, które dają dobre przewidywania dla **nieznanych** danych – takich, z którymi algorytm nie miał jeszcze styczności! Nie sztuką jest przewidywać rzeczy, które już sie zna.\n", " * Dlatego testowanie/ewaluowanie modelu na zbiorze uczącym mija się z celem i jest nieprzydatne.\n", " * Do ewaluacji modelu należy użyć oddzielnego zbioru danych.\n", " * **Dane uczące i dane testowe zawsze powinny stanowić oddzielne zbiory!**" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Na wykładzie *5. Dobre praktyki w uczeniu maszynowym* dowiesz się, jak podzielić posiadane dane na zbiór uczący i zbiór testowy.\n", "\n", "Tutaj, na razie, do ewaluacji użyjemy specjalnie przygotowanego zbioru testowego." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Jako metrykę ewaluacji wykorzystamy znany nam już błąd średniokwadratowy (MSE):" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "def mse(expected, predicted):\n", " \"\"\"Błąd średniokwadratowy\"\"\"\n", " m = len(expected)\n", " if len(predicted) != m:\n", " raise Exception(\"Wektory mają różne długości!\")\n", " return 1.0 / (2 * m) * sum((expected[i] - predicted[i]) ** 2 for i in range(m))\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.36540743711836\n" ] } ], "source": [ "# Wczytwanie danych testowych z pliku za pomocą numpy\n", "\n", "test_data = np.loadtxt(\"data1_test.csv\", delimiter=\",\")\n", "x_test = test_data[:, 0]\n", "y_test = test_data[:, 1]\n", "\n", "# Obliczenie przewidywań modelu\n", "y_pred = h(best_theta, x_test)\n", "\n", "# Obliczenie MSE na zbiorze testowym (im mniejszy MSE, tym lepiej!)\n", "evaluation_result = mse(y_test, y_pred)\n", "\n", "print(evaluation_result)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Otrzymana wartość mówi nam o tym, jak dobry jest stworzony przez nas model.\n", "\n", "W przypadku metryki MSE im mniejsza wartość, tym lepiej.\n", "\n", "W ten sposób możemy np. porównywać różne modele." ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "livereveal": { "start_slideshow_at": "selected", "theme": "white" }, "vscode": { "interpreter": { "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" } } }, "nbformat": 4, "nbformat_minor": 4 }