2021-03-02 08:32:40 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
2021-03-24 08:11:50 +01:00
"## Uczenie maszynowe – zastosowania\n",
2021-03-02 08:32:40 +01:00
"# 2. Regresja liniowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.1. Funkcja kosztu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Zadanie\n",
"Znając $x$ – ludność miasta (w dziesiątkach tysięcy mieszkańców),\n",
"należy przewidzieć $y$ – dochód firmy transportowej (w dziesiątkach tysięcy dolarów).\n",
"\n",
"(Dane pochodzą z kursu „Machine Learning”, Andrew Ng, Coursera)."
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 1,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
2021-03-30 12:12:30 +02:00
"# Przydatne importy\n",
"\n",
2021-03-02 08:32:40 +01:00
"import numpy as np\n",
"import matplotlib\n",
2021-03-30 12:12:30 +02:00
"import matplotlib.pyplot as plt\n",
2021-03-02 08:32:40 +01:00
"import ipywidgets as widgets\n",
2021-03-30 12:12:30 +02:00
"import pandas as pd\n",
2021-03-02 08:32:40 +01:00
"\n",
"%matplotlib inline\n",
2021-03-30 12:12:30 +02:00
"%config InlineBackend.figure_format = \"svg\"\n",
2021-03-02 08:32:40 +01:00
"\n",
"from IPython.display import display, Math, Latex"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
2021-03-30 12:12:30 +02:00
"### Wczytanie danych"
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 2,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 12:12:30 +02:00
" x y\n",
"0 6.1101 17.5920\n",
"1 5.5277 9.1302\n",
"2 8.5186 13.6620\n",
"3 7.0032 11.8540\n",
"4 5.8598 6.8233\n",
"5 8.3829 11.8860\n",
"6 7.4764 4.3483\n",
"7 8.5781 12.0000\n",
"8 6.4862 6.5987\n",
"9 5.0546 3.8166\n"
2021-03-02 08:32:40 +01:00
]
}
],
"source": [
2021-03-30 12:12:30 +02:00
"data = pd.read_csv(\"data01_train.csv\", names=[\"x\", \"y\"])\n",
"print(data[:10])"
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 3,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
2021-03-30 12:12:30 +02:00
"outputs": [],
2021-03-02 08:32:40 +01:00
"source": [
2021-03-30 12:12:30 +02:00
"x = data[[\"x\"]].to_numpy().flatten()\n",
"y = data[[\"y\"]].to_numpy().flatten()"
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Hipoteza i parametry modelu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jak przewidzieć $y$ na podstawie danego $x$? W celu odpowiedzi na to pytanie będziemy starać się znaleźć taką funkcję $h(x)$, która będzie najlepiej obrazować zależność między $x$ a $y$, tj. $y \\sim h(x)$.\n",
"\n",
2021-03-10 12:14:21 +01:00
"Zacznijmy od najprostszego przypadku, kiedy $h(x)$ jest po prostu funkcją liniową."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ogólny wzór funkcji liniowej to\n",
"$$ h(x) = a \\, x + b $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
2021-03-02 08:32:40 +01:00
"Pamiętajmy jednak, że współczynniki $a$ i $b$ nie są w tej chwili dane z góry – naszym zadaniem właśnie będzie znalezienie takich ich wartości, żeby $h(x)$ było „możliwie jak najbliżej” $y$ (co właściwie oznacza to sformułowanie, wyjaśnię potem).\n",
"\n",
"Poszukiwaną funkcję $h$ będziemy nazywać **funkcją hipotezy**, a jej współczynniki – **parametrami modelu**.\n",
"\n",
"W teorii uczenia maszynowego parametry modelu oznacza się na ogół grecką literą $\\theta$ z odpowiednimi indeksami, dlatego powyższy wzór opisujący liniową funkcję hipotezy zapiszemy jako\n",
"$$ h(x) = \\theta_0 + \\theta_1 x $$\n",
"\n",
"**Parametry modelu** tworzą wektor, który oznaczymy po prostu przez $\\theta$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
2021-03-10 12:14:21 +01:00
"slide_type": "subslide"
2021-03-02 08:32:40 +01:00
}
},
"source": [
"$$ \\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Żeby podkreślić fakt, że funkcja hipotezy zależy od parametrów modelu, będziemy pisać $h_\\theta$ zamiast $h$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x $$"
]
},
{
"cell_type": "markdown",
2021-03-10 12:14:21 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-02 08:32:40 +01:00
"source": [
"Przyjrzyjmy się teraz, jak wyglądają dane, które mamy modelować:"
]
},
2021-03-10 12:14:21 +01:00
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Na poniższym wykresie możesz spróbować ręcznie dopasować parametry modelu $\\theta_0$ i $\\theta_1$ tak, aby jak najlepiej modelowały zależność między $x$ a $y$:"
]
},
2021-03-02 08:32:40 +01:00
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 4,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
2021-03-10 12:14:21 +01:00
"slide_type": "skip"
2021-03-02 08:32:40 +01:00
}
},
2021-03-30 12:12:30 +02:00
"outputs": [],
"source": [
"# Funkcje rysujące wykres kropkowy oraz prostą regresyjną\n",
"\n",
"def regdots(x, y): \n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter(x, y, c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(u'Wielkość miejscowości [dzies. tys. mieszk.]')\n",
" ax.set_ylabel(u'Dochód firmy [dzies. tys. dolarów]')\n",
" ax.margins(.05, .05)\n",
" plt.ylim(min(y) - 1, max(y) + 1)\n",
" plt.xlim(min(x) - 1, max(x) + 1)\n",
" return fig\n",
"\n",
"def regline(fig, fun, theta, x):\n",
" ax = fig.axes[0]\n",
" x0, x1 = min(x), max(x)\n",
" X = [x0, x1]\n",
" Y = [fun(theta, x) for x in X]\n",
" ax.plot(X, Y, linewidth='2',\n",
" label=(r'$y={theta0}{op}{theta1}x$'.format(\n",
" theta0=theta[0],\n",
" theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n",
" op='+' if theta[1] >= 0 else '-')))\n",
"\n",
"def legend(fig):\n",
" ax = fig.axes[0]\n",
" handles, labels = ax.get_legend_handles_labels()\n",
" # try-except block is a fix for a bug in Poly3DCollection\n",
" try:\n",
" fig.legend(handles, labels, fontsize='15', loc='lower right')\n",
" except AttributeError:\n",
" pass"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 5,
2021-03-30 12:12:30 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
2021-03-02 08:32:40 +01:00
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"359.033144pt\" version=\"1.1\" viewBox=\"0 0 662.783125 359.033144\" width=\"662.783125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 359.033144 \r\nL 662.783125 359.033144 \r\nL 662.783125 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 41.003125 320.453144 \r\nL 593.963125 320.453144 \r\nL 593.963125 9.413144 \r\nL 41.003125 9.413144 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"PathCollection_1\">\r\n <defs>\r\n <path d=\"M 0 3.535534 \r\nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \r\nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \r\nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \r\nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \r\nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \r\nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \r\nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \r\nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \r\nz\r\n\" id=\"ma0d3aa991a\" style=\"stroke:#ff0000;\"/>\r\n </defs>\r\n <g clip-path=\"url(#pc5216bb1cc)\">\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"101.074061\" xlink:href=\"#ma0d3aa991a\" y=\"90.928742\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"84.280036\" xlink:href=\"#ma0d3aa991a\" y=\"182.22837\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"170.525313\" xlink:href=\"#ma0d3aa991a\" y=\"133.331958\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"126.827398\" xlink:href=\"#ma0d3aa991a\" y=\"152.839596\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"93.856436\" xlink:href=\"#ma0d3aa991a\" y=\"207.11895\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"166.612282\" xlink:href=\"#ma0d3aa991a\" y=\"152.494328\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"140.472544\" xlink:href=\"#ma0d3aa991a\" y=\"233.823265\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"172.241049\" xlink:href=\"#ma0d3aa991a\" y=\"151.264311\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"111.919241\" xlink:href=\"#ma0d3aa991a\" y=\"209.542299\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"70.637774\" xlink:href=\"#ma0d3aa991a\" y=\"239.560108\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"89.557004\" xlink:href=\"#ma0d3aa991a\" y=\"245.649771\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"333.315473\" xlink:href=\"#ma0d3aa991a\" y=\"113.446684\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"90.228881\" xlink:href=\"#ma0d3aa991a\" y=\"246.697443\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"167.347598\" xlink:href=\"#ma0d3aa991a\" y=\"202.776127\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"87.538492\" xlink:href=\"#ma0d3aa991a\" y=\"273.012469\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"80.003673\" xlink:href=\"#ma0d3aa991a\" y=\"242.836916\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"108.435865\" xlink:href=\"#ma0d3aa991a\" y=\"223.502991\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"72.814884\" xlink:href=\"#ma0d3aa991a\" y=\"274.689284\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"110.287129\" xlink:href=\"#ma0d3aa991a\" y=\"241.338237\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"128.776705\" xlink:href=\"#ma0d3aa991a\" y=\"222.591268\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"103.352096\" xlink:href=\"#ma0d3aa991a\" y=\"246.875472\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"509.387446\" xlink:h
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdots(x,y)\n",
"legend(fig)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 6,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Hipoteza: funkcja liniowa jednej zmiennej\n",
"\n",
"def h(theta, x):\n",
" return theta[0] + theta[1] * x"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 7,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"sliderTheta01 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n",
"sliderTheta11 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n",
"\n",
"def slide1(theta0, theta1):\n",
" fig = regdots(x, y)\n",
" regline(fig, h, [theta0, theta1], x)\n",
" legend(fig)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 8,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-03-17 12:54:29 +01:00
"model_id": "325961a10ee1479cb0657468c6aaf42a",
2021-03-02 08:32:40 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide1(theta0, theta1)>"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 8,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide1, theta0=sliderTheta01, theta1=sliderTheta11)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Skąd wiadomo, że przewidywania modelu (wartości funkcji $h(x)$) zgadzaja się z obserwacjami (wartości $y$)?\n",
"\n",
"Aby to zmierzyć wprowadzimy pojęcie funkcji kosztu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja kosztu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Funkcję kosztu zdefiniujemy w taki sposób, żeby odzwierciedlała ona różnicę między przewidywaniami modelu a obserwacjami.\n",
"\n",
2021-04-21 11:24:35 +02:00
"Jedną z możliwosci jest zdefiniowanie funkcji kosztu jako wartość **błędu średniokwadratowego** (metoda najmniejszych kwadratów, *mean-square error, MSE*).\n",
"\n",
"My zdefiniujemy funkcję kosztu jako *połowę* błędu średniokwadratowego w celu ułatwienia późniejszych obliczeń (obliczenie pochodnej funkcji kosztu w dalszej części wykładu). Możemy tak zrobić, ponieważ $\\frac{1}{2}$ jest stałą, a pomnożenie przez stałą nie wpływa na przebieg zmienności funkcji."
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
2021-03-10 12:14:21 +01:00
"slide_type": "subslide"
2021-03-02 08:32:40 +01:00
}
},
"source": [
"$$ J(\\theta) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"gdzie $m$ jest liczbą wszystkich przykładów (obserwacji), czyli wielkością zbioru danych uczących.\n",
"\n",
"W powyższym wzorze sumujemy kwadraty różnic między przewidywaniami modelu ($h_\\theta \\left( x^{(i)} \\right)$) a obserwacjami ($y^{(i)}$) po wszystkich przykładach $i$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Teraz nasze zadanie sprowadza się do tego, że będziemy szukać takich parametrów $\\theta = \\left[\\begin{array}{c}\\theta_0\\\\ \\theta_1\\end{array}\\right]$, które minimalizują fukcję kosztu $J(\\theta)$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\hat\\theta = \\mathop{\\arg\\min}_{\\theta} J(\\theta) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\theta \\in \\mathbb{R}^2, \\quad J \\colon \\mathbb{R}^2 \\to \\mathbb{R} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Proszę zwrócić uwagę, że dziedziną funkcji kosztu jest zbiór wszystkich możliwych wartości parametrów $\\theta$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ J(\\theta_0, \\theta_1) \\, = \\, \\frac{1}{2m} \\sum_{i = 1}^{m} \\left( \\theta_0 + \\theta_1 x^{(i)} - y^{(i)} \\right) ^2 $$"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 9,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"def J(h, theta, x, y):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 10,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Oblicz wartość funkcji kosztu i pokaż na wykresie\n",
"\n",
"def regline2(fig, fun, theta, xx, yy):\n",
" \"\"\"Rysuj regresję liniową\"\"\"\n",
" ax = fig.axes[0]\n",
" x0, x1 = min(xx), max(xx)\n",
" X = [x0, x1]\n",
" Y = [fun(theta, x) for x in X]\n",
" cost = J(fun, theta, xx, yy)\n",
2021-03-30 12:12:30 +02:00
" ax.plot(X, Y, linewidth=\"2\", \n",
2021-03-02 08:32:40 +01:00
" label=(r'$y={theta0}{op}{theta1}x, \\; J(\\theta)={cost:.3}$'.format(\n",
" theta0=theta[0],\n",
" theta1=(theta[1] if theta[1] >= 0 else -theta[1]),\n",
" op='+' if theta[1] >= 0 else '-',\n",
2021-03-30 12:12:30 +02:00
" cost=str(cost))))\n",
2021-03-02 08:32:40 +01:00
"\n",
"sliderTheta02 = widgets.FloatSlider(min=-10, max=10, step=0.1, value=0, description=r'$\\theta_0$', width=300)\n",
"sliderTheta12 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=0, description=r'$\\theta_1$', width=300)\n",
"\n",
"def slide2(theta0, theta1):\n",
" fig = regdots(x, y)\n",
" regline2(fig, h, [theta0, theta1], x, y)\n",
" legend(fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Poniższy interaktywny wykres pokazuje wartość funkcji kosztu $J(\\theta)$. Czy teraz łatwiej jest dobrać parametry modelu?"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 11,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-03-17 12:54:29 +01:00
"model_id": "912397d07efd4143a75de75ebe15fa1e",
2021-03-02 08:32:40 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='$\\\\theta_0$', max=10.0, min=-10.0), FloatSlider(valu…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide2(theta0, theta1)>"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 11,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide2, theta0=sliderTheta02, theta1=sliderTheta12)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja kosztu jako funkcja zmiennej $\\theta$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
2021-03-19 19:43:40 +01:00
"Funkcja kosztu zdefiniowana jako MSE jest funkcją zmiennej wektorowej $\\theta$, czyli funkcją dwóch zmiennych rzeczywistych: $\\theta_0$ i $\\theta_1$.\n",
2021-03-02 08:32:40 +01:00
" \n",
"Zobaczmy, jak wygląda jej wykres."
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 12,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wykres funkcji kosztu dla ustalonego theta_1=1.0\n",
"\n",
"def costfun(fun, x, y):\n",
" return lambda theta: J(fun, theta, x, y)\n",
"\n",
"def costplot(hypothesis, x, y, theta1=1.0):\n",
2021-03-30 12:12:30 +02:00
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
2021-03-02 08:32:40 +01:00
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.set_xlabel(r'$\\theta_0$')\n",
" ax.set_ylabel(r'$J(\\theta)$')\n",
" j = costfun(hypothesis, x, y)\n",
" fun = lambda theta0: j([theta0, theta1])\n",
" X = np.arange(-10, 10, 0.1)\n",
" Y = [fun(x) for x in X]\n",
" ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta_0, {theta1})$'.format(theta1=theta1)))\n",
" return fig\n",
"\n",
"def slide3(theta1):\n",
" fig = costplot(h, x, y, theta1)\n",
" legend(fig)\n",
"\n",
"sliderTheta13 = widgets.FloatSlider(min=-5, max=5, step=0.1, value=1.0, description=r'$\\theta_1$', width=300)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 13,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-03-17 12:54:29 +01:00
"model_id": "da466a0090f044b7824dcd2976b276e7",
2021-03-02 08:32:40 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=1.0, description='$\\\\theta_1$', max=5.0, min=-5.0), Button(description…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide3(theta1)>"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 13,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide3, theta1=sliderTheta13)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 14,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wykres funkcji kosztu względem theta_0 i theta_1\n",
"\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import pylab\n",
"\n",
"%matplotlib inline\n",
"\n",
"def costplot3d(hypothesis, x, y, show_gradient=False):\n",
2021-03-30 12:12:30 +02:00
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
2021-03-02 08:32:40 +01:00
" ax = fig.add_subplot(111, projection='3d')\n",
" fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)\n",
" ax.set_xlabel(r'$\\theta_0$')\n",
" ax.set_ylabel(r'$\\theta_1$')\n",
" ax.set_zlabel(r'$J(\\theta)$')\n",
" \n",
" j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n",
" X = np.arange(-10, 10.1, 0.1)\n",
" Y = np.arange(-1, 4.1, 0.1)\n",
" X, Y = np.meshgrid(X, Y)\n",
" Z = np.array([[J(hypothesis, [theta0, theta1], x, y) \n",
" for theta0, theta1 in zip(xRow, yRow)] \n",
" for xRow, yRow in zip(X, Y)])\n",
" \n",
" ax.plot_surface(X, Y, Z, rstride=2, cstride=8, linewidth=0.5,\n",
" alpha=0.5, cmap='jet', zorder=0,\n",
" label=r\"$J(\\theta)$\")\n",
" ax.view_init(elev=20., azim=-150)\n",
"\n",
" ax.set_xlim3d(-10, 10);\n",
" ax.set_ylim3d(-1, 4);\n",
" ax.set_zlim3d(-100, 800);\n",
"\n",
" N = range(0, 800, 20)\n",
2021-03-30 12:12:30 +02:00
" plt.contour(X, Y, Z, N, zdir='z', offset=-100, cmap='coolwarm', alpha=1)\n",
2021-03-02 08:32:40 +01:00
" \n",
" ax.plot([-3.89578088] * 2,\n",
" [ 1.19303364] * 2,\n",
" [-100, 4.47697137598], \n",
" color='red', alpha=1, linewidth=1.3, zorder=100, linestyle='dashed',\n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" ax.scatter([-3.89578088] * 2,\n",
" [ 1.19303364] * 2,\n",
" [-100, 4.47697137598], \n",
" c='r', s=80, marker='x', alpha=1, linewidth=1.3, zorder=100, \n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" \n",
" if show_gradient:\n",
" ax.plot([3.0, 1.1],\n",
" [3.0, 2.4],\n",
" [263.0, 125.0], \n",
" color='green', alpha=1, linewidth=1.3, zorder=100)\n",
" ax.scatter([3.0],\n",
" [3.0],\n",
" [263.0], \n",
" c='g', s=30, marker='D', alpha=1, linewidth=1.3, zorder=100)\n",
"\n",
" ax.margins(0,0,0)\n",
" fig.tight_layout()"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 15,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"381.6pt\" version=\"1.1\" viewBox=\"0 0 684 381.6\" width=\"684pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 381.6 \r\nL 684 381.6 \r\nL 684 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"patch_2\">\r\n <path d=\"M 7.2 374.4 \r\nL 676.8 374.4 \r\nL 676.8 7.2 \r\nL 7.2 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"pane3d_1\">\r\n <g id=\"patch_3\">\r\n <path d=\"M 599.679879 267.135424 \r\nL 287.434153 230.869665 \r\nL 285.273708 47.194199 \r\nL 608.536998 75.763272 \r\n\" style=\"fill:#f2f2f2;opacity:0.5;stroke:#f2f2f2;stroke-linejoin:miter;\"/>\r\n </g>\r\n </g>\r\n <g id=\"pane3d_2\">\r\n <g id=\"patch_4\">\r\n <path d=\"M 93.508496 295.93455 \r\nL 287.434153 230.869665 \r\nL 285.273708 47.194199 \r\nL 83.993125 98.515866 \r\n\" style=\"fill:#e6e6e6;opacity:0.5;stroke:#e6e6e6;stroke-linejoin:miter;\"/>\r\n </g>\r\n </g>\r\n <g id=\"pane3d_3\">\r\n <g id=\"patch_5\">\r\n <path d=\"M 423.6079 338.781779 \r\nL 599.679879 267.135424 \r\nL 287.434153 230.869665 \r\nL 93.508496 295.93455 \r\n\" style=\"fill:#ececec;opacity:0.5;stroke:#ececec;stroke-linejoin:miter;\"/>\r\n </g>\r\n </g>\r\n <g id=\"axis3d_1\">\r\n <g id=\"line2d_1\">\r\n <path d=\"M 599.679879 267.135424 \r\nL 423.6079 338.781779 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-width:0.8;\"/>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- $\\theta_0$ -->\r\n <defs>\r\n <path d=\"M 45.515625 34.671875 \r\nL 14.453125 34.671875 \r\nQ 12.359375 20.0625 14.5 13.875 \r\nQ 17.1875 6.25 24.46875 6.25 \r\nQ 31.78125 6.25 37.359375 13.921875 \r\nQ 42.234375 20.65625 45.515625 34.671875 \r\nz\r\nM 47.015625 42.96875 \r\nQ 48.34375 56.84375 46.625 61.71875 \r\nQ 43.953125 69.4375 36.765625 69.4375 \r\nQ 29.296875 69.4375 23.828125 61.8125 \r\nQ 19.53125 55.671875 16.15625 42.96875 \r\nz\r\nM 38.1875 76.765625 \r\nQ 49.90625 76.765625 54.59375 66.40625 \r\nQ 59.28125 56.109375 55.71875 37.84375 \r\nQ 52.203125 19.625 43.453125 9.28125 \r\nQ 34.765625 -1.125 23.046875 -1.125 \r\nQ 11.28125 -1.125 6.640625 9.28125 \r\nQ 2 19.625 5.515625 37.84375 \r\nQ 9.078125 56.109375 17.71875 66.40625 \r\nQ 26.421875 76.765625 38.1875 76.765625 \r\nz\r\n\" id=\"DejaVuSans-Oblique-952\"/>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(542.069841 327.616274)rotate(-22.142152)scale(0.1 -0.1)\">\r\n <use transform=\"translate(0 0.234375)\" xlink:href=\"#DejaVuSans-Oblique-952\"/>\r\n <use transform=\"translate(61.181641 -16.171875)scale(0.7)\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n <g id=\"Line3DCollection_1\">\r\n <path d=\"M 427.436424 337.223895 \r\nL 97.708704 294.525319 \r\nL 88.364948 97.401157 \r\n\" style=\"fill:none;stroke:#b0b0b0;stroke-width:0.8;\"/>\r\n <path d=\"M 450.12442 327.991806 \r\nL 12
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"costplot3d(h, x, y)"
]
},
{
"cell_type": "markdown",
2021-03-10 12:14:21 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-02 08:32:40 +01:00
"source": [
"Na powyższym wykresie poszukiwane minimum funkcji kosztu oznaczone jest czerwonym krzyżykiem.\n",
"\n",
"Możemy też zobaczyć rzut powyższego trójwymiarowego wykresu na płaszczyznę $(\\theta_0, \\theta_1)$ poniżej:"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 16,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def costplot2d(hypothesis, x, y, gradient_values=[], nohead=False):\n",
2021-03-30 12:12:30 +02:00
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
2021-03-02 08:32:40 +01:00
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.set_xlabel(r'$\\theta_0$')\n",
" ax.set_ylabel(r'$\\theta_1$')\n",
" \n",
" j = lambda theta0, theta1: costfun(hypothesis, x, y)([theta0, theta1])\n",
" X = np.arange(-10, 10.1, 0.1)\n",
" Y = np.arange(-1, 4.1, 0.1)\n",
" X, Y = np.meshgrid(X, Y)\n",
" Z = np.array([[J(hypothesis, [theta0, theta1], x, y) \n",
" for theta0, theta1 in zip(xRow, yRow)] \n",
" for xRow, yRow in zip(X, Y)])\n",
" \n",
" N = range(0, 800, 20)\n",
2021-03-30 12:12:30 +02:00
" plt.contour(X, Y, Z, N, cmap='coolwarm', alpha=1)\n",
2021-03-02 08:32:40 +01:00
"\n",
" ax.scatter([-3.89578088], [1.19303364], c='r', s=80, marker='x',\n",
" label=r'minimum: $J(-3.90, 1.19) = 4.48$')\n",
" \n",
" if len(gradient_values) > 0:\n",
" prev_theta = gradient_values[0][1]\n",
" ax.scatter([prev_theta[0]], [prev_theta[1]],\n",
" c='g', s=30, marker='D', zorder=100)\n",
" for cost, theta in gradient_values[1:]:\n",
" dtheta = [theta[0] - prev_theta[0], theta[1] - prev_theta[1]]\n",
" ax.arrow(prev_theta[0], prev_theta[1], dtheta[0], dtheta[1], \n",
" color='green', \n",
" head_width=(0.0 if nohead else 0.1), \n",
" head_length=(0.0 if nohead else 0.2),\n",
" zorder=100)\n",
" prev_theta = theta\n",
" \n",
" return fig"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 17,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"360.619219pt\" version=\"1.1\" viewBox=\"0 0 664.400312 360.619219\" width=\"664.400312pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M -0 360.619219 \r\nL 664.400312 360.619219 \r\nL 664.400312 0 \r\nL -0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 42.620312 322.039219 \r\nL 595.580312 322.039219 \r\nL 595.580312 10.999219 \r\nL 42.620312 10.999219 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"PathCollection_1\">\r\n <defs>\r\n <path d=\"M -4.472136 4.472136 \r\nL 4.472136 -4.472136 \r\nM -4.472136 -4.472136 \r\nL 4.472136 4.472136 \r\n\" id=\"m8207b317f0\" style=\"stroke:#ff0000;stroke-width:1.5;\"/>\r\n </defs>\r\n <g clip-path=\"url(#paa1386775c)\">\r\n <use style=\"fill:#ff0000;stroke:#ff0000;stroke-width:1.5;\" x=\"211.389763\" xlink:href=\"#m8207b317f0\" y=\"185.614982\"/>\r\n </g>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"m64b8cd08f8\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620313\" xlink:href=\"#m64b8cd08f8\" y=\"322.039219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- − 10.0 -->\r\n <defs>\r\n <path d=\"M 10.59375 35.5 \r\nL 73.1875 35.5 \r\nL 73.1875 27.203125 \r\nL 10.59375 27.203125 \r\nz\r\n\" id=\"DejaVuSans-8722\"/>\r\n <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n <path d=\"M 10.6875 12.40625 \r\nL 21 12.40625 \r\nL 21 0 \r\nL 10.6875 0 \r\nz\r\n\" id=\"DejaVuSans-46\"/>\r\n </defs>\r\n <g transform=\"translate(27.297656 336.637656)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-8722\"/>\r\n <use x=\"83.789062\" xlink:href=\"#DejaVuSans-49\"/>\r\n <use x=\"147.412109\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"211.035156\" xlink:href=\"#DejaVuSans-46\"/>\r\n <use x=\"242.822266\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"111.740313\" xlink:href=\"#m64b8cd08f8\" y=\"322.039219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- − 7.5 -->\r\n <defs>\r\n <path d=\"M 8.203125 72.90625 \r\nL 55.078125 72.90625 \r\nL 55.078125 68.703125 \r\nL 28.609375 0 \r\nL 18.3125 0 \r\nL 43.21875 64.59375 \r\nL 8.20312
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = costplot2d(h, x, y)\n",
"legend(fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Cechy funkcji kosztu\n",
"* $J(\\theta)$ jest funkcją wypukłą\n",
"* $J(\\theta)$ posiada tylko jedno minimum lokalne"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.2. Metoda gradientu prostego"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Metoda gradientu prostego\n",
"Metoda znajdowania minimów lokalnych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Idea:\n",
" * Zacznijmy od dowolnego $\\theta$.\n",
" * Zmieniajmy powoli $\\theta$ tak, aby zmniejszać $J(\\theta)$, aż w końcu znajdziemy minimum."
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 18,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"381.6pt\" version=\"1.1\" viewBox=\"0 0 684 381.6\" width=\"684pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 381.6 \r\nL 684 381.6 \r\nL 684 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"patch_2\">\r\n <path d=\"M 7.2 374.4 \r\nL 676.8 374.4 \r\nL 676.8 7.2 \r\nL 7.2 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"pane3d_1\">\r\n <g id=\"patch_3\">\r\n <path d=\"M 599.679879 267.135424 \r\nL 287.434153 230.869665 \r\nL 285.273708 47.194199 \r\nL 608.536998 75.763272 \r\n\" style=\"fill:#f2f2f2;opacity:0.5;stroke:#f2f2f2;stroke-linejoin:miter;\"/>\r\n </g>\r\n </g>\r\n <g id=\"pane3d_2\">\r\n <g id=\"patch_4\">\r\n <path d=\"M 93.508496 295.93455 \r\nL 287.434153 230.869665 \r\nL 285.273708 47.194199 \r\nL 83.993125 98.515866 \r\n\" style=\"fill:#e6e6e6;opacity:0.5;stroke:#e6e6e6;stroke-linejoin:miter;\"/>\r\n </g>\r\n </g>\r\n <g id=\"pane3d_3\">\r\n <g id=\"patch_5\">\r\n <path d=\"M 423.6079 338.781779 \r\nL 599.679879 267.135424 \r\nL 287.434153 230.869665 \r\nL 93.508496 295.93455 \r\n\" style=\"fill:#ececec;opacity:0.5;stroke:#ececec;stroke-linejoin:miter;\"/>\r\n </g>\r\n </g>\r\n <g id=\"axis3d_1\">\r\n <g id=\"line2d_1\">\r\n <path d=\"M 599.679879 267.135424 \r\nL 423.6079 338.781779 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-width:0.8;\"/>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- $\\theta_0$ -->\r\n <defs>\r\n <path d=\"M 45.515625 34.671875 \r\nL 14.453125 34.671875 \r\nQ 12.359375 20.0625 14.5 13.875 \r\nQ 17.1875 6.25 24.46875 6.25 \r\nQ 31.78125 6.25 37.359375 13.921875 \r\nQ 42.234375 20.65625 45.515625 34.671875 \r\nz\r\nM 47.015625 42.96875 \r\nQ 48.34375 56.84375 46.625 61.71875 \r\nQ 43.953125 69.4375 36.765625 69.4375 \r\nQ 29.296875 69.4375 23.828125 61.8125 \r\nQ 19.53125 55.671875 16.15625 42.96875 \r\nz\r\nM 38.1875 76.765625 \r\nQ 49.90625 76.765625 54.59375 66.40625 \r\nQ 59.28125 56.109375 55.71875 37.84375 \r\nQ 52.203125 19.625 43.453125 9.28125 \r\nQ 34.765625 -1.125 23.046875 -1.125 \r\nQ 11.28125 -1.125 6.640625 9.28125 \r\nQ 2 19.625 5.515625 37.84375 \r\nQ 9.078125 56.109375 17.71875 66.40625 \r\nQ 26.421875 76.765625 38.1875 76.765625 \r\nz\r\n\" id=\"DejaVuSans-Oblique-952\"/>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(542.069841 327.616274)rotate(-22.142152)scale(0.1 -0.1)\">\r\n <use transform=\"translate(0 0.234375)\" xlink:href=\"#DejaVuSans-Oblique-952\"/>\r\n <use transform=\"translate(61.181641 -16.171875)scale(0.7)\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n <g id=\"Line3DCollection_1\">\r\n <path d=\"M 427.436424 337.223895 \r\nL 97.708704 294.525319 \r\nL 88.364948 97.401157 \r\n\" style=\"fill:none;stroke:#b0b0b0;stroke-width:0.8;\"/>\r\n <path d=\"M 450.12442 327.991806 \r\nL 12
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"costplot3d(h, x, y, show_gradient=True)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 19,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przykładowe wartości kolejnych przybliżeń (sztuczne)\n",
"\n",
"gv = [[_, [3.0, 3.0]], [_, [2.6, 2.4]], [_, [2.2, 2.0]], [_, [1.6, 1.6]], [_, [0.4, 1.2]]]\n",
"\n",
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"sliderSteps1 = widgets.IntSlider(min=0, max=3, step=1, value=0, description='kroki', width=300)\n",
"\n",
"def slide4(steps):\n",
" costplot2d(h, x, y, gradient_values=gv[:steps+1])"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 20,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-03-10 12:14:21 +01:00
"application/vnd.jupyter.widget-view+json": {
2022-03-17 12:54:29 +01:00
"model_id": "9dace5ce318949d1881ab3a6d5ccf7ce",
2021-03-10 12:14:21 +01:00
"version_major": 2,
"version_minor": 0
},
2021-03-02 08:32:40 +01:00
"text/plain": [
2021-03-10 12:14:21 +01:00
"interactive(children=(IntSlider(value=0, description='kroki', max=3), Output()), _dom_classes=('widget-interac…"
2021-03-02 08:32:40 +01:00
]
},
2021-03-10 12:14:21 +01:00
"metadata": {},
2021-03-02 08:32:40 +01:00
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide4(steps)>"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 20,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact(slide4, steps=sliderSteps1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metoda gradientu prostego\n",
"W każdym kroku będziemy aktualizować parametry $\\theta_j$:\n",
"\n",
"$$ \\theta_j := \\theta_j - \\alpha \\frac{\\partial}{\\partial \\theta_j} J(\\theta) \\quad \\mbox{ dla każdego } j $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
2022-03-04 08:14:16 +01:00
"Współczynnik $\\alpha$ nazywamy **długością kroku** lub **współczynnikiem szybkości uczenia** (*learning rate*)."
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ \\begin{array}{rcl}\n",
"\\dfrac{\\partial}{\\partial \\theta_j} J(\\theta)\n",
" & = & \\dfrac{\\partial}{\\partial \\theta_j} \\dfrac{1}{2m} \\displaystyle\\sum_{i = 1}^{m} \\left( h_{\\theta} \\left( x^{(i)} \\right) - y^{(i)} \\right) ^2 \\\\\n",
" & = & 2 \\cdot \\dfrac{1}{2m} \\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\\\\n",
" & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) - y^{(i)} \\right) \\cdot \\dfrac{\\partial}{\\partial\\theta_j} \\left( \\displaystyle\\sum_{i=0}^n \\theta_i x_i^{(i)} - y^{(i)} \\right)\\\\\n",
" & = & \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta \\left( x^{(i)} \\right) -y^{(i)} \\right) x_j^{(i)} \\\\\n",
"\\end{array} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Czyli dla regresji liniowej jednej zmiennej:\n",
"\n",
"$$ h_\\theta(x) = \\theta_0 + \\theta_1x $$\n",
"\n",
"w każdym kroku będziemy aktualizować:\n",
"\n",
"$$\n",
"\\begin{array}{rcl}\n",
"\\theta_0 & := & \\theta_0 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) \\\\ \n",
"\\theta_1 & := & \\theta_1 - \\alpha \\, \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_\\theta(x^{(i)})-y^{(i)} \\right) x^{(i)}\\\\ \n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"###### Uwaga!\n",
" * W każdym kroku aktualizujemy *jednocześnie* $\\theta_0$ i $\\theta_1$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Kolejne kroki wykonujemy aż uzyskamy zbieżność"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metoda gradientu prostego – implementacja"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 21,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wyświetlanie macierzy w LaTeX-u\n",
"\n",
"def LatexMatrix(matrix):\n",
" ltx = r'\\left[\\begin{array}'\n",
" m, n = matrix.shape\n",
" ltx += '{' + (\"r\" * n) + '}'\n",
" for i in range(m):\n",
" ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n",
" ltx += r'\\end{array}\\right]'\n",
" return ltx"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 22,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"def gradient_descent(h, cost_fun, theta, x, y, alpha, eps):\n",
" current_cost = cost_fun(h, theta, x, y)\n",
" log = [[current_cost, theta]] # log przechowuje wartości kosztu i parametrów\n",
" m = len(y)\n",
" while True:\n",
" new_theta = [\n",
" theta[0] - alpha/float(m) * sum(h(theta, x[i]) - y[i]\n",
" for i in range(m)), \n",
" theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n",
" for i in range(m))]\n",
2022-03-17 12:54:29 +01:00
" theta = new_theta # jednoczesna aktualizacja - używamy zmiennej tymczasowej\n",
2021-03-30 12:12:30 +02:00
" prev_cost = current_cost\n",
" current_cost = cost_fun(h, theta, x, y)\n",
" if current_cost > prev_cost:\n",
2021-03-17 10:44:12 +01:00
" print(\"Zbyt duża długość kroku!\")\n",
" break\n",
2021-03-02 08:32:40 +01:00
" if abs(prev_cost - current_cost) <= eps:\n",
" break \n",
" log.append([current_cost, theta])\n",
" return theta, log"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 23,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/latex": [
2021-03-30 12:12:30 +02:00
"$\\displaystyle \\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}-3.4894 \\\\ 1.1786 \\\\ \\end{array}\\right] \\quad J(\\theta) = 4.7371 \\quad \\textrm{po 22362 iteracjach}$"
2021-03-02 08:32:40 +01:00
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2021-03-30 12:12:30 +02:00
"best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.001, eps=0.0000001)\n",
2021-03-02 08:32:40 +01:00
"\n",
"display(Math(r'\\large\\textrm{Wynik:}\\quad \\theta = ' + \n",
" LatexMatrix(np.matrix(best_theta).reshape(2,1)) + \n",
" (r' \\quad J(\\theta) = %.4f' % log[-1][0]) \n",
" + r' \\quad \\textrm{po %d iteracjach}' % len(log))) "
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 24,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"sliderSteps2 = widgets.IntSlider(min=0, max=500, step=1, value=1, description='kroki', width=300)\n",
"\n",
"def slide5(steps):\n",
" costplot2d(h, x, y, gradient_values=log[:steps+1], nohead=True)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 25,
2021-03-02 08:32:40 +01:00
"metadata": {
2021-03-17 10:44:12 +01:00
"scrolled": true,
2021-03-02 08:32:40 +01:00
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-03-17 12:54:29 +01:00
"model_id": "2bb606269b6842d7b746811f66405f21",
2021-03-02 08:32:40 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(IntSlider(value=1, description='kroki', max=500), Button(description='Run Interact', sty…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide5(steps)>"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 25,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide5, steps=sliderSteps2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Współczynnik $\\alpha$ (długość kroku)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Tempo zbieżności metody gradientu prostego możemy regulować za pomocą parametru $\\alpha$, pamiętając, że:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Jeżeli długość kroku jest zbyt mała, algorytm może działać zbyt wolno."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Jeżeli długość kroku jest zbyt duża, algorytm może nie być zbieżny."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.3. Predykcja wyników"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zbudowaliśmy model, dzięki któremu wiemy, jaka jest zależność między dochodem firmy transportowej ($y$) a ludnością miasta ($x$).\n",
"\n",
"Wróćmy teraz do postawionego na początku wykładu pytania: jak przewidzieć dochód firmy transportowej w mieście o danej wielkości?\n",
"\n",
"Odpowiedź polega po prostu na zastosowaniu funkcji $h$ z wyznaczonymi w poprzednim kroku parametrami $\\theta$.\n",
"\n",
"Na przykład, jeżeli miasto ma $536\\,000$ ludności, to $x = 53.6$ (bo dane trenujące były wyrażone w dziesiątkach tysięcy mieszkańców, a $536\\,000 = 53.6 \\cdot 10\\,000$) i możemy użyć znalezionych parametrów $\\theta$, by wykonać następujące obliczenia:\n",
"$$ \\hat{y} \\, = \\, h_\\theta(x) \\, = \\, \\theta_0 + \\theta_1 \\, x \\, = \\, 0.0494 + 0.7591 \\cdot 53.6 \\, = \\, 40.7359 $$\n",
"\n",
"Czyli używając zdefiniowanych wcześniej funkcji:"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 26,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 12:12:30 +02:00
"59.68111013077243\n"
2021-03-02 08:32:40 +01:00
]
}
],
"source": [
"example_x = 53.6\n",
"predicted_y = h(best_theta, example_x)\n",
"print(predicted_y) ## taki jest przewidywany dochód tej firmy transportowej w 536-tysięcznym mieście"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.4. Ewaluacja modelu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Jak ocenić jakość stworzonego przez nas modelu?\n",
"\n",
" * Trzeba sprawdzić, jak przewidywania modelu zgadzają się z oczekiwaniami!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Czy możemy w tym celu użyć danych, których użyliśmy do wytrenowania modelu?\n",
"**NIE!**\n",
"\n",
" * Istotą uczenia maszynowego jest budowanie modeli/algorytmów, które dają dobre przewidywania dla **nieznanych** danych – takich, z którymi algorytm nie miał jeszcze styczności! Nie sztuką jest przewidywać rzeczy, które juz sie zna.\n",
" * Dlatego testowanie/ewaluowanie modelu na zbiorze uczącym mija się z celem i jest nieprzydatne.\n",
" * Do ewaluacji modelu należy użyć oddzielnego zbioru danych.\n",
" * **Dane uczące i dane testowe zawsze powinny stanowić oddzielne zbiory!**"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Na wykładzie *5. Dobre praktyki w uczeniu maszynowym* dowiesz się, jak podzielić posiadane dane na zbiór uczący i zbiór testowy.\n",
"\n",
"Tutaj, na razie, do ewaluacji użyjemy specjalnie przygotowanego zbioru testowego."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
2021-03-19 19:43:40 +01:00
"Jako metrykę ewaluacji wykorzystamy znany nam już błąd średniokwadratowy (MSE):"
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 27,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-02 08:32:40 +01:00
"outputs": [],
"source": [
2021-03-19 19:43:40 +01:00
"def mse(expected, predicted):\n",
2021-03-02 08:32:40 +01:00
" \"\"\"Błąd średniokwadratowy\"\"\"\n",
" m = len(expected)\n",
" if len(predicted) != m:\n",
" raise Exception('Wektory mają różne długości!')\n",
" return 1.0 / (2 * m) * sum((expected[i] - predicted[i])**2 for i in range(m))"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 28,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 12:12:30 +02:00
"3.4988278621350606\n"
2021-03-02 08:32:40 +01:00
]
}
],
"source": [
"# Wczytwanie danych testowych z pliku za pomocą numpy\n",
"\n",
"test_data = np.loadtxt('data01_test.csv', delimiter=',')\n",
"x_test = test_data[:, 0]\n",
"y_test = test_data[:, 1]\n",
"\n",
"# Obliczenie przewidywań modelu\n",
"y_pred = h(best_theta, x_test)\n",
"\n",
2021-03-19 19:43:40 +01:00
"# Obliczenie MSE na zbiorze testowym (im mniejszy MSE, tym lepiej!)\n",
"evaluation_result = mse(y_test, y_pred)\n",
2021-03-02 08:32:40 +01:00
"\n",
"print(evaluation_result)"
]
},
{
"cell_type": "markdown",
2021-03-10 12:14:21 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-02 08:32:40 +01:00
"source": [
"Otrzymana wartość mówi nam o tym, jak dobry jest stworzony przez nas model.\n",
"\n",
2021-03-19 19:43:40 +01:00
"W przypadku metryki MSE im mniejsza wartość, tym lepiej.\n",
2021-03-02 08:32:40 +01:00
"\n",
"W ten sposób możemy np. porównywać różne modele."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.5. Regresja liniowa wielu zmiennych"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Do przewidywania wartości $y$ możemy użyć więcej niż jednej cechy $x$:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Przykład – ceny mieszkań"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 29,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
2021-03-10 12:14:21 +01:00
"name": "stdout",
"output_type": "stream",
"text": [
"y : price x1: isNew x2: rooms x3: floor x4: location x5: sqrMetres\n",
"476118.0 False 3 1 Centrum 78 \n",
"459531.0 False 3 2 Sołacz 62 \n",
"411557.0 False 3 0 Sołacz 15 \n",
"496416.0 False 4 0 Sołacz 14 \n",
"406032.0 False 3 0 Sołacz 15 \n",
"450026.0 False 3 1 Naramowice 80 \n",
"571229.15 False 2 4 Wilda 39 \n",
"325000.0 False 3 1 Grunwald 54 \n",
"268229.0 False 2 1 Grunwald 90 \n"
2021-03-02 08:32:40 +01:00
]
}
],
"source": [
"import csv\n",
"\n",
2021-03-10 12:14:21 +01:00
"reader = csv.reader(open('data02_train.tsv', encoding='utf-8'), delimiter='\\t')\n",
2021-03-02 08:32:40 +01:00
"for i, row in enumerate(list(reader)[:10]):\n",
" if i == 0:\n",
" print(' '.join(['{}: {:8}'.format('x' + str(j) if j > 0 else 'y ', entry)\n",
" for j, entry in enumerate(row)]))\n",
" else:\n",
" print(' '.join(['{:12}'.format(entry) for entry in row]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ x^{(2)} = ({\\rm \"False\"}, 3, 2, {\\rm \"Sołacz\"}, 62), \\quad x_3^{(2)} = 2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Hipoteza"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"W naszym przypadku (wybraliśmy 5 cech):\n",
"\n",
"$$ h_\\theta(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"W ogólności ($n$ cech):\n",
"\n",
"$$ h_\\theta(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\ldots + \\theta_n x_n $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Jeżeli zdefiniujemy $x_0 = 1$, będziemy mogli powyższy wzór zapisać w bardziej kompaktowy sposób:\n",
"\n",
"$$\n",
"\\begin{array}{rcl}\n",
"h_\\theta(x)\n",
" & = & \\theta_0 x_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\ldots + \\theta_n x_n \\\\\n",
" & = & \\displaystyle\\sum_{i=0}^{n} \\theta_i x_i \\\\\n",
" & = & \\theta^T \\, x \\\\\n",
" & = & x^T \\, \\theta \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metoda gradientu prostego – notacja macierzowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Metoda gradientu prostego przyjmie bardzo elegancką formę, jeżeli do jej zapisu użyjemy wektorów i macierzy."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$\n",
"X=\\left[\\begin{array}{cc}\n",
"1 & \\left( \\vec x^{(1)} \\right)^T \\\\\n",
"1 & \\left( \\vec x^{(2)} \\right)^T \\\\\n",
"\\vdots & \\vdots\\\\\n",
"1 & \\left( \\vec x^{(m)} \\right)^T \\\\\n",
"\\end{array}\\right] \n",
"= \\left[\\begin{array}{cccc}\n",
"1 & x_1^{(1)} & \\cdots & x_n^{(1)} \\\\\n",
"1 & x_1^{(2)} & \\cdots & x_n^{(2)} \\\\\n",
"\\vdots & \\vdots & \\ddots & \\vdots\\\\\n",
"1 & x_1^{(m)} & \\cdots & x_n^{(m)} \\\\\n",
"\\end{array}\\right]\n",
"\\quad\n",
"\\vec{y} = \n",
"\\left[\\begin{array}{c}\n",
"y^{(1)}\\\\\n",
"y^{(2)}\\\\\n",
"\\vdots\\\\\n",
"y^{(m)}\\\\\n",
"\\end{array}\\right]\n",
"\\quad\n",
"\\theta = \\left[\\begin{array}{c}\n",
"\\theta_0\\\\\n",
"\\theta_1\\\\\n",
"\\vdots\\\\\n",
"\\theta_n\\\\\n",
"\\end{array}\\right]\n",
"$$"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 30,
2021-03-10 12:14:21 +01:00
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
2021-03-02 08:32:40 +01:00
"outputs": [],
"source": [
"# Wersje macierzowe funkcji rysowania wykresów punktowych oraz krzywej regresyjnej\n",
"\n",
"def hMx(theta, X):\n",
" return X * theta\n",
"\n",
"def regdotsMx(X, y): \n",
2022-03-04 08:14:16 +01:00
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
2021-03-02 08:32:40 +01:00
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel('Populacja')\n",
" ax.set_ylabel('Zysk')\n",
" ax.margins(.05, .05)\n",
2022-03-04 08:14:16 +01:00
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
2021-03-02 08:32:40 +01:00
" return fig\n",
"\n",
"def reglineMx(fig, fun, theta, X):\n",
" ax = fig.axes[0]\n",
" x0, x1 = np.min(X[:, 1]), np.max(X[:, 1])\n",
" L = [x0, x1]\n",
" LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n",
" ax.plot(L, fun(theta, LX), linewidth='2',\n",
" label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n",
" theta0=float(theta[0][0]),\n",
" theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n",
" op='+' if theta[1][0] >= 0 else '-')))"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 31,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 1. 3. 1. 78.]\n",
" [ 1. 3. 2. 62.]\n",
" [ 1. 3. 0. 15.]\n",
" [ 1. 4. 0. 14.]\n",
" [ 1. 3. 0. 15.]]\n",
"(1339, 4)\n",
"\n",
"[[476118.]\n",
" [459531.]\n",
" [411557.]\n",
" [496416.]\n",
" [406032.]]\n",
"(1339, 1)\n"
]
}
],
2021-03-02 08:32:40 +01:00
"source": [
"# Wczytwanie danych z pliku za pomocą numpy – regresja liniowa wielu zmiennych – notacja macierzowa\n",
"\n",
"import pandas\n",
"\n",
"data = pandas.read_csv('data02_train.tsv', delimiter='\\t', usecols=['price', 'rooms', 'floor', 'sqrMetres'])\n",
"m, n_plus_1 = data.values.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data.values[:, 1:].reshape(m, n)\n",
"\n",
"# Dodaj kolumnę jedynek do macierzy\n",
"XMx = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx = np.matrix(data.values[:, 0]).reshape(m, 1)\n",
"\n",
"print(XMx[:5])\n",
"print(XMx.shape)\n",
"\n",
"print()\n",
"\n",
"print(yMx[:5])\n",
"print(yMx.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Funkcja kosztu – notacja macierzowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$J(\\theta)=\\dfrac{1}{2|\\vec y|}\\left(X\\theta-\\vec{y}\\right)^T\\left(X\\theta-\\vec{y}\\right)$$ \n"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 32,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\Large J(\\theta) = 85104141370.9717$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"from IPython.display import display, Math, Latex\n",
"\n",
"def JMx(theta,X,y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y) . T * ( X * theta - y))\n",
" return J.item()\n",
"\n",
"thetaMx = np.matrix([10, 90, -1, 2.5]).reshape(4, 1) \n",
"\n",
"cost = JMx(thetaMx,XMx,yMx) \n",
"display(Math(r'\\Large J(\\theta) = %.4f' % cost))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Gradient – notacja macierzowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T\\left(X\\theta-\\vec y\\right)$$"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 33,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\large \\theta = \\left[\\begin{array}{r}10.0000 \\\\ 90.0000 \\\\ -1.0000 \\\\ 2.5000 \\\\ \\end{array}\\right]\\quad\\large \\nabla J(\\theta) = \\left[\\begin{array}{r}-373492.7442 \\\\ -1075656.5086 \\\\ -989554.4921 \\\\ -23806475.6561 \\\\ \\end{array}\\right]$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"from IPython.display import display, Math, Latex\n",
"\n",
"def dJMx(theta,X,y):\n",
" \"\"\"Wersja macierzowa gradientu funckji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"thetaMx = np.matrix([10, 90, -1, 2.5]).reshape(4, 1) \n",
"\n",
"display(Math(r'\\large \\theta = ' + LatexMatrix(thetaMx) + \n",
" r'\\quad' + r'\\large \\nabla J(\\theta) = ' \n",
" + LatexMatrix(dJMx(thetaMx,XMx,yMx))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Algorytm gradientu prostego – notacja macierzowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$ \\theta := \\theta - \\alpha \\, \\nabla J(\\theta) $$"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 34,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}17446.2104 \\\\ 86476.7968 \\\\ -1374.8949 \\\\ 2165.0689 \\\\ \\end{array}\\right] \\quad J(\\theta) = 10324864803.0591 \\quad \\textrm{po 374576 iteracjach}$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n",
"\n",
"def GDMx(fJ, fdJ, theta, X, y, alpha, eps):\n",
" current_cost = fJ(theta, X, y)\n",
" log = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" if current_cost > prev_cost:\n",
" print('Długość kroku (alpha) jest zbyt duża!')\n",
" break\n",
" log.append([current_cost, theta])\n",
" return theta, log\n",
"\n",
"thetaStartMx = np.zeros((n + 1, 1))\n",
"\n",
"# Zmieniamy wartości alpha (rozmiar kroku) oraz eps (kryterium stopu)\n",
"thetaBestMx, log = GDMx(JMx, dJMx, thetaStartMx, \n",
" XMx, yMx, alpha=0.0001, eps=0.1)\n",
"\n",
"######################################################################\n",
"display(Math(r'\\large\\textrm{Wynik:}\\quad \\theta = ' + \n",
" LatexMatrix(thetaBestMx) + \n",
" (r' \\quad J(\\theta) = %.4f' % log[-1][0]) \n",
" + r' \\quad \\textrm{po %d iteracjach}' % len(log))) "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.6. Metoda gradientu prostego w praktyce"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Kryterium stopu"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Algorytm gradientu prostego polega na wykonywaniu określonych kroków w pętli. Pytanie brzmi: kiedy należy zatrzymać wykonywanie tej pętli?\n",
"\n",
"W każdej kolejnej iteracji wartość funkcji kosztu maleje o coraz mniejszą wartość.\n",
"Parametr `eps` określa, jaka wartość graniczna tej różnicy jest dla nas wystarczająca:\n",
"\n",
" * Im mniejsza wartość `eps`, tym dokładniejszy wynik, ale dłuższy czas działania algorytmu.\n",
" * Im większa wartość `eps`, tym krótszy czas działania algorytmu, ale mniej dokładny wynik."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Na wykresie zobaczymy porównanie regresji dla różnych wartości `eps`"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 35,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
2021-03-10 12:14:21 +01:00
},
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"359.033144pt\" version=\"1.1\" viewBox=\"0 0 662.383125 359.033144\" width=\"662.383125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M -0 359.033144 \r\nL 662.383125 359.033144 \r\nL 662.383125 0 \r\nL -0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 40.603125 320.453144 \r\nL 593.563125 320.453144 \r\nL 593.563125 9.413144 \r\nL 40.603125 9.413144 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"PathCollection_1\">\r\n <defs>\r\n <path d=\"M 0 3.535534 \r\nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \r\nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \r\nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \r\nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \r\nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \r\nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \r\nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \r\nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \r\nz\r\n\" id=\"mdff5033756\" style=\"stroke:#ff0000;\"/>\r\n </defs>\r\n <g clip-path=\"url(#p9b87104e3c)\">\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"100.674061\" xlink:href=\"#mdff5033756\" y=\"90.928742\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"83.880036\" xlink:href=\"#mdff5033756\" y=\"182.22837\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"170.125313\" xlink:href=\"#mdff5033756\" y=\"133.331958\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"126.427398\" xlink:href=\"#mdff5033756\" y=\"152.839596\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"93.456436\" xlink:href=\"#mdff5033756\" y=\"207.11895\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"166.212282\" xlink:href=\"#mdff5033756\" y=\"152.494328\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"140.072544\" xlink:href=\"#mdff5033756\" y=\"233.823265\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"171.841049\" xlink:href=\"#mdff5033756\" y=\"151.264311\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"111.519241\" xlink:href=\"#mdff5033756\" y=\"209.542299\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"70.237774\" xlink:href=\"#mdff5033756\" y=\"239.560108\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"89.157004\" xlink:href=\"#mdff5033756\" y=\"245.649771\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"332.915473\" xlink:href=\"#mdff5033756\" y=\"113.446684\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"89.828881\" xlink:href=\"#mdff5033756\" y=\"246.697443\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"166.947598\" xlink:href=\"#mdff5033756\" y=\"202.776127\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"87.138492\" xlink:href=\"#mdff5033756\" y=\"273.012469\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"79.603673\" xlink:href=\"#mdff5033756\" y=\"242.836916\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"108.035865\" xlink:href=\"#mdff5033756\" y=\"223.502991\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"72.414884\" xlink:href=\"#mdff5033756\" y=\"274.689284\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"109.887129\" xlink:href=\"#mdff5033756\" y=\"241.338237\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"128.376705\" xlink:href=\"#mdff5033756\" y=\"222.591268\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"102.952096\" xlink:href=\"#mdff5033756\" y=\"246.875472\"/>\r\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"508.987446\" xlink
2021-03-10 12:14:21 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"# Wczytwanie danych z pliku za pomocą numpy – wersja macierzowa\n",
"data = np.loadtxt('data01_train.csv', delimiter=',')\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n].reshape(m, n)\n",
"\n",
"# Dodaj kolumnę jedynek do macierzy\n",
"XMx = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx = np.matrix(data[:, 1]).reshape(m, 1)\n",
"\n",
"thetaStartMx = np.zeros((2, 1))\n",
"\n",
"fig = regdotsMx(XMx, yMx)\n",
"theta_e1, log1 = GDMx(JMx, dJMx, thetaStartMx, XMx, yMx, alpha=0.01, eps=0.01) # niebieska linia\n",
"reglineMx(fig, hMx, theta_e1, XMx)\n",
"theta_e2, log2 = GDMx(JMx, dJMx, thetaStartMx, XMx, yMx, alpha=0.01, eps=0.000001) # pomarańczowa linia\n",
"reglineMx(fig, hMx, theta_e2, XMx)\n",
"legend(fig)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 36,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\theta_{10^{-2}} = \\left[\\begin{array}{r}0.0531 \\\\ 0.8365 \\\\ \\end{array}\\right]\\quad\\theta_{10^{-6}} = \\left[\\begin{array}{r}-3.4895 \\\\ 1.1786 \\\\ \\end{array}\\right]$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"display(Math(r'\\theta_{10^{-2}} = ' + LatexMatrix(theta_e1) +\n",
" r'\\quad\\theta_{10^{-6}} = ' + LatexMatrix(theta_e2)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Długość kroku ($\\alpha$)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 37,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Jak zmienia się koszt w kolejnych krokach w zależności od alfa\n",
"\n",
"def costchangeplot(logs):\n",
2022-03-04 08:14:16 +01:00
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
2021-03-02 08:32:40 +01:00
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.set_xlabel('krok')\n",
" ax.set_ylabel(r'$J(\\theta)$')\n",
"\n",
" X = np.arange(0, 500, 1)\n",
" Y = [logs[step][0] for step in X]\n",
" ax.plot(X, Y, linewidth='2', label=(r'$J(\\theta)$'))\n",
" return fig\n",
"\n",
"def slide7(alpha):\n",
" best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=alpha, eps=0.0001)\n",
" fig = costchangeplot(log)\n",
" legend(fig)\n",
"\n",
"sliderAlpha1 = widgets.FloatSlider(min=0.01, max=0.03, step=0.001, value=0.02, description=r'$\\alpha$', width=300)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 38,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-03-17 12:54:29 +01:00
"model_id": "db68a57ec7514bb3ba51b0dd7729f575",
2021-03-10 12:14:21 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.02, description='$\\\\alpha$', max=0.03, min=0.01, step=0.001), Button…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide7(alpha)>"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 38,
2021-03-10 12:14:21 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"widgets.interact_manual(slide7, alpha=sliderAlpha1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 2.7. Normalizacja danych"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Normalizacja danych to proces, który polega na dostosowaniu danych wejściowych w taki sposób, żeby ułatwić działanie algorytmowi gradientu prostego.\n",
"\n",
"Wyjaśnię to na przykladzie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Użyjemy danych z „Gratka flats challenge 2017”.\n",
"\n",
"Rozważmy model $h(x) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2$, w którym cena mieszkania prognozowana jest na podstawie liczby pokoi $x_1$ i metrażu $x_2$:"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 39,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>price</th>\n",
" <th>rooms</th>\n",
" <th>sqrMetres</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>476118.00</td>\n",
" <td>3</td>\n",
" <td>78</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>459531.00</td>\n",
" <td>3</td>\n",
" <td>62</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>411557.00</td>\n",
" <td>3</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>496416.00</td>\n",
" <td>4</td>\n",
" <td>14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>406032.00</td>\n",
" <td>3</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>450026.00</td>\n",
" <td>3</td>\n",
" <td>80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>571229.15</td>\n",
" <td>2</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>325000.00</td>\n",
" <td>3</td>\n",
" <td>54</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>268229.00</td>\n",
" <td>2</td>\n",
" <td>90</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>604836.00</td>\n",
" <td>4</td>\n",
" <td>40</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" price rooms sqrMetres\n",
"0 476118.00 3 78\n",
"1 459531.00 3 62\n",
"2 411557.00 3 15\n",
"3 496416.00 4 14\n",
"4 406032.00 3 15\n",
"5 450026.00 3 80\n",
"6 571229.15 2 39\n",
"7 325000.00 3 54\n",
"8 268229.00 2 90\n",
"9 604836.00 4 40"
]
},
2022-03-17 12:54:29 +01:00
"execution_count": 39,
2021-03-10 12:14:21 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"# Wczytanie danych przy pomocy biblioteki pandas\n",
"import pandas\n",
"alldata = pandas.read_csv('data_flats.tsv', header=0, sep='\\t',\n",
" usecols=['price', 'rooms', 'sqrMetres'])\n",
"alldata[:10]"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 40,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Funkcja, która pokazuje wartości minimalne i maksymalne w macierzy X\n",
"\n",
"def show_mins_and_maxs(XMx):\n",
" mins = np.amin(XMx, axis=0).tolist()[0] # wartości minimalne\n",
" maxs = np.amax(XMx, axis=0).tolist()[0] # wartości maksymalne\n",
" for i, (xmin, xmax) in enumerate(zip(mins, maxs)):\n",
" display(Math(\n",
" r'${:.2F} \\leq x_{} \\leq {:.2F}$'.format(xmin, i, xmax)))"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 41,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych\n",
"\n",
"import numpy as np\n",
"\n",
"%matplotlib inline\n",
"\n",
"data2 = np.matrix(alldata[['rooms', 'sqrMetres', 'price']])\n",
"\n",
"m, n_plus_1 = data2.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data2[:, 0:n]\n",
"\n",
"XMx2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx2 = np.matrix(data2[:, -1]).reshape(m, 1) / 1000.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Cechy w danych treningowych przyjmują wartości z zakresu:"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 42,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.00 \\leq x_0 \\leq 1.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/latex": [
"$\\displaystyle 2.00 \\leq x_1 \\leq 7.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/latex": [
"$\\displaystyle 12.00 \\leq x_2 \\leq 196.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"show_mins_and_maxs(XMx2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Jak widzimy, $x_2$ przyjmuje wartości dużo większe niż $x_1$.\n",
"Powoduje to, że wykres funkcji kosztu jest bardzo „spłaszczony” wzdłuż jednej z osi:"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 43,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def contour_plot(X, y, rescale=10**8):\n",
" theta0_vals = np.linspace(-100000, 100000, 100)\n",
" theta1_vals = np.linspace(-100000, 100000, 100)\n",
"\n",
" J_vals = np.zeros(shape=(theta0_vals.size, theta1_vals.size))\n",
" for t1, element in enumerate(theta0_vals):\n",
" for t2, element2 in enumerate(theta1_vals):\n",
" thetaT = np.matrix([1.0, element, element2]).reshape(3,1)\n",
" J_vals[t1, t2] = JMx(thetaT, X, y) / rescale\n",
" \n",
2022-03-04 08:14:16 +01:00
" plt.figure()\n",
" plt.contour(theta0_vals, theta1_vals, J_vals.T, np.logspace(-2, 3, 20))\n",
" plt.xlabel(r'$\\theta_1$')\n",
" plt.ylabel(r'$\\theta_2$')"
2021-03-02 08:32:40 +01:00
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 44,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
2021-03-10 12:14:21 +01:00
},
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"265.995469pt\" version=\"1.1\" viewBox=\"0 0 435.520312 265.995469\" width=\"435.520312pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 265.995469 \r\nL 435.520312 265.995469 \r\nL 435.520312 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 74.432812 228.439219 \r\nL 409.232813 228.439219 \r\nL 409.232813 10.999219 \r\nL 74.432812 10.999219 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"m2ea7d8b0f0\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"74.432812\" xlink:href=\"#m2ea7d8b0f0\" y=\"228.439219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- − 100000 -->\r\n <defs>\r\n <path d=\"M 10.59375 35.5 \r\nL 73.1875 35.5 \r\nL 73.1875 27.203125 \r\nL 10.59375 27.203125 \r\nz\r\n\" id=\"DejaVuSans-8722\"/>\r\n <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(51.155469 243.037656)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-8722\"/>\r\n <use x=\"83.789062\" xlink:href=\"#DejaVuSans-49\"/>\r\n <use x=\"147.412109\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"211.035156\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"274.658203\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"338.28125\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"401.904297\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"116.282812\" xlink:href=\"#m2ea7d8b0f0\" y=\"228.439219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- − 75000 -->\r\n <defs>\r\n <path d=\"M 8.203125 72.90625 \r\nL 55.078125 72.90625 \r\nL 55.078125 68.703125 \r\nL 28.609375 0 \r\nL 18.3125 0 \r\nL 43.21875 64.59375 \r\nL 8.203125 64.59375 \r\nz\r\n\" id=\"DejaVuSans-55\"/>\r\n <path d=\"M 10.796875 72.90625 \r\nL 49.515625 72.90625 \r\nL 49.515625 64.59375 \r\nL 19.828125 64.59375 \r\nL 19.828125 46.734375 \r\nQ 21.96875 47.46875 24.109375 47.828125 \r\nQ 26.265625 48.1875 28.421875 48.1875 \r\nQ 40.625 48.1875 47.75 41.5 \r\nQ 54.890625 34.8125 54.890625 23.390625 \r\nQ 54.890625 11.625 47.5625 5.09375 \r\nQ 40.234375 -1.421875 26.90625 -1.421875 \r\
2021-03-10 12:14:21 +01:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"contour_plot(XMx2, yMx2, rescale=10**10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jeżeli funkcja kosztu ma kształt taki, jak na powyższym wykresie, to łatwo sobie wyobrazić, że znalezienie minimum lokalnego przy użyciu metody gradientu prostego musi stanowć nie lada wyzwanie: algorytm szybko znajdzie „rynnę”, ale „zjazd” wzdłuż „rynny” w poszukiwaniu minimum będzie odbywał się bardzo powoli.\n",
"\n",
"Jak temu zaradzić?\n",
"\n",
"Spróbujemy przekształcić dane tak, żeby funkcja kosztu miała „ładny”, regularny kształt."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Skalowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Będziemy dążyć do tego, żeby każda z cech przyjmowała wartości w podobnym zakresie.\n",
"\n",
"W tym celu przeskalujemy wartości każdej z cech, dzieląc je przez wartość maksymalną:\n",
"\n",
"$$ \\hat{x_i}^{(j)} := \\frac{x_i^{(j)}}{\\max_j x_i^{(j)}} $$"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 45,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.00 \\leq x_0 \\leq 1.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/latex": [
"$\\displaystyle 0.29 \\leq x_1 \\leq 1.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/latex": [
"$\\displaystyle 0.06 \\leq x_2 \\leq 1.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"XMx2_scaled = XMx2 / np.amax(XMx2, axis=0)\n",
"\n",
"show_mins_and_maxs(XMx2_scaled)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 46,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
2021-03-10 12:14:21 +01:00
},
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"265.995469pt\" version=\"1.1\" viewBox=\"0 0 435.520312 265.995469\" width=\"435.520312pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 265.995469 \r\nL 435.520312 265.995469 \r\nL 435.520312 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 74.432812 228.439219 \r\nL 409.232813 228.439219 \r\nL 409.232813 10.999219 \r\nL 74.432812 10.999219 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"m5ee61859d7\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"74.432812\" xlink:href=\"#m5ee61859d7\" y=\"228.439219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- − 100000 -->\r\n <defs>\r\n <path d=\"M 10.59375 35.5 \r\nL 73.1875 35.5 \r\nL 73.1875 27.203125 \r\nL 10.59375 27.203125 \r\nz\r\n\" id=\"DejaVuSans-8722\"/>\r\n <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(51.155469 243.037656)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-8722\"/>\r\n <use x=\"83.789062\" xlink:href=\"#DejaVuSans-49\"/>\r\n <use x=\"147.412109\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"211.035156\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"274.658203\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"338.28125\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"401.904297\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"116.282812\" xlink:href=\"#m5ee61859d7\" y=\"228.439219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- − 75000 -->\r\n <defs>\r\n <path d=\"M 8.203125 72.90625 \r\nL 55.078125 72.90625 \r\nL 55.078125 68.703125 \r\nL 28.609375 0 \r\nL 18.3125 0 \r\nL 43.21875 64.59375 \r\nL 8.203125 64.59375 \r\nz\r\n\" id=\"DejaVuSans-55\"/>\r\n <path d=\"M 10.796875 72.90625 \r\nL 49.515625 72.90625 \r\nL 49.515625 64.59375 \r\nL 19.828125 64.59375 \r\nL 19.828125 46.734375 \r\nQ 21.96875 47.46875 24.109375 47.828125 \r\nQ 26.265625 48.1875 28.421875 48.1875 \r\nQ 40.625 48.1875 47.75 41.5 \r\nQ 54.890625 34.8125 54.890625 23.390625 \r\nQ 54.890625 11.625 47.5625 5.09375 \r\nQ 40.234375 -1.421875 26.90625 -1.421875 \r\
2021-03-10 12:14:21 +01:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"contour_plot(XMx2_scaled, yMx2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Normalizacja średniej"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Będziemy dążyć do tego, żeby dodatkowo średnia wartość każdej z cech była w okolicach $0$.\n",
"\n",
"W tym celu oprócz przeskalowania odejmiemy wartość średniej od wartości każdej z cech:\n",
"\n",
"$$ \\hat{x_i}^{(j)} := \\frac{x_i^{(j)} - \\mu_i}{\\max_j x_i^{(j)}} $$"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 47,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 0.00 \\leq x_0 \\leq 0.00$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/latex": [
"$\\displaystyle -0.10 \\leq x_1 \\leq 0.62$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/latex": [
"$\\displaystyle -0.23 \\leq x_2 \\leq 0.70$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"XMx2_norm = (XMx2 - np.mean(XMx2, axis=0)) / np.amax(XMx2, axis=0)\n",
"\n",
"show_mins_and_maxs(XMx2_norm)"
]
},
{
"cell_type": "code",
2022-03-17 12:54:29 +01:00
"execution_count": 48,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
2021-03-10 12:14:21 +01:00
"outputs": [
{
"data": {
2022-03-17 12:54:29 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"265.995469pt\" version=\"1.1\" viewBox=\"0 0 435.520312 265.995469\" width=\"435.520312pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <defs>\r\n <style type=\"text/css\">\r\n*{stroke-linecap:butt;stroke-linejoin:round;}\r\n </style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 265.995469 \r\nL 435.520312 265.995469 \r\nL 435.520312 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 74.432812 228.439219 \r\nL 409.232813 228.439219 \r\nL 409.232813 10.999219 \r\nL 74.432812 10.999219 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"mcb373ee736\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n </defs>\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"74.432812\" xlink:href=\"#mcb373ee736\" y=\"228.439219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_1\">\r\n <!-- − 100000 -->\r\n <defs>\r\n <path d=\"M 10.59375 35.5 \r\nL 73.1875 35.5 \r\nL 73.1875 27.203125 \r\nL 10.59375 27.203125 \r\nz\r\n\" id=\"DejaVuSans-8722\"/>\r\n <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n </defs>\r\n <g transform=\"translate(51.155469 243.037656)scale(0.1 -0.1)\">\r\n <use xlink:href=\"#DejaVuSans-8722\"/>\r\n <use x=\"83.789062\" xlink:href=\"#DejaVuSans-49\"/>\r\n <use x=\"147.412109\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"211.035156\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"274.658203\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"338.28125\" xlink:href=\"#DejaVuSans-48\"/>\r\n <use x=\"401.904297\" xlink:href=\"#DejaVuSans-48\"/>\r\n </g>\r\n </g>\r\n </g>\r\n <g id=\"xtick_2\">\r\n <g id=\"line2d_2\">\r\n <g>\r\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"116.282812\" xlink:href=\"#mcb373ee736\" y=\"228.439219\"/>\r\n </g>\r\n </g>\r\n <g id=\"text_2\">\r\n <!-- − 75000 -->\r\n <defs>\r\n <path d=\"M 8.203125 72.90625 \r\nL 55.078125 72.90625 \r\nL 55.078125 68.703125 \r\nL 28.609375 0 \r\nL 18.3125 0 \r\nL 43.21875 64.59375 \r\nL 8.203125 64.59375 \r\nz\r\n\" id=\"DejaVuSans-55\"/>\r\n <path d=\"M 10.796875 72.90625 \r\nL 49.515625 72.90625 \r\nL 49.515625 64.59375 \r\nL 19.828125 64.59375 \r\nL 19.828125 46.734375 \r\nQ 21.96875 47.46875 24.109375 47.828125 \r\nQ 26.265625 48.1875 28.421875 48.1875 \r\nQ 40.625 48.1875 47.75 41.5 \r\nQ 54.890625 34.8125 54.890625 23.390625 \r\nQ 54.890625 11.625 47.5625 5.09375 \r\nQ 40.234375 -1.421875 26.90625 -1.421875 \r\
2021-03-10 12:14:21 +01:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-03-02 08:32:40 +01:00
"source": [
"contour_plot(XMx2_norm, yMx2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Teraz funkcja kosztu ma wykres o bardzo regularnym kształcie – algorytm gradientu prostego zastosowany w takim przypadku bardzo szybko znajdzie minimum funkcji kosztu."
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"livereveal": {
"start_slideshow_at": "selected",
2021-03-10 12:14:21 +01:00
"theme": "white"
2021-03-02 08:32:40 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}