1
0
Fork 0
uczenie-maszynowe/wyk/02_Regresja_liniowa.ipynb

1610 lines
1.0 MiB
Plaintext
Raw Normal View History

2022-10-14 11:34:46 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe\n",
"# 2. Regresja liniowa część 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.1. Funkcja kosztu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Zadanie\n",
"Znając $x$ ludność miasta, należy przewidzieć $y$ dochód firmy transportowej.\n",
"\n",
"(Dane pochodzą z kursu „Machine Learning”, Andrew Ng, Coursera)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Uwaga**: Ponieważ ten przykład ma być tak prosty, jak to tylko możliwe, ludność miasta podana jest w dziesiątkach tysięcy mieszkańców, a dochód firmy w dziesiątkach tysięcy dolarów. Dzięki temu funkcja kosztu obliczona w dalszej części wykładu będzie osiągać wartości, które łatwo przedstawić na wykresie."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import ipywidgets as widgets\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = \"svg\"\n",
"\n",
"from IPython.display import display, Math, Latex"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Dane"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" x y\n",
"0 6.1101 17.59200\n",
"1 5.5277 9.13020\n",
"2 8.5186 13.66200\n",
"3 7.0032 11.85400\n",
"4 5.8598 6.82330\n",
".. ... ...\n",
"75 6.5479 0.29678\n",
"76 7.5386 3.88450\n",
"77 5.0365 5.70140\n",
"78 10.2740 6.75260\n",
"79 5.1077 2.05760\n",
"\n",
"[80 rows x 2 columns]\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
2022-10-27 10:32:07 +02:00
"data = pd.read_csv(\"data1_train.csv\", names=[\"x\", \"y\"])\n",
2022-10-14 11:34:46 +02:00
"print(data)\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"x = data[\"x\"].to_numpy()\n",
"y = data[\"y\"].to_numpy()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Hipoteza i parametry modelu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jak przewidzieć $y$ na podstawie danego $x$? W celu odpowiedzi na to pytanie będziemy starać się znaleźć taką funkcję $h(x)$, która będzie najlepiej obrazować zależność między $x$ a $y$, tj. $y \\sim h(x)$.\n",
"\n",
"Zacznijmy od najprostszego przypadku, kiedy $h(x)$ jest po prostu funkcją liniową. Ogólny wzór funkcji liniowej to"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ h(x) = a \\, x + b $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Pamiętajmy jednak, że współczynniki $a$ i $b$ nie są w tej chwili dane z góry naszym zadaniem właśnie będzie znalezienie takich ich wartości, żeby $h(x)$ było „możliwie jak najbliżej” $y$ (co właściwie oznacza to sformułowanie, wyjaśnię potem)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Poszukiwaną funkcję $h$ będziemy nazywać **funkcją hipotezy**, a jej współczynniki **parametrami modelu**."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W teorii uczenia maszynowego parametry modelu oznacza się na ogół grecką literą $\\theta$ z odpowiednimi indeksami, dlatego powyższy wzór opisujący liniową funkcję hipotezy zapiszemy jako\n",
"$$ h(x) = \\theta_0 + \\theta_1 x $$\n",
"\n",
"**Parametry modelu** tworzą wektor, który oznaczymy po prostu przez $\\theta$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ \\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Żeby podkreślić fakt, że funkcja hipotezy zależy od parametrów modelu, będziemy pisać $h_\\theta$ zamiast $h$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjrzyjmy się teraz, jak wyglądają dane, które mamy modelować:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Na poniższym wykresie możesz spróbować ręcznie dopasować parametry modelu $\\theta_0$ i $\\theta_1$ tak, aby jak najlepiej modelowały zależność między $x$ a $y$:"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Funkcje rysujące wykres kropkowy oraz prostą regresyjną\n",
"\n",
"\n",
"def regdots(x, y):\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter(x, y, c=\"r\", label=\"Dane\")\n",
"\n",
" ax.set_xlabel(\"Wielkość miejscowości\")\n",
" ax.set_ylabel(\"Dochód firmy\")\n",
" ax.margins(0.05, 0.05)\n",
" plt.ylim(min(y) - 1, max(y) + 1)\n",
" plt.xlim(min(x) - 1, max(x) + 1)\n",
" return fig\n",
"\n",
"\n",
"def regline(fig, fun, theta, x):\n",
" ax = fig.axes[0]\n",
" x0, x1 = min(x), max(x)\n",
" X = [x0, x1]\n",
" Y = [fun(theta, x) for x in X]\n",
" ax.plot(\n",
" X,\n",
" Y,\n",
" linewidth=\"2\",\n",
" label=(\n",
" r\"$y={theta0}{op}{theta1}x$\".format(\n",
" theta0=theta[0],\n",
" theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n",
" op=\"+\" if theta[1] >= 0 else \"-\",\n",
" )\n",
" ),\n",
" )\n",
"\n",
"\n",
"def legend(fig):\n",
" ax = fig.axes[0]\n",
" handles, labels = ax.get_legend_handles_labels()\n",
" # try-except block is a fix for a bug in Poly3DCollection\n",
" try:\n",
" fig.legend(handles, labels, fontsize=\"15\", loc=\"lower right\")\n",
" except AttributeError:\n",
" pass\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-10-27 10:32:07 +02:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"662.783125pt\" height=\"359.033144pt\" viewBox=\"0 0 662.783125 359.033144\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2022-10-14T11:18:51.560990</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.6.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 359.033144 \nL 662.783125 359.033144 \nL 662.783125 0 \nL 0 0 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 41.003125 320.453144 \nL 593.963125 320.453144 \nL 593.963125 9.413144 \nL 41.003125 9.413144 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path id=\"m7f751acde1\" d=\"M 0 3 \nC 0.795609 3 1.55874 2.683901 2.12132 2.12132 \nC 2.683901 1.55874 3 0.795609 3 0 \nC 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 \nC 1.55874 -2.683901 0.795609 -3 0 -3 \nC -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 \nC -2.683901 -1.55874 -3 -0.795609 -3 0 \nC -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 \nC -1.55874 2.683901 -0.795609 3 0 3 \nz\n\" style=\"stroke: #ff0000\"/>\n </defs>\n <g clip-path=\"url(#p8b35ab4d72)\">\n <use xlink:href=\"#m7f751acde1\" x=\"101.074061\" y=\"90.928742\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"84.280036\" y=\"182.22837\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"170.525313\" y=\"133.331958\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"126.827398\" y=\"152.839596\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"93.856436\" y=\"207.11895\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"166.612282\" y=\"152.494328\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"140.472544\" y=\"233.823265\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"172.241049\" y=\"151.264311\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"111.919241\" y=\"209.542299\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"70.637774\" y=\"239.560108\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"89.557004\" y=\"245.649771\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"333.315473\" y=\"113.446684\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"90.228881\" y=\"246.697443\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"167.347598\" y=\"202.776127\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"87.538492\" y=\"273.012469\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"80.003673\" y=\"242.836916\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"108.435865\" y=\"223.502991\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"72.814884\" y=\"274.689284\" style=\"fill: #ff0000; stroke: #ff0000\"/>\n <use xlink:href=\"#m7f751acde1\" x=\"110.287129\" y=\"241.338237\" style=\"fill: #ff0000; s
2022-10-14 11:34:46 +02:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = regdots(x, y)\n",
"legend(fig)\n"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Hipoteza: funkcja liniowa jednej zmiennej\n",
"\n",
"\n",
"def h(theta, x):\n",
" return theta[0] + theta[1] * x\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"sliderTheta01 = widgets.FloatSlider(\n",
" min=-10, max=10, step=0.1, value=0, description=r\"$\\theta_0$\", width=300\n",
")\n",
"sliderTheta11 = widgets.FloatSlider(\n",
" min=-5, max=5, step=0.1, value=0, description=r\"$\\theta_1$\", width=300\n",
")\n",
"\n",
"\n",
"def slide1(theta0, theta1):\n",
" fig = regdots(x, y)\n",
" regline(fig, h, [theta0, theta1], x)\n",
" legend(fig)\n"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4880be41c52643798571f509b333a025",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide1(theta0, theta1)>"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide1, theta0=sliderTheta01, theta1=sliderTheta11)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Skąd wiadomo, że przewidywania modelu (wartości funkcji $h(x)$) zgadzaja się z obserwacjami (wartości $y$)?\n",
"\n",
"Aby to zmierzyć wprowadzimy pojęcie funkcji kosztu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja kosztu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Funkcję kosztu zdefiniujemy w taki sposób, żeby odzwierciedlała ona różnicę między przewidywaniami modelu a obserwacjami.\n",
"\n",
"Jedną z możliwosci jest zdefiniowanie funkcji kosztu jako wartość **błędu średniokwadratowego** (metoda najmniejszych kwadratów, *mean-square error, MSE*).\n",
"\n",
"My zdefiniujemy funkcję kosztu jako *połowę* błędu średniokwadratowego w celu ułatwienia późniejszych obliczeń (obliczenie pochodnej funkcji kosztu w dalszej części wykładu). Możemy tak zrobić, ponieważ $\\frac{1}{2}$ jest stałą, a pomnożenie przez stałą nie wpływa na przebieg zmienności funkcji."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ J(\\theta) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"gdzie $m$ jest liczbą wszystkich przykładów (obserwacji), czyli wielkością zbioru danych uczących.\n",
"\n",
"W powyższym wzorze sumujemy kwadraty różnic między przewidywaniami modelu ($h_\\theta \\left( x^{(i)} \\right)$) a obserwacjami ($y^{(i)}$) po wszystkich przykładach $i$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Teraz nasze zadanie sprowadza się do tego, że będziemy szukać takich parametrów $\\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right]$, które minimalizują fukcję kosztu $J(\\theta)$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\hat\\theta = \\mathop{\\arg\\min}_{\\theta} J(\\theta) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\theta \\in \\mathbb{R}^2, \\quad J \\colon \\mathbb{R}^2 \\to \\mathbb{R} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Proszę zwrócić uwagę, że dziedziną funkcji kosztu jest zbiór wszystkich możliwych wartości parametrów $\\theta$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ J(\\theta_0, \\theta_1) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( \\theta_0 + \\theta_1 x^{(i)} - y^{(i)} \\right) ^2 $$"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"def J(h, theta, x, y):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i]) ** 2 for i in range(m))\n"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"# Oblicz wartość funkcji kosztu i pokaż na wykresie\n",
"\n",
"\n",
"def regline2(fig, fun, theta, xx, yy):\n",
" \"\"\"Rysuj regresję liniową\"\"\"\n",
" ax = fig.axes[0]\n",
" x0, x1 = min(xx), max(xx)\n",
" X = [x0, x1]\n",
" Y = [fun(theta, x) for x in X]\n",
" cost = J(fun, theta, xx, yy)\n",
" ax.plot(\n",
" X,\n",
" Y,\n",
" linewidth=\"2\",\n",
" label=(\n",
" r\"$y={theta0}{op}{theta1}x, \\; J(\\theta)={cost:.3}$\".format(\n",
" theta0=theta[0],\n",
" theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n",
" op=\"+\" if theta[1] >= 0 else \"-\",\n",
" cost=cost,\n",
" )\n",
" ),\n",
" )\n",
"\n",
"\n",
"sliderTheta02 = widgets.FloatSlider(\n",
" min=-10, max=10, step=0.1, value=0, description=r\"$\\theta_0$\", width=300\n",
")\n",
"sliderTheta12 = widgets.FloatSlider(\n",
" min=-5, max=5, step=0.1, value=0, description=r\"$\\theta_1$\", width=300\n",
")\n",
"\n",
"\n",
"def slide2(theta0, theta1):\n",
" fig = regdots(x, y)\n",
" regline2(fig, h, [theta0, theta1], x, y)\n",
" legend(fig)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Poniższy interaktywny wykres pokazuje wartość funkcji kosztu $J(\\theta)$. Czy teraz łatwiej jest dobrać parametry modelu?"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c67ea652bba946cf83a86485848bb0b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide2(theta0, theta1)>"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide2, theta0=sliderTheta02, theta1=sliderTheta12)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja kosztu jako funkcja zmiennej $\\theta$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Funkcja kosztu zdefiniowana jako MSE jest funkcją zmiennej wektorowej $\\theta$, czyli funkcją dwóch zmiennych rzeczywistych: $\\theta_0$ i $\\theta_1$.\n",
" \n",
"Zobaczmy, jak wygląda jej wykres."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wykres funkcji kosztu dla ustalonego theta_1=1.0\n",
"\n",
"\n",
"def costfun(fun, x, y):\n",
" return lambda theta: J(fun, theta, x, y)\n",
"\n",
"\n",
"def costplot(hypothesis, x, y, theta1=1.0):\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.set_xlabel(r\"$\\theta_0$\")\n",
" ax.set_ylabel(r\"$J(\\theta)$\")\n",
" j = costfun(hypothesis, x, y)\n",
" fun = lambda theta0: j([theta0, theta1])\n",
" X = np.arange(-10, 10, 0.1)\n",
" Y = [fun(x) for x in X]\n",
" ax.plot(\n",
" X, Y, linewidth=\"2\", label=(r\"$J(\\theta_0, {theta1})$\".format(theta1=theta1))\n",
" )\n",
" return fig\n",
"\n",
"\n",
"def slide3(theta1):\n",
" fig = costplot(h, x, y, theta1)\n",
" legend(fig)\n",
"\n",
"\n",
"sliderTheta13 = widgets.FloatSlider(\n",
" min=-5, max=5, step=0.1, value=1.0, description=r\"$\\theta_1$\", width=300\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f5ea28655cad4743b9e58a3ecd0b1fc3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=1.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide3(theta1)>"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide3, theta1=sliderTheta13)\n"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wykres funkcji kosztu względem theta_0 i theta_1\n",
"\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import pylab\n",
"\n",
"%matplotlib inline\n",
"\n",
"def costplot3d(hypothesis, x, y, show_gradient=False):\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111, projection='3d')\n",
" fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)\n",
" ax.set_xlabel(r'$\\theta_0$')\n",
" ax.set_ylabel(r'$\\theta_1$')\n",
" ax.set_zlabel(r'$J(\\theta)$')\n",
" \n",
" j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n",
" X = np.arange(-10, 10.1, 0.1)\n",
" Y = np.arange(-1, 4.1, 0.1)\n",
" X, Y = np.meshgrid(X, Y)\n",
" Z = np.array([[J(hypothesis, [theta0, theta1], x, y) \n",
" for theta0, theta1 in zip(xRow, yRow)] \n",
" for xRow, yRow in zip(X, Y)])\n",
" \n",
" ax.plot_surface(X, Y, Z, rstride=2, cstride=8, linewidth=0.5,\n",
" alpha=0.5, cmap='jet', zorder=0,\n",
" label=r\"$J(\\theta)$\")\n",
" ax.view_init(elev=20., azim=-150)\n",
"\n",
" ax.set_xlim3d(-10, 10);\n",
" ax.set_ylim3d(-1, 4);\n",
" ax.set_zlim3d(-100, 800);\n",
"\n",
" N = range(0, 800, 20)\n",
" plt.contour(X, Y, Z, N, zdir='z', offset=-100, cmap='coolwarm', alpha=1)\n",
" \n",
" ax.plot([-3.89578088] * 2,\n",
" [ 1.19303364] * 2,\n",
" [-100, 4.47697137598], \n",
" color='red', alpha=1, linewidth=1.3, zorder=100, linestyle='dashed',\n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" ax.scatter([-3.89578088] * 2,\n",
" [ 1.19303364] * 2,\n",
" [-100, 4.47697137598], \n",
" c='r', s=80, marker='x', alpha=1, linewidth=1.3, zorder=100, \n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" \n",
" if show_gradient:\n",
" ax.plot([3.0, 1.1],\n",
" [3.0, 2.4],\n",
" [263.0, 125.0], \n",
" color='green', alpha=1, linewidth=1.3, zorder=100)\n",
" ax.scatter([3.0],\n",
" [3.0],\n",
" [263.0], \n",
" c='g', s=30, marker='D', alpha=1, linewidth=1.3, zorder=100)\n",
"\n",
" ax.margins(0,0,0)\n",
" fig.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-10-27 10:32:07 +02:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"390.444589pt\" height=\"381.6pt\" viewBox=\"0 0 390.444589 381.6\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2022-10-14T11:19:26.563438</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.6.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 381.6 \nL 390.444589 381.6 \nL 390.444589 0 \nL 0 0 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"patch_2\">\n <path d=\"M 16.044589 374.4 \nL 383.244589 374.4 \nL 383.244589 7.2 \nL 16.044589 7.2 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"pane3d_1\">\n <g id=\"patch_3\">\n <path d=\"M 360.428771 251.950477 \nL 164.993629 211.54792 \nL 163.85364 55.350348 \nL 365.162551 88.288778 \n\" style=\"fill: #f2f2f2; opacity: 0.5; stroke: #f2f2f2; stroke-linejoin: miter\"/>\n </g>\n </g>\n <g id=\"pane3d_2\">\n <g id=\"patch_4\">\n <path d=\"M 42.37097 284.407606 \nL 164.993629 211.54792 \nL 163.85364 55.350348 \nL 37.233053 114.824017 \n\" style=\"fill: #e6e6e6; opacity: 0.5; stroke: #e6e6e6; stroke-linejoin: miter\"/>\n </g>\n </g>\n <g id=\"pane3d_3\">\n <g id=\"patch_5\">\n <path d=\"M 250.667545 333.320011 \nL 360.428771 251.950477 \nL 164.993629 211.54792 \nL 42.37097 284.407606 \n\" style=\"fill: #ececec; opacity: 0.5; stroke: #ececec; stroke-linejoin: miter\"/>\n </g>\n </g>\n <g id=\"axis3d_1\">\n <g id=\"line2d_1\">\n <path d=\"M 360.428771 251.950477 \nL 250.667545 333.320011 \n\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"text_1\">\n <!-- $\\theta_0$ -->\n <g transform=\"translate(333.040862 323.946514) rotate(-36.550732) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-Oblique-3b8\" d=\"M 2913 2219 \nL 925 2219 \nQ 791 1284 928 888 \nQ 1100 400 1566 400 \nQ 2034 400 2391 891 \nQ 2703 1322 2913 2219 \nz\nM 3009 2750 \nQ 3094 3638 2984 3950 \nQ 2813 4444 2353 4444 \nQ 1875 4444 1525 3956 \nQ 1250 3563 1034 2750 \nL 3009 2750 \nz\nM 2444 4913 \nQ 3194 4913 3494 4250 \nQ 3794 3591 3566 2422 \nQ 3341 1256 2781 594 \nQ 2225 -72 1475 -72 \nQ 722 -72 425 594 \nQ 128 1256 353 2422 \nQ 581 3591 1134 4250 \nQ 1691 4913 2444 4913 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-Oblique-3b8\" transform=\"translate(0 0.234375)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(61.181641 -16.171875) scale(0.7)\"/>\n </g>\n </g>\n <g id=\"Line3DCollection_1\">\n <path d=\"M 253.084306 331.528388 \nL 45.05703 282.811609 \nL 40.014312 113.51766 \n\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8\"/>\n <path d=\"M 267.378373 320.931736 \nL 60.956609 273.364428 \nL 56.470393 105.788242 \n\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8\"/>\n <path d=\"M 281.327689 310.590658 \nL 76.493745 264.132604
2022-10-14 11:34:46 +02:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"costplot3d(h, x, y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Na powyższym wykresie poszukiwane minimum funkcji kosztu oznaczone jest czerwonym krzyżykiem.\n",
"\n",
"Możemy też zobaczyć rzut powyższego trójwymiarowego wykresu na płaszczyznę $(\\theta_0, \\theta_1)$ poniżej:"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def costplot2d(hypothesis, x, y, gradient_values=[], nohead=False):\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.set_xlabel(r\"$\\theta_0$\")\n",
" ax.set_ylabel(r\"$\\theta_1$\")\n",
"\n",
" j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n",
" X = np.arange(-10, 10.1, 0.1)\n",
" Y = np.arange(-1, 4.1, 0.1)\n",
" X, Y = np.meshgrid(X, Y)\n",
" Z = np.array(\n",
" [\n",
" [\n",
" J(hypothesis, [theta0, theta1], x, y)\n",
" for theta0, theta1 in zip(xRow, yRow)\n",
" ]\n",
" for xRow, yRow in zip(X, Y)\n",
" ]\n",
" )\n",
"\n",
" N = range(0, 800, 20)\n",
" plt.contour(X, Y, Z, N, cmap=\"coolwarm\", alpha=1)\n",
"\n",
" ax.scatter(\n",
" [-3.89578088],\n",
" [1.19303364],\n",
" c=\"r\",\n",
" s=80,\n",
" marker=\"x\",\n",
" label=r\"minimum: $J(-3.90, 1.19) = 4.48$\",\n",
" )\n",
"\n",
" if len(gradient_values) > 0:\n",
" prev_theta = gradient_values[0][1]\n",
" ax.scatter(\n",
" [prev_theta[0]], [prev_theta[1]], c=\"g\", s=30, marker=\"D\", zorder=100\n",
" )\n",
" for cost, theta in gradient_values[1:]:\n",
" dtheta = [theta[0] - prev_theta[0], theta[1] - prev_theta[1]]\n",
" ax.arrow(\n",
" prev_theta[0],\n",
" prev_theta[1],\n",
" dtheta[0],\n",
" dtheta[1],\n",
" color=\"green\",\n",
" head_width=(0.0 if nohead else 0.1),\n",
" head_length=(0.0 if nohead else 0.2),\n",
" zorder=100,\n",
" )\n",
" prev_theta = theta\n",
"\n",
" return fig\n"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-10-27 10:32:07 +02:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"664.400312pt\" height=\"360.619219pt\" viewBox=\"0 0 664.400312 360.619219\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2022-10-14T11:19:28.775965</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.6.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M -0 360.619219 \nL 664.400312 360.619219 \nL 664.400312 0 \nL -0 0 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 42.620312 322.039219 \nL 595.580312 322.039219 \nL 595.580312 10.999219 \nL 42.620312 10.999219 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path id=\"ma6bf33e205\" d=\"M -4.472136 4.472136 \nL 4.472136 -4.472136 \nM -4.472136 -4.472136 \nL 4.472136 4.472136 \n\" style=\"stroke: #ff0000; stroke-width: 1.5\"/>\n </defs>\n <g clip-path=\"url(#p0f1a8a9a24)\">\n <use xlink:href=\"#ma6bf33e205\" x=\"211.389763\" y=\"185.614982\" style=\"fill: #ff0000; stroke: #ff0000; stroke-width: 1.5\"/>\n </g>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path id=\"m4113ebd0b2\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m4113ebd0b2\" x=\"42.620313\" y=\"322.039219\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 10.0 -->\n <g transform=\"translate(27.297656 336.637656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-2e\" d=\"M 684 794 \nL 1344 794 \nL 1344 0 \nL 684 0 \nL 684 794 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-2212\"/>\n <use xlink:href=\"#DejaVuSans-31\" x=\"83.789062\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"147.412109\"/>\n <use xlink:href=\"#DejaVuSans-2e\" x=\"211.035156\"/>\n <use xlink:href=\"#DejaVuSans-30\" x=\"242.822266\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use xlink:href=\"#m4113ebd0b2\" x=\"111.740313\" y=\"322.039219\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 7.5 -->\n <g transform=\"translate(99.598906 336.637656) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-37\" d=\"M 525 4666 \nL 3525 4666 \nL 3525 4397 \nL 1831 0 \
2022-10-14 11:34:46 +02:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = costplot2d(h, x, y)\n",
"legend(fig)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Cechy funkcji kosztu\n",
"Funkcja kosztu $J(\\theta)$ zdefiniowana powyżej jest funkcją wypukłą, dlatego posiada tylko jedno minimum lokalne."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.2. Metoda gradientu prostego"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Metoda gradientu prostego\n",
"Metoda znajdowania minimów lokalnych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Idea:\n",
" * Zacznijmy od dowolnego $\\theta$.\n",
" * Zmieniajmy powoli $\\theta$ tak, aby zmniejszać $J(\\theta)$, aż w końcu znajdziemy minimum."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-10-27 10:32:07 +02:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"390.444589pt\" height=\"381.6pt\" viewBox=\"0 0 390.444589 381.6\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2022-10-14T11:19:32.586581</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.6.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 381.6 \nL 390.444589 381.6 \nL 390.444589 0 \nL 0 0 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"patch_2\">\n <path d=\"M 16.044589 374.4 \nL 383.244589 374.4 \nL 383.244589 7.2 \nL 16.044589 7.2 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"pane3d_1\">\n <g id=\"patch_3\">\n <path d=\"M 360.428771 251.950477 \nL 164.993629 211.54792 \nL 163.85364 55.350348 \nL 365.162551 88.288778 \n\" style=\"fill: #f2f2f2; opacity: 0.5; stroke: #f2f2f2; stroke-linejoin: miter\"/>\n </g>\n </g>\n <g id=\"pane3d_2\">\n <g id=\"patch_4\">\n <path d=\"M 42.37097 284.407606 \nL 164.993629 211.54792 \nL 163.85364 55.350348 \nL 37.233053 114.824017 \n\" style=\"fill: #e6e6e6; opacity: 0.5; stroke: #e6e6e6; stroke-linejoin: miter\"/>\n </g>\n </g>\n <g id=\"pane3d_3\">\n <g id=\"patch_5\">\n <path d=\"M 250.667545 333.320011 \nL 360.428771 251.950477 \nL 164.993629 211.54792 \nL 42.37097 284.407606 \n\" style=\"fill: #ececec; opacity: 0.5; stroke: #ececec; stroke-linejoin: miter\"/>\n </g>\n </g>\n <g id=\"axis3d_1\">\n <g id=\"line2d_1\">\n <path d=\"M 360.428771 251.950477 \nL 250.667545 333.320011 \n\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linecap: square\"/>\n </g>\n <g id=\"text_1\">\n <!-- $\\theta_0$ -->\n <g transform=\"translate(333.040862 323.946514) rotate(-36.550732) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-Oblique-3b8\" d=\"M 2913 2219 \nL 925 2219 \nQ 791 1284 928 888 \nQ 1100 400 1566 400 \nQ 2034 400 2391 891 \nQ 2703 1322 2913 2219 \nz\nM 3009 2750 \nQ 3094 3638 2984 3950 \nQ 2813 4444 2353 4444 \nQ 1875 4444 1525 3956 \nQ 1250 3563 1034 2750 \nL 3009 2750 \nz\nM 2444 4913 \nQ 3194 4913 3494 4250 \nQ 3794 3591 3566 2422 \nQ 3341 1256 2781 594 \nQ 2225 -72 1475 -72 \nQ 722 -72 425 594 \nQ 128 1256 353 2422 \nQ 581 3591 1134 4250 \nQ 1691 4913 2444 4913 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-Oblique-3b8\" transform=\"translate(0 0.234375)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(61.181641 -16.171875) scale(0.7)\"/>\n </g>\n </g>\n <g id=\"Line3DCollection_1\">\n <path d=\"M 253.084306 331.528388 \nL 45.05703 282.811609 \nL 40.014312 113.51766 \n\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8\"/>\n <path d=\"M 267.378373 320.931736 \nL 60.956609 273.364428 \nL 56.470393 105.788242 \n\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8\"/>\n <path d=\"M 281.327689 310.590658 \nL 76.493745 264.132604
2022-10-14 11:34:46 +02:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"costplot3d(h, x, y, show_gradient=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przykładowe wartości kolejnych przybliżeń (sztuczne)\n",
"\n",
"gv = [\n",
" [_, [3.0, 3.0]],\n",
" [_, [2.6, 2.4]],\n",
" [_, [2.2, 2.0]],\n",
" [_, [1.6, 1.6]],\n",
" [_, [0.4, 1.2]],\n",
"]\n",
"\n",
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"sliderSteps1 = widgets.IntSlider(\n",
" min=0, max=3, step=1, value=0, description=\"kroki\", width=300\n",
")\n",
"\n",
"\n",
"def slide4(steps):\n",
" costplot2d(h, x, y, gradient_values=gv[: steps + 1])\n"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba49ab01f3694550a13b124b599f9d17",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(IntSlider(value=0, description='kroki', max=3), Output()), _dom_classes=('widget-interac…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide4(steps)>"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact(slide4, steps=sliderSteps1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metoda gradientu prostego\n",
"W każdym kroku będziemy aktualizować parametry $\\theta_j$:\n",
"\n",
"$$ \\theta_j := \\theta_j - \\alpha \\frac{\\partial}{\\partial \\theta_j} J(\\theta) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Współczynnik $\\alpha$ nazywamy **długością kroku** lub **współczynnikiem szybkości uczenia** (*learning rate*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ \\begin{array}{rcl}\n",
"\\dfrac{\\partial}{\\partial \\theta_j} J(\\theta)\n",
" & = & \\dfrac{\\partial}{\\partial \\theta_j} \\dfrac{1}{2m} \\displaystyle\\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 \\\\\n",
" & = & 2 \\cdot \\dfrac{1}{2m} \\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\\\\n",
" & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( \\displaystyle\\sum_{i=0}^n \\theta_i x_i^{(i)} - y^{(i)} \\right)\\\\\n",
" & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) -y^{(i)} \\right) x_j^{(i)} \\\\\n",
"\\end{array} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Czyli dla regresji liniowej jednej zmiennej:\n",
"\n",
"$$ h_\\theta(x) = \\theta_0 + \\theta_1x $$\n",
"\n",
"w każdym kroku będziemy aktualizować:\n",
"\n",
"$$\n",
"\\begin{array}{rcl}\n",
"\\theta_0 & := & \\theta_0 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) \\\\ \n",
"\\theta_1 & := & \\theta_1 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) x^{(i)}\\\\ \n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"###### Uwaga!\n",
" * W każdym kroku aktualizujemy *jednocześnie* $\\theta_0$ i $\\theta_1$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Kolejne kroki wykonujemy aż uzyskamy zbieżność"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metoda gradientu prostego implementacja"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wyświetlanie macierzy w LaTeX-u\n",
"\n",
"\n",
"def LatexMatrix(matrix):\n",
" ltx = r\"\\left[\\begin{array}\"\n",
" m, n = matrix.shape\n",
" ltx += \"{\" + (\"r\" * n) + \"}\"\n",
" for i in range(m):\n",
" ltx += r\" & \".join([(\"%.4f\" % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n",
" ltx += r\"\\end{array}\\right]\"\n",
" return ltx\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"def gradient_descent(h, cost_fun, theta, x, y, alpha, eps):\n",
" current_cost = cost_fun(h, theta, x, y)\n",
" history = [\n",
" [current_cost, theta]\n",
" ] # zapiszmy wartości kosztu i parametrów, by potem zrobić wykres\n",
" m = len(y)\n",
" while True:\n",
" new_theta = [\n",
" theta[0] - alpha / float(m) * sum(h(theta, x[i]) - y[i] for i in range(m)),\n",
" theta[1]\n",
" - alpha / float(m) * sum((h(theta, x[i]) - y[i]) * x[i] for i in range(m)),\n",
" ]\n",
" theta = new_theta # jednoczesna aktualizacja - używamy zmiennej tymczasowej\n",
" try:\n",
" prev_cost = current_cost\n",
" current_cost = cost_fun(h, theta, x, y)\n",
" except OverflowError:\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" history.append([current_cost, theta])\n",
" return theta, history\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}-1.8792 \\\\ 1.0231 \\\\ \\end{array}\\right] \\quad J(\\theta) = 5.0010 \\quad \\textrm{po 4114 iteracjach}$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_theta, history = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.001, eps=0.0001)\n",
"\n",
"display(\n",
" Math(\n",
" r\"\\large\\textrm{Wynik:}\\quad \\theta = \"\n",
" + LatexMatrix(np.matrix(best_theta).reshape(2, 1))\n",
" + (r\" \\quad J(\\theta) = %.4f\" % history[-1][0])\n",
" + r\" \\quad \\textrm{po %d iteracjach}\" % len(history)\n",
" )\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"sliderSteps2 = widgets.IntSlider(\n",
" min=0, max=500, step=1, value=1, description=\"kroki\", width=300\n",
")\n",
"\n",
"\n",
"def slide5(steps):\n",
" costplot2d(h, x, y, gradient_values=history[: steps + 1], nohead=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59091adc5a5f4d20bf2ad5e92c17b234",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(IntSlider(value=1, description='kroki', max=500), Button(description='Run Interact', sty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide5(steps)>"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide5, steps=sliderSteps2)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Współczynnik szybkości uczenia $\\alpha$ (długość kroku)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Tempo zbieżności metody gradientu prostego możemy regulować za pomocą parametru $\\alpha$, pamiętając, że:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Jeżeli długość kroku jest zbyt mała, algorytm może działać zbyt wolno."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Jeżeli długość kroku jest zbyt duża, algorytm może nie być zbieżny."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.3. Predykcja wyników"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zbudowaliśmy model, dzięki któremu wiemy, jaka jest zależność między dochodem firmy transportowej ($y$) a ludnością miasta ($x$).\n",
"\n",
"Wróćmy teraz do postawionego na początku wykładu pytania: jak przewidzieć dochód firmy transportowej w mieście o danej wielkości?\n",
"\n",
"Odpowiedź polega po prostu na zastosowaniu funkcji $h$ z wyznaczonymi w poprzednim kroku parametrami $\\theta$.\n",
"\n",
"Na przykład, jeżeli miasto ma $536\\,000$ ludności, to $x = 53.6$ (bo dane trenujące były wyrażone w dziesiątkach tysięcy mieszkańców, a $536\\,000 = 53.6 \\cdot 10\\,000$) i możemy użyć znalezionych parametrów $\\theta$, by wykonać następujące obliczenia:\n",
"$$ \\hat{y} \\, = \\, h_\\theta(x) \\, = \\, \\theta_0 + \\theta_1 \\, x \\, = \\, 0.0494 + 0.7591 \\cdot 53.6 \\, = \\, 40.7359 $$\n",
"\n",
"Czyli używając zdefiniowanych wcześniej funkcji:"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"52.96131370254696\n"
]
}
],
"source": [
"example_x = 53.6\n",
"predicted_y = h(best_theta, example_x)\n",
"print(\n",
" predicted_y\n",
") ## taki jest przewidywany dochód tej firmy transportowej w 536-tysięcznym mieście\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.4. Ewaluacja modelu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Jak ocenić jakość stworzonego przez nas modelu?\n",
"\n",
" * Trzeba sprawdzić, jak przewidywania modelu zgadzają się z oczekiwaniami!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Czy możemy w tym celu użyć danych, których użyliśmy do wytrenowania modelu?\n",
"**NIE!**\n",
"\n",
" * Istotą uczenia maszynowego jest budowanie modeli/algorytmów, które dają dobre przewidywania dla **nieznanych** danych takich, z którymi algorytm nie miał jeszcze styczności! Nie sztuką jest przewidywać rzeczy, które już sie zna.\n",
" * Dlatego testowanie/ewaluowanie modelu na zbiorze uczącym mija się z celem i jest nieprzydatne.\n",
" * Do ewaluacji modelu należy użyć oddzielnego zbioru danych.\n",
" * **Dane uczące i dane testowe zawsze powinny stanowić oddzielne zbiory!**"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Na wykładzie *5. Dobre praktyki w uczeniu maszynowym* dowiesz się, jak podzielić posiadane dane na zbiór uczący i zbiór testowy.\n",
"\n",
"Tutaj, na razie, do ewaluacji użyjemy specjalnie przygotowanego zbioru testowego."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Jako metrykę ewaluacji wykorzystamy znany nam już błąd średniokwadratowy (MSE):"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"def mse(expected, predicted):\n",
" \"\"\"Błąd średniokwadratowy\"\"\"\n",
" m = len(expected)\n",
" if len(predicted) != m:\n",
" raise Exception(\"Wektory mają różne długości!\")\n",
" return 1.0 / (2 * m) * sum((expected[i] - predicted[i]) ** 2 for i in range(m))\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.36540743711836\n"
]
}
],
"source": [
"# Wczytwanie danych testowych z pliku za pomocą numpy\n",
"\n",
2022-10-27 10:32:07 +02:00
"test_data = np.loadtxt(\"data1_test.csv\", delimiter=\",\")\n",
2022-10-14 11:34:46 +02:00
"x_test = test_data[:, 0]\n",
"y_test = test_data[:, 1]\n",
"\n",
"# Obliczenie przewidywań modelu\n",
"y_pred = h(best_theta, x_test)\n",
"\n",
"# Obliczenie MSE na zbiorze testowym (im mniejszy MSE, tym lepiej!)\n",
"evaluation_result = mse(y_test, y_pred)\n",
"\n",
"print(evaluation_result)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Otrzymana wartość mówi nam o tym, jak dobry jest stworzony przez nas model.\n",
"\n",
"W przypadku metryki MSE im mniejsza wartość, tym lepiej.\n",
"\n",
"W ten sposób możemy np. porównywać różne modele."
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
},
"vscode": {
"interpreter": {
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}