2021-04-06 11:16:04 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe\n",
"# 3. Ewaluacja, regularyzacja, optymalizacja"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.1. Metodologia testowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W uczeniu maszynowym bardzo ważna jest ewaluacja budowanego modelu. Dlatego dobrze jest podzielić posiadane dane na odrębne zbiory – osobny zbiór danych do uczenia i osobny do testowania. W niektórych przypadkach potrzeba będzie dodatkowo wyodrębnić tzw. zbiór walidacyjny."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Zbiór uczący a zbiór testowy"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* Na zbiorze uczącym (treningowym) uczymy algorytmy, a na zbiorze testowym sprawdzamy ich poprawność.\n",
"* Zbiór uczący powinien być kilkukrotnie większy od testowego (np. 4:1, 9:1 itp.).\n",
"* Zbiór testowy często jest nieznany.\n",
"* Należy unikać mieszania danych testowych i treningowych – nie wolno „zanieczyszczać” danych treningowych danymi testowymi!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Czasami potrzebujemy dobrać parametry modelu, np. $\\alpha$ – który zbiór wykorzystać do tego celu?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Zbiór walidacyjny"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Do doboru parametrów najlepiej użyć jeszcze innego zbioru – jest to tzw. **zbiór walidacyjny**"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Zbiór walidacyjny powinien mieć wielkość zbliżoną do wielkości zbioru testowego, czyli np. dane można podzielić na te trzy zbiory w proporcjach 3:1:1, 8:1:1 itp."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Walidacja krzyżowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Którą część danych wydzielić jako zbiór walidacyjny tak, żeby było „najlepiej”?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Niech każda partia danych pełni tę rolę naprzemiennie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img width=\"100%\" src=\"https://chrisjmccormick.files.wordpress.com/2013/07/10_fold_cv.png\"/>\n",
"Żródło: https://chrisjmccormick.wordpress.com/2013/07/31/k-fold-cross-validation-with-matlab-code/"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Walidacja krzyżowa\n",
"\n",
"* Podziel dane $D = \\left\\{ (x^{(1)}, y^{(1)}), \\ldots, (x^{(m)}, y^{(m)})\\right\\} $ na $N$ rozłącznych zbiorów $T_1,\\ldots,T_N$\n",
"* Dla $i=1,\\ldots,N$, wykonaj:\n",
" * Użyj $T_i$ do walidacji i zbiór $S_i$ do trenowania, gdzie $S_i = D \\smallsetminus T_i$. \n",
" * Zapisz model $\\theta_i$.\n",
"* Akumuluj wyniki dla modeli $\\theta_i$ dla zbiorów $T_i$.\n",
"* Ustalaj parametry uczenia na akumulowanych wynikach."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Walidacja krzyżowa – wskazówki\n",
"\n",
"* Zazwyczaj ustala się $N$ w przedziale od $4$ do $10$, tzw. $N$-krotna walidacja krzyżowa (*$N$-fold cross validation*). \n",
"* Zbiór $D$ warto zrandomizować przed podziałem.\n",
"* W jaki sposób akumulować wyniki dla wszystkich zbiórow $T_i$?\n",
"* Po ustaleniu parametrów dla każdego $T_i$, trenujemy model na całych danych treningowych z ustalonymi parametrami.\n",
"* Testujemy na zbiorze testowym (jeśli nim dysponujemy)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### _Leave-one-out_\n",
"\n",
"Jest to szczególny przypadek walidacji krzyżowej, w której $N = m$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* Jaki jest rozmiar pojedynczego zbioru $T_i$?\n",
"* Jakie są zalety i wady tej metody?\n",
"* Kiedy może być przydatna?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Zbiór walidujący a algorytmy optymalizacji\n",
"\n",
"* Gdy błąd rośnie na zbiorze uczącym, mamy źle dobrany parametr $\\alpha$. Należy go wtedy zmniejszyć.\n",
"* Gdy błąd zmniejsza się na zbiorze trenującym, ale rośnie na zbiorze walidującym, mamy do czynienia ze zjawiskiem **nadmiernego dopasowania** (*overfitting*).\n",
"* Należy wtedy przerwać optymalizację. Automatyzacja tego procesu to _early stopping_."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.2. Miary jakości"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Aby przeprowadzić ewaluację modelu, musimy wybrać **miarę** (**metrykę**), jakiej będziemy używać.\n",
"\n",
"Jakiej miary użyc najlepiej?\n",
" * To zależy od rodzaju zadania.\n",
" * Innych metryk używa się do regresji, a innych do klasyfikacji"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metryki dla zadań regresji\n",
"\n",
"Dla zadań regresji możemy zastosować np.:\n",
" * błąd średniokwadratowy (*root-mean-square error*, RMSE):\n",
" $$ \\mathrm{RMSE} \\, = \\, \\sqrt{ \\frac{1}{m} \\sum_{i=1}^{m} \\left( \\hat{y}^{(i)} - y^{(i)} \\right)^2 } $$\n",
" * średni błąd bezwzględny (*mean absolute error*, MAE):\n",
" $$ \\mathrm{MAE} \\, = \\, \\frac{1}{m} \\sum_{i=1}^{m} \\left| \\hat{y}^{(i)} - y^{(i)} \\right| $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W powyższych wzorach $y^{(i)}$ oznacza **oczekiwaną** wartości zmiennej $y$ w $i$-tym przykładzie, a $\\hat{y}^{(i)}$ oznacza wartość zmiennej $y$ w $i$-tym przykładzie wyliczoną (**przewidzianą**) przez nasz model."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Metryki dla zadań klasyfikacji\n",
"\n",
"Aby przedstawić kilka najpopularniejszych metryk stosowanych dla zadań klasyfikacyjnych, posłużmy się następującym przykładem:"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 43,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"import random\n",
"import seaborn\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 44,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1,x2,n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n+1):\n",
" for i in range(m+1):\n",
" X.append(np.multiply(np.power(x1,i),np.power(x2,(m-i))))\n",
" return np.hstack(X)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 45,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel=None, ylabel=None, Y_predicted=[], highlight=None):\n",
" \"\"\"Wykres danych dla zadania klasyfikacji\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" \n",
" if len(Y_predicted) > 0:\n",
" Y_predicted = Y_predicted.tolist()\n",
" X1tn = [x[1] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 0 and yp[0] == 0]\n",
" X1fn = [x[1] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 1 and yp[0] == 0]\n",
" X1tp = [x[1] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 1 and yp[0] == 1]\n",
" X1fp = [x[1] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 0 and yp[0] == 1]\n",
" X2tn = [x[2] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 0 and yp[0] == 0]\n",
" X2fn = [x[2] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 1 and yp[0] == 0]\n",
" X2tp = [x[2] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 1 and yp[0] == 1]\n",
" X2fp = [x[2] for x, y, yp in zip(X, Y, Y_predicted) if y[0] == 0 and yp[0] == 1]\n",
" \n",
" if highlight == 'tn':\n",
" ax.scatter(X1tn, X2tn, c='r', marker='x', s=100, label='Dane')\n",
" ax.scatter(X1fn, X2fn, c='k', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1tp, X2tp, c='k', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1fp, X2fp, c='k', marker='x', s=50, label='Dane')\n",
" elif highlight == 'fn':\n",
" ax.scatter(X1tn, X2tn, c='k', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1fn, X2fn, c='g', marker='o', s=100, label='Dane')\n",
" ax.scatter(X1tp, X2tp, c='k', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1fp, X2fp, c='k', marker='x', s=50, label='Dane')\n",
" elif highlight == 'tp':\n",
" ax.scatter(X1tn, X2tn, c='k', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1fn, X2fn, c='k', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1tp, X2tp, c='g', marker='o', s=100, label='Dane')\n",
" ax.scatter(X1fp, X2fp, c='k', marker='x', s=50, label='Dane')\n",
" elif highlight == 'fp':\n",
" ax.scatter(X1tn, X2tn, c='k', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1fn, X2fn, c='k', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1tp, X2tp, c='k', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1fp, X2fp, c='r', marker='x', s=100, label='Dane')\n",
" else:\n",
" ax.scatter(X1tn, X2tn, c='r', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1fn, X2fn, c='g', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1tp, X2tp, c='g', marker='o', s=50, label='Dane')\n",
" ax.scatter(X1fp, X2fp, c='r', marker='x', s=50, label='Dane')\n",
"\n",
" else:\n",
" ax.scatter(X1n, X2n, c='r', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1p, X2p, c='g', marker='o', s=50, label='Dane')\n",
" \n",
" if xlabel:\n",
" ax.set_xlabel(xlabel)\n",
" if ylabel:\n",
" ax.set_ylabel(ylabel)\n",
" \n",
" ax.margins(.05, .05)\n",
" return fig"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 46,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv('data-metrics.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"\n",
"X2 = powerme(data[:, 1], data[:, 2], n)\n",
"Y2 = np.matrix(data[:, 0]).reshape(m, 1)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 47,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAm8AAAFmCAYAAAA70X3dAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3df3Dc913n8ddbieS5rndo7bjFdWKSYk2H2NCS6kKhmiqFJpeIay2LFiVnIDeXOZOjmXHtwMR33EGHH3PAQHzuEcqlpkPL+JrNTSXFUBU3DZRieoUomSS160ulhpC4yiWqXdq1OCQn+74/vt+v/dVqV/pK2t3v97v7fMzsaPfz/X7Xn/16V/vS9/PL3F0AAADIh660KwAAAIDkCG8AAAA5QngDAADIEcIbAABAjhDeAAAAcoTwBgAAkCNXpl2BNFx11VV+7bXXpl0NAACARZ544olvufuW5fbpyPB27bXXanJyMu1qAAAALGJm/7DSPjSbAgAA5AjhDQAAIEdSD29m9gkze8XMTtXZbmb2UTObNrNnzOyG2LZbzezZcNuh1tUaAAAgHamHN0l/LOnWZbbfJqk3vO2T9DFJMrMrJD0Qbr9e0h1mdn1TawoAAJCy1MObu39J0vlldtkt6VMe+Iqk15vZVkk3Spp29+fcfUHSQ+G+AAAAbSv18JbANkkvxh6fDcvqlQMAALStPIQ3q1Hmy5TXfhKzfWY2aWaTs7OzDascAABAK+UhvJ2VdE3s8dWSZpYpr8ndH3T3Pnfv27Jl2bnvAAAAMisP4e24pJ8LR52+U9J33P0lSY9L6jWz68ysR9Lt4b4AqrlLY2PBzyTlAIDMSj28mdmnJf1vSW81s7NmdpeZ3W1md4e7TEh6TtK0pI9L+gVJcvdXJd0j6YSkM5IedvfTLX8BQB6Mj0vDw9KBA5eDmnvweHg42A4AyIXUl8dy9ztW2O6SPlRn24SCcAdgOUND0v790pEjwePDh4PgduRIUD40lG79AACJpR7eALSAWRDYpCCwRSFu//6g3GqN/wEAZJF5B/Z16evrcxamR0dyl7pivSUqFYIbAGSImT3h7n3L7ZN6nzcALRL1cYuL94EDAOQC4Q3oBFFwi/q4VSqX+8AR4AAgV+jzBnSC8fHLwS3q4xbvAzcwIO3Zk24dAQCJEN6ATjA0JI2OBj+jPm5RgBsYYLQpAOQI4Q3oBGa1r6zVKwcAZBZ93gAAAHKE8AYAAJAjhDcAAIAcIbwBAADkCOENAAAgRwhvAAAAOUJ4AwAAyBHCGwAAQI4wSS8AADWU58sqnS5p6tyUejf3amTniIobimlXCyC8AQBQ7eQLJzV4bFAVr2ju4pwK3QUdPHFQE3sn1L+9P+3qocPRbAoAQEx5vqzBY4MqL5Q1d3FOkjR3cU7lhaD8wsKFlGuITkd4AwAgpnS6pIpXam6reEWlU6UW1whYjPAGAEDM1LmpS1fcqs1dnNP0+ekW1whYjPAGAEBM76YdKtiGmtsKtkE7Nn1/i2sELEZ4AwAgZuTvC+r65/ma27r+eV4jz72uxTUCFiO8AQAQU/ypOzQx/wEV56WCB5MyFPxKFeelifkPaONP3ZFyDdHpmCoEAIA4M/X/7sOaOfAhlf78Y5reJO04/6pG3v0ftPHwA5JZ2jVEhzN3T7sOLdfX1+eTk5NpVwMAkGXuUlesgapSIbih6czsCXfvW24fmk0BAKjmLh04sLjswIGgHEgZ4Q0AgLgouB05Iu3fH1xx278/eEyAQwbQ5w0AgLjx8cvB7fDhoKn08OFg25Ej0sCAtGdPunVER8tEeDOzWyUdkXSFpKPu/ltV239J0t7w4ZWSfkDSFnc/b2bPSypLek3Sqyu1EwMAsKyhIWl0NPgZ9XGLAtzAQFAOpCj1AQtmdoWkr0u6WdJZSY9LusPdv1Zn//dJOuDuPx4+fl5Sn7t/K+m/yYAFAACQRXkZsHCjpGl3f87dFyQ9JGn3MvvfIenTLakZAABAxmQhvG2T9GLs8dmwbAkze52kWyV9Jlbskj5vZk+Y2b56/4iZ7TOzSTObnJ2dbUC1AXQ8d2lsbGkH9nrlANAAWQhvtSbNqfcb732S/sbdz8fK3uXuN0i6TdKHzOzdtQ509wfdvc/d+7Zs2bK+GgOAFHRsHx5ePAIxGqk4PBxsB4AGy0J4OyvpmtjjqyXN1Nn3dlU1mbr7TPjzFUljCpphAaD5hoaWTiERn2KCju0AmiALo00fl9RrZtdJ+qaCgPZvqncys++RNCDpZ2JlBUld7l4O798i6ddaUmsAqJ5C4siR4H58igkAaLDUr7y5+6uS7pF0QtIZSQ+7+2kzu9vM7o7tukfS5919Llb2JkknzexpSX8n6bPu/uetqjsALApwEYIbgCbKwpU3ufuEpImqsj+sevzHkv64quw5SW9rcvUAoL56yygR4AA0SepX3gAgt1hGCUAKMnHlDQByiWWUAKSA8AYAa8UySgBSQHgDgLUyq31lrV45ADQAfd4AAAByhPAGAACQI4Q3AACAHCG8AQAA5AjhDQAAIEcIbwAAADlCeAMAAMgRwhsAAECOEN4AAAByhPAGAACQI4Q3AACAHCG8AQAA5AjhDQAAIEcIbwAAADlCeAMAAMgRwhsAAECOEN4AAAByhPAGAACQI4Q3AACAHCG8AQAA5AjhDQAAIEcIbwAAADlCeAMAAMiRTIQ3M7vVzJ41s2kzO1Rj+01m9h0zeyq8/UrSYwEAANrJlWlXwMyukPSApJslnZX0uJkdd/evVe361+7+r9d4LAAAQFvIwpW3GyVNu/tz7r4g6SFJu1twLAAAQO5kIbxtk/Ri7PHZsKzaj5rZ02b2OTPbucpjZWb7zGzSzCZnZ2cbUW8AAICWS73ZVJLVKPOqx09K+j53v2Bmg5LGJfUmPDYodH9Q0oOS1NfXV3MfIOvK82WVTpc0dW5KvZt7NbJzRMUNxbSrBQBooSyEt7OSrok9vlrSTHwHd/9u7P6Emf2BmV2V5FigXZx84aQGjw2q4hXNXZxTobuggycOamLvhPq396ddPQBAi2Sh2fRxSb1mdp2Z9Ui6XdLx+A5m9r1mZuH9GxXU+1ySY4F2UJ4va/DYoMoLZc1dnJMkzV2cU3khKL+wcCHlGgIAWiX18Obur0q6R9IJSWckPezup83sbjO7O9ztA5JOmdnTkj4q6XYP1Dy29a8CaK7S6ZIqXqm5reIVlU6VWlwjAEBastBsKnefkDRRVfaHsfu/L+n3kx4LtJupc1OXrrhVm7s4p+nz0y2uEQAgLalfeQOwst7NvSp0F2puK3QXtGPTjhbXCACQFsIbkAMjO0fUZbU/rl3WpZFdIy2uEQAgLYQ3IAeKG4qa2DuhYk/x0hW4QndBxZ6gfGPPxpRrGHKXxsaCn0nKAQCrlok+bwBW1r+9XzP3zqh0qqTp89PasWmHRnaNZCe4SdL4uDQ8LO3fLx0+LJkFge3AAenIEWl0VNqzJ+1aAkCuEd6AHNnYs1F33XBX2tWob2goCG5HjgSPDx++HNz27w+2AwDWhfAGoHHMgsAmBYEtCnHxK3EAgHWhz1u7ou8R0hIPcBGCGwA0DOGtXUV9jw4cuBzUor5Hw8PBdqAZovdZXPx9CABYF8Jbu4r3PYq+OOl7hGarfp9VKkvfhwCAdaHPW7ui7xHSMD5+ObhF77P4+3BggNGmALBO5h34l3BfX59PTk6mXY3WcJe6YhdYKxWCG5rHPQhwQ0OL32f1ygEAi5jZE+7et9w+NJu2M/oeodXMgitr1QGtXjkAYNUIb+2KvkcAALQl+ry1K/oeAQDQlghv7WpoKFiKKN7HKApwAwOMNgUAIKdoNm1X9D3CajGxMwDkAuENQICJnTtWeb6so08e1X2P3qejTx5Veb6cdpUALINmUwAqz5dV2v4tTd33DvU+ekQjBxZUPPwAEzt3gJMvnNTgsUFVvKK5i3MqdBd08MRBTeydUP/2/rSr11xMbYOcYp43oMMt+fL2K9W18Komjkn9L4iJndtYeb6sbfdvU3lh6ZW2Yk9RM/fOaGPPxhRq1iJjY8F
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(X2, Y2, xlabel=r'$x_1$', ylabel=r'$x_2$')"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 48,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób, \n",
" żeby wartości zawsze były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0/(1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy (regresja logistyczna)\"\"\"\n",
" return safeSigmoid(X*theta, eps)\n",
"\n",
"def J(h,theta,X,y, lamb=0):\n",
" \"\"\"Funkcja kosztu dla regresji logistycznej\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0)/m\n",
" if lamb > 0:\n",
" j += lamb/(2*m) * np.sum(np.power(theta[1:],2))\n",
" return j\n",
"\n",
"def dJ(h,theta,X,y,lamb=0):\n",
" \"\"\"Gradient funkcji kosztu\"\"\"\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" if lamb > 0:\n",
" g[1:] += lamb/float(y.shape[0]) * theta[1:] \n",
" return g\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja predykcji - klasyfikacja dwuklasowa\"\"\"\n",
" prob = h(theta, X)\n",
" return prob"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 49,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta]) \n",
" return theta, errors"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 50,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.37136167]\n",
" [ 0.90128948]\n",
" [ 0.54708112]\n",
" [-5.9929264 ]\n",
" [ 2.64435168]\n",
" [-4.27978238]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(X2.shape[1])).reshape(X2.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, X2, Y2, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)\n",
"print('theta = {}'.format(theta))"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 51,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02),\n",
" np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 52,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"Y_expected = Y2.astype(int)\n",
"Y_predicted = (classifyBi(theta, X2) > 0.5).astype(int)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 53,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie interaktywnego wykresu\n",
"\n",
"dropdown_highlight = widgets.Dropdown(options=['all', 'tp', 'fp', 'tn', 'fn'], value='all', description='highlight')\n",
"\n",
"def interactive_classification(highlight):\n",
" fig = plot_data_for_classification(X2, Y2, xlabel=r'$x_1$', ylabel=r'$x_2$',\n",
" Y_predicted=Y_predicted, highlight=highlight)\n",
" plot_decision_boundary(fig, theta, X2)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 54,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-04-16 10:29:03 +02:00
"model_id": "3f6877198a304e40b7c159f189c4277a",
2021-04-06 11:16:04 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(Dropdown(description='highlight', options=('all', 'tp', 'fp', 'tn', 'fn'), value='all'),…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.interactive_classification(highlight)>"
]
},
2021-04-16 10:29:03 +02:00
"execution_count": 54,
2021-04-06 11:16:04 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact(interactive_classification, highlight=dropdown_highlight)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Zadanie klasyfikacyjne z powyższego przykładu polega na przypisaniu punktów do jednej z dwóch kategorii:\n",
" 0. <font color=\"red\">czerwone krzyżyki</font>\n",
" 1. <font color=\"green\">zielone kółka</font>\n",
"\n",
"W tym celu zastosowano regresję logistyczną."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"W rezultacie otrzymano model, który dzieli płaszczyznę na dwa obszary:\n",
" 0. <font color=\"red\">na zewnątrz granatowej krzywej</font>\n",
" 1. <font color=\"green\">wewnątrz granatowej krzywej</font>\n",
" \n",
"Model przewiduje klasę <font color=\"red\">0 („czerwoną”)</font> dla punktów znajdujący się w obszarze na zewnątrz krzywej, natomiast klasę <font color=\"green\">1 („zieloną”)</font> dla punktów znajdujących sie w obszarze wewnąrz krzywej."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wszysktie obserwacje możemy podzielić zatem na cztery grupy:\n",
" * **true positives (TP)** – prawidłowo sklasyfikowane pozytywne przykłady (<font color=\"green\">zielone kółka</font> w <font color=\"green\">wewnętrznym obszarze</font>)\n",
" * **true negatives (TN)** – prawidłowo sklasyfikowane negatywne przykłady (<font color=\"red\">czerwone krzyżyki</font> w <font color=\"red\">zewnętrznym obszarze</font>)\n",
" * **false positives (FP)** – negatywne przykłady sklasyfikowane jako pozytywne (<font color=\"red\">czerwone krzyżyki</font> w <font color=\"green\">wewnętrznym obszarze</font>)\n",
" * **false negatives (FN)** – pozytywne przykłady sklasyfikowane jako negatywne (<font color=\"green\">zielone kółka</font> w <font color=\"red\">zewnętrznym obszarze</font>)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Innymi słowy:\n",
"\n",
"<img width=\"50%\" src=\"https://blog.aimultiple.com/wp-content/uploads/2019/07/positive-negative-true-false-matrix.png\">"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TP = 5\n",
"TN = 35\n",
"FP = 3\n",
"FN = 6\n"
]
}
],
"source": [
"# Obliczmy TP, TN, FP i FN\n",
"\n",
"tp = 0\n",
"tn = 0\n",
"fp = 0\n",
"fn = 0\n",
"\n",
"for i in range(len(Y_expected)):\n",
" if Y_expected[i] == 1 and Y_predicted[i] == 1:\n",
" tp += 1\n",
" elif Y_expected[i] == 0 and Y_predicted[i] == 0:\n",
" tn += 1\n",
" elif Y_expected[i] == 0 and Y_predicted[i] == 1:\n",
" fp += 1\n",
" elif Y_expected[i] == 1 and Y_predicted[i] == 0:\n",
" fn += 1\n",
" \n",
"print('TP =', tp)\n",
"print('TN =', tn)\n",
"print('FP =', fp)\n",
"print('FN =', fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"Możemy teraz zdefiniować następujące metryki:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### Dokładność (*accuracy*)\n",
"$$ \\mbox{accuracy} = \\frac{\\mbox{przypadki poprawnie sklasyfikowane}}{\\mbox{wszystkie przypadki}} = \\frac{TP + TN}{TP + TN + FP + FN} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Dokładność otrzymujemy przez podzielenie liczby przypadków poprawnie sklasyfikowanych przez liczbę wszystkich przypadków:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8163265306122449\n"
]
}
],
"source": [
"accuracy = (tp + tn) / (tp + tn + fp + fn)\n",
"print('Accuracy:', accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Uwaga:** Nie zawsze dokładność będzie dobrą miarą, zwłaszcza gdy klasy są bardzo asymetryczne!\n",
"\n",
"*Przykład:* Wyobraźmy sobie test na koronawirusa, który **zawsze** zwraca wynik negatywny. Jaką przydatność będzie miał taki test w praktyce? Żadną. A jaka będzie jego *dokładność*? Policzmy:\n",
"$$ \\mbox{accuracy} \\, = \\, \\frac{\\mbox{szacowana liczba osób zdrowych na świecie}}{\\mbox{populacja Ziemi}} \\, \\approx \\, \\frac{7\\,700\\,000\\,000 - 600\\,000}{7\\,700\\,000\\,000} \\, \\approx \\, 0.99992 $$\n",
"(zaokrąglone dane z 27 marca 2020)\n",
"\n",
"Powyższy wynik jest tak wysoki, ponieważ zdecydowana większość osób na świecie nie jest zakażona, więc biorąc losowego Ziemianina możemy w ciemno strzelać, że nie ma koronawirusa.\n",
"\n",
"W tym przypadku duża różnica w liczności obu zbiorów (zakażeni/niezakażeni) powoduje, że *accuracy* nie jest dobrą metryką.\n",
"\n",
"Dlatego dysponujemy również innymi metrykami:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### Precyzja (*precision*)\n",
"$$ \\mbox{precision} = \\frac{TP}{TP + FP} $$"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Precision: 0.625\n"
]
}
],
"source": [
"precision = tp / (tp + fp)\n",
"print('Precision:', precision)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Precyzja określa, jaka część przykładów sklasyfikowanych jako pozytywne to faktycznie przykłady pozytywne."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### Pokrycie (czułość, *recall*)\n",
"$$ \\mbox{recall} = \\frac{TP}{TP + FN} $$"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Recall: 0.45454545454545453\n"
]
}
],
"source": [
"recall = tp / (tp + fn)\n",
"print('Recall:', recall)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Pokrycie mówi nam, jaka część przykładów pozytywnych została poprawnie sklasyfikowana."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### *$F$-measure* (*$F$-score*)\n",
"$$ F = \\frac{2 \\cdot \\mbox{precision} \\cdot \\mbox{recall}}{\\mbox{precision} + \\mbox{recall}} $$"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F-score: 0.5263157894736842\n"
]
}
],
"source": [
"fscore = (2 * precision * recall) / (precision + recall)\n",
"print('F-score:', fscore)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"$F$-_measure_ jest kompromisem między precyzją a pokryciem (a ściślej: jest średnią harmoniczną precyzji i pokrycia)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$F$-_measure_ jest szczególnym przypadkiem ogólniejszej miary:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"*$F_\\beta$-measure*:\n",
"$$ F_\\beta = \\frac{(1 + \\beta) \\cdot \\mbox{precision} \\cdot \\mbox{recall}}{\\beta^2 \\cdot \\mbox{precision} + \\mbox{recall}} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Dla $\\beta = 1$ otrzymujemy:\n",
"$$ F_1 \\, = \\, \\frac{(1 + 1) \\cdot \\mbox{precision} \\cdot \\mbox{recall}}{1^2 \\cdot \\mbox{precision} + \\mbox{recall}} \\, = \\, \\frac{2 \\cdot \\mbox{precision} \\cdot \\mbox{recall}}{\\mbox{precision} + \\mbox{recall}} \\, = \\, F $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.3. Obserwacje odstające"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Obserwacje odstające** (*outliers*) – to wszelkie obserwacje posiadające nietypową wartość.\n",
"\n",
"Mogą być na przykład rezultatem błędnego pomiaru albo pomyłki przy wprowadzaniu danych do bazy, ale nie tylko.\n",
"\n",
"Obserwacje odstające mogą niekiedy znacząco wpłynąć na parametry modelu, dlatego ważne jest, żeby takie obserwacje odrzucić zanim przystąpi się do tworzenia modelu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W poniższym przykładzie można zobaczyć wpływ obserwacji odstających na wynik modelowania na przykładzie danych dotyczących cen mieszkań zebranych z ogłoszeń na portalu Gratka.pl: tutaj przykładem obserwacji odstającej może być ogłoszenie, w którym podano cenę w tys. zł zamiast ceny w zł."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"def h_linear(Theta, x):\n",
" \"\"\"Funkcja regresji liniowej\"\"\"\n",
" return x * Theta\n",
"\n",
"def linear_regression(theta):\n",
" \"\"\"Ta funkcja zwraca funkcję regresji liniowej dla danego wektora parametrów theta\"\"\"\n",
" return lambda x: h_linear(theta, x)\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print('Algorithm does not converge!')\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"def plot_regression(fig, fun, theta, X):\n",
" \"\"\"Wykres krzywej regresji (wersja macierzowa)\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" L = [x0, x1]\n",
" LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n",
" ax.plot(L, fun(theta, LX), linewidth='2',\n",
" label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n",
" theta0=float(theta[0][0]),\n",
" theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n",
" op='+' if theta[1][0] >= 0 else '-')))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv('data_flats_with_outliers.tsv', sep='\\t',\n",
" names=['price', 'isNew', 'rooms', 'floor', 'location', 'sqrMetres'])\n",
"data = np.matrix(alldata[['price', 'sqrMetres']])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"\n",
"Xo = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"yo = np.matrix(data[:, -1]).reshape(m, 1)\n",
"\n",
"Xo /= np.amax(Xo, axis=0)\n",
"yo /= np.amax(yo, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFoCAYAAADq7KeuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAcJUlEQVR4nO3dfbBkZ10n8O9vJkOik9EISQjkBbQyC6IrAa9JkNRuYMWFKZboLu7E3ZKIWQIKFi/KGrVKFP9YastXJJCKvGZXcXwDUsUARgqNKTaGSUyAvOBMAco4kYSXDTcDiZnMs390z2a46Zn0vbdv93N7Pp+qW919znNO/+65p3u+85xznlOttQAA0K8Nsy4AAICjE9gAADonsAEAdE5gAwDonMAGANA5gQ0AoHMzC2xVdWZVfayq7qiq26rq1SPaVFW9uar2VNUnq+qZs6gVAGCWjpvhex9I8nOttZurakuSm6rq2tba7Ye1eUGSrcOf85K8bfgIAHDMmFkPW2vtrtbazcPni0nuSHL6kmYXJbm6DdyQ5KSqesKUSwUAmKkuzmGrqicneUaSv10y6/QkXzjs9d48MtQBAMy1WR4STZJU1YlJ/izJa1prX1s6e8QiI++lVVWXJbksSTZv3vz9T33qUydaJwDAat10001faq2dstzlZhrYqmpTBmHtD1prfz6iyd4kZx72+owk+0atq7V2VZKrkmRhYaHt2rVrwtUCAKxOVf3DSpab5VWileQdSe5orf3WEZpdk+Qlw6tFz09yb2vtrqkVCQDQgVn2sD07yU8k+VRV3TKc9ktJzkqS1tqVSXYm2ZZkT5KvJ3npDOoEAJipmQW21tr1GX2O2uFtWpJXTqciAIA+dXGVKAAARyawAQB0TmADAOicwAYA0DmBDQCgcwIbAEDnBDYAgM4JbAAAnRPYAAA6J7ABAHROYAMA6JzABgDQOYENAKBzAhsAQOcENgCAzglsAACdE9gAADonsAEAdE5gAwDonMAGANA5gQ0AoHMCGwBA5wQ2AIDOCWwAAJ0T2AAAOiewAQB0TmADAOicwAYA0DmBDQCgcwIbAEDnBDYAgM4JbAAAnRPYAAA6J7ABAHROYAMA6JzABgDQOYENAKBzAhsAQOcENgCAzglsAACdE9gAADonsAEAdE5gAwDo3EwDW1W9s6rurqpPH2H+hVV1b1XdMvz5lWnXCAAwa8fN+P3fneQtSa4+Spu/aa29cDrlAAD0Z6Y9bK2165J8ZZY1AAD0bj2cw/asqrq1qj5UVd8z62IAAKZt1odEH83NSZ7UWruvqrYleX+SraMaVtVlSS5LkrPOOmt6FQIArLGue9haa19rrd03fL4zyaaqOvkIba9qrS201hZOOeWUqdYJALCWug5sVXVaVdXw+bkZ1Pvl2VYFADBdMz0kWlXvTXJhkpOram+SNyTZlCSttSuTvDjJT1fVgSTfSHJxa63NqFwAgJmYaWBrrf34o8x/SwbDfgAAHLO6PiQKAIDABgDQPYENAKBzAhsAQOcENgCAzglsAACdE9gAADonsAEAdE5gAwDonMAGANA5gQ0AoHMCGwBA5wQ2AIDOCWwAAJ0T2AAAOiewAQB0TmADAOicwAYA0DmBDQCgcwIbAEDnBDYAgM4JbAAAnRPYAAA6J7ABAHROYAMA6JzABgDQOYENAKBzAhsAQOcENgCAzglsAACdE9gAADonsAEAdE5gAwDonMAGANA5gQ0AoHMCGwBA5wQ2AIDOCWwAAJ0T2AAAOiewAQB0TmADAOicwAYA0DmBDQCgczMNbFX1zqq6u6o+fYT5VVVvrqo9VfXJqnrmtGuEri0uJm9/e/ILvzB4XFycdUUArIHjZvz+707yliRXH2H+C5JsHf6cl+Rtw0fg+uuTbduSgweT/fuTzZuT170u2bkzueCCWVcHwATNtIettXZdkq8cpclFSa5uAzckOamqnjCd6qBji4uDsLa4OAhryeDx0PT77pttfQBMVO/nsJ2e5AuHvd47nAbHth07Bj1roxw8OJgPwNzoPbDViGltZMOqy6pqV1Xtuueee9a4LJix3bsf7llbav/+ZM+e6dYDwJrqPbDtTXLmYa/PSLJvVMPW2lWttYXW2sIpp5wyleJgZrZuHZyzNsrmzcnZZ0+3HgDWVO+B7ZokLxleLXp+kntba3fNuiiYue3bkw1H+Phu2DCYD8DcmOlVolX13iQXJjm5qvYmeUOSTUnSWrsyyc4k25LsSfL1JC+dTaXQmS1bBleDLr1KdMOGwfQTT5x1hQBM0EwDW2vtxx9lfkvyyimVA+vLBRck+/YNLjDYs2dwGHT7dmENYA7Nehw2YDVOPDG59NJZVwHAGuv9HDYAgGOewAYA0DmBDQCgcwIbAEDnBDYAgM4JbAAAnRPYAAA6J7ABAHROYAMA6JzABgDQOYENAKBzAhsAQOcENgCAzglsAACdE9gAADonsAEAdE5gAwDonMAGANA5gQ0AoHMCGwBA546bdQHAKi0uJjt2JLt3J1u3Jtu3J1u2zLoqACZIYIP17Prrk23bkoMHk/37k82bk9e9Ltm5M7nggllXB8CEOCQK69Xi4iCsLS4OwloyeDw0/b77ZlsfABMjsMF6tWPHoGdtlIMHB/MBmAsCG6xXu3c/3LO21P79yZ49060HgDUjsMF6tXXr4Jy1UTZvTs4+e7r1ALBmBDZYr7ZvTzYc4SO8YcNgPgBzQWCD9WrLlsHVoFu2PNzTtnnzw9NPPHG29QEwMYb1gPXsgguSffsGFxjs2TM4DLp9u7AGMGcENljvTjwxufTSh18vLiZvf7uBdAHmiMAG88RAugBzyTlsMC8MpAswtwQ2mBcG0gWYWwIbzAsD6QLMLYEN5oWBdAHmlsAG88JAugBza+yrRKvqO5JsTXLCoWmttevWoihgBQ4NmLv0KtENGwykC7DOjRXYquq/JXl1kjOS3JLk/CT/J8lz1640YNkMpAswl8btYXt1kh9IckNr7TlV9dQkv7Z2ZQErtnQgXQDWvXHPYbu/tXZ/klTV8a21O5M8Ze3KAgDgkHF72PZW1UlJ3p/k2qr6apJ9a1cWsGKLi4NDom5NBTA3qrW2vAWq/m2Sb0/y4dbav6xJVau0sLDQdu3aNesyYPpG3Zrq0EUHbk0FMHNVdVNrbWG5y409rEdVbayqJyb5XAYXHpy23Dcbsc7nV9VnqmpPVV0+Yv6FVXVvVd0y/PmV1b4nzC23pgKYW+NeJfqzSd6Q5ItJDt37piX5vpW+cVVtTHJFkucl2ZvkE1V1TWvt9iVN/6a19sKVvg/MtcMPf/7zPycPPji63YMPDtq5GAFgXVrOVaJPaa19eYLvfW6SPa21zyZJVf1RkouSLA1swChLD38ed1xy4MDotvffn9zuowWwXo17SPQLSe6d8HufPlzvIXuH05Z6VlXdWlUfqqrvOdLKquqyqtpVVbvuueeeCZcKnRl1+PNIYe2QL0/y/1sATNO4PWyfTfJXVfXBJA8cmtha+61VvHeNmLb0CoibkzyptXZfVW3L4CrVraNW1lq7KslVyeCig1XUBf3bsWPQs7Ycj3vc2tQCwJobt4ftH5Ncm+QxSbYc9rMae5OcedjrM7JkqJDW2tdaa/cNn+9MsqmqTl7l+8L6t3v3wz1r4zjhhORpT1u7egBYU2P1sLXWfi1Jqmpza20Z/0oc1SeSbK2q70zyT0kuTvJfDm9QVacl+WJrrVXVuRkETMd1YOvWZOPG5KGHxmu/aZObvwOsY2P1sFXVs6rq9iR3DF8/vareupo3bq0dSPKqJB8ZrvePW2u3VdUrquoVw2YvTvLpqro1yZuTXNyWO3AczKNt28YLa5s3P3xTePcTBVi3xj2H7XeS/Psk1yRJa+3Wqvo3q33z4WHOnUumXXnY87ckectq3wfmzu/+7qO3qUouuyx54xuFNYB1buyBc1trX1gyacxjMcBELS4mv/M7j96uteS3fzv5+MfXviYA1tTYw3pU1Q8maVX1mKr6+QwPjwJTtmPH+OeuJcmLXuQuBwDr3LiB7RVJXpnBOGl7k5yT5Gf
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(Xo, yo, xlabel=u'metraż', ylabel=u'cena')\n",
"theta_start = np.matrix([0.0, 0.0]).reshape(2, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, Xo, yo, alpha=0.01)\n",
"plot_regression(fig, h_linear, theta, Xo)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Na powyższym przykładzie obserwacja odstająca jawi sie jako pojedynczy punkt po prawej stronie wykresu. Widzimy, że otrzymana krzywa regresji zamiast odwzorowywać ogólny trend, próbuje „dopasować się” do tej pojedynczej obserwacji.\n",
"\n",
"Dlatego taką obserwację należy usunąć ze zbioru danych (zobacz ponizej)."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Odrzućmy obserwacje odstające\n",
"alldata_no_outliers = [\n",
" (index, item) for index, item in alldata.iterrows() \n",
" if item.price > 100 and item.sqrMetres > 10]\n",
"\n",
"alldata_no_outliers = alldata.loc[(alldata['price'] > 100) & (alldata['sqrMetres'] > 100)]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"data = np.matrix(alldata_no_outliers[['price', 'sqrMetres']])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"\n",
"Xo = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"yo = np.matrix(data[:, -1]).reshape(m, 1)\n",
"\n",
"Xo /= np.amax(Xo, axis=0)\n",
"yo /= np.amax(yo, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFoCAYAAADq7KeuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de5BcZ3nn8e8zmtFtZmSbsoxsWeayaO2YXWxYxZjgWi4JBKvYOAvEgs3GCsuW17GcNbdUTLIFhCSLUxtMAsgYr+PY3k3AW4GAKwgIIUk5TkJAJjbgu+IkWEjYigHP6K6Rnv3j9Eg9Mz0zPZfu82rm+6nq6tPnnO5+Z1ot/fSe533fyEwkSZJUrp66GyBJkqSpGdgkSZIKZ2CTJEkqnIFNkiSpcAY2SZKkwhnYJEmSCldbYIuIdRHxFxHxUEQ8EBHXtjgnIuIjEbEjIr4ZES+po62SJEl16q3xvUeAd2XmNyJiELg3Ir6cmQ82nXMpsL5xeynw8ca9JEnSolFbD1tm7s7MbzS2h4GHgLXjTrsMuCMrXwVOjYgzu9xUSZKkWhVRwxYRzwVeDPzduENrgSeaHu9kYqiTJEla0Oq8JApARAwAnwbenplD4w+3eErLtbQi4krgSoD+/v5/d955581rOyVJkubq3nvv/ZfMXD3T59Ua2CKijyqs/UFmfqbFKTuBdU2PzwZ2tXqtzLwZuBlgw4YNuX379nlurSRJ0txExD/P5nl1jhIN4PeAhzLzhklOuwu4ojFa9GLgmczc3bVGSpIkFaDOHraXAz8HfCsi7mvs+xXgHIDMvAnYBmwEdgD7gbfW0E5JkqRa1RbYMvMeWteoNZ+TwJbutEiSJKlMRYwSlSRJ0uQMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVLhaA1tE3BoRT0XEtyc5/sqIeCYi7mvc3tvtNkqSJNWtt+b3vw34GHDHFOf8VWa+vjvNkSRJKk+tPWyZeTfw/TrbIEmSVLqToYbtZRFxf0R8ISJeONlJEXFlRGyPiO179uzpZvskSZI6qvTA9g3gOZl5AfBR4LOTnZiZN2fmhszcsHr16q41UJIkqdOKDmyZOZSZexvb24C+iDi95mZJkiR1VdGBLSLWREQ0ti+iau/T9bZKkiSpu2odJRoRnwReCZweETuB9wF9AJl5E/Am4BciYgQ4ALw5M7Om5kqSJNWi1sCWmW+Z5vjHqKb9kCRJWrSKviQqSZIkA5skSVLxDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYWrNbBFxK0R8VREfHuS4xERH4mIHRHxzYh4SbfbKKlmw8Nwyy3wy79c3Q8P190iSeq63prf/zbgY8Adkxy/FFjfuL0U+HjjXtJicM89sHEjHDsG+/ZBfz+8852wbRtcckndrZOkrqm1hy0z7wa+P8UplwF3ZOWrwKkRcWZ3WiepVsPDVVgbHq7CGlT3o/v37q23fZLURaXXsK0Fnmh6vLOxT9JCd+edVc9aK8eOVcclaZEoPbBFi33Z8sSIKyNie0Rs37NnT4ebJanjHnvsRM/aePv2wY4d3W2PJNWo9MC2E1jX9PhsYFerEzPz5szckJkbVq9e3ZXGSeqg9eurmrVW+vvhBS/obnskqUalB7a7gCsao0UvBp7JzN11N0pSF2zaBD2T/BXV01Mdl6RFotZRohHxSeCVwOkRsRN4H9AHkJk3AduAjcAOYD/w1npaKqnrBger0aDjR4n29FT7BwbqbqEkdU2tgS0z3zLN8QS2dKk5kkpzySWwa1c1wGDHjuoy6KZNhjVJi07d87BJ0tQGBuBtb6u7FZJUq9Jr2CRJkhY9A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYUzsEmSJBXOwCZJklQ4A5skSVLhDGySJEmFM7BJkiQVzsAmSZJUOAObJElS4QxskiRJhTOwSZIkFc7AJkmSVDgDmyRJUuEMbJIkSYXrrbsBknRSGB6GO++Exx6D9eth0yYYHKy7VZIWCQObJE3nnntg40Y4dgz27YP+fnjnO2HbNrjkkrpbJ2kR8JKoJE1leLgKa8PDVViD6n50/9699bZP0qJgYJOkqdx5Z9Wz1sqxY9VxSeqwti+JRsRpwHpg+ei+zLy7E42SpGI89tiJnrXx9u2DHTu62x5Ji1JbgS0i/itwLXA2cB9wMfC3wKs71zRJKsD69VXNWqvQ1t8PL3hB99skadFp95LotcCPAv+cma8CXgzs6VirJKkUmzZBzyR/Vfb0VMclqcPaDWwHM/MgQEQsy8yHgXM71yxJKsTgYDUadHCw6lGD6n50/8BAve2TtCi0W8O2MyJOBT4LfDkifgDs6lyzJKkgl1wCu3ZVAwx27Kgug27aZFiT1DWRmTN7QsQrgFOAL2bm4Y60ao42bNiQ27dvr7sZkiRJY0TEvZm5YabPm8ko0SXAs4F/bOxaA3xnpm8oSZKkmWl3lOgvAu8DngRGJyRK4EUdapckqRQuyyXVrt0etmuBczPz6U42RpJUGJflkorQ7ijRJ4BnOtkQSVJhXJZLKka7PWyPA38ZEZ8HDo3uzMwbOtIqSVL92lmW621v626bpEWq3cD2ncZtaeMmSVroXJZLKkZbgS0zfw0gIvozc5JvryRpQXFZLqkYbdWwRcTLIuJB4KHG4wsi4saOtkySVC+X5ZKK0e6gg98BfhJ4GiAz7wf+facaJUkqgMtyScVoe+LczHwiIpp3HZ3/5kiSiuKyXFIR2g1sT0TEjwEZEUuB/07j8qgkaYEbGHA0qFSzdi+JXgVsAdYCO4ELgas71ShJkiSd0G4P24eAazLzBwARcVpj33/pVMMkSQVxeSqpVu0GtheNhjWAzPxBRLy4Q22SJJXE5amk2rV7SbSn0asGQEQ8ixkMWJhMRLwuIh6JiB0RcV2L46+MiGci4r7G7b1zfU9J0gy4PJVUhJlcEv2biPgjIIHLgd+cyxtHxBJgK/Aaqrq4r0fEXZn54LhT/yozXz+X95IkzZLLU0lFaHelgzsiYjvwaiCAN7QIVjN1EbAjMx8HiIhPAZcBc31dSYuZtVbzy+WppCLMZB62B5nfMLUWeKLp8U7gpS3Oe1lE3A/sAt6dmQ/MYxskLSSdrLXatQve8x54+GE47zz44AfhrLPmp90lW7du6uNnn92ddkiLXLs1bJ0QLfbluMffAJ6TmRcAHwU+O+mLRVwZEdsjYvuePXvmsZmSTgqdrLW68UZYuxbuuAO+9rXqfu3aar8kdUGdgW0n0Pxft7OpetGOy8yhzNzb2N4G9EXE6a1eLDNvzswNmblh9erVnWq
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(Xo, yo, xlabel=u'metraż', ylabel=u'cena')\n",
"theta_start = np.matrix([0.0, 0.0]).reshape(2, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, Xo, yo, alpha=0.01)\n",
"plot_regression(fig, h_linear, theta, Xo)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Na powyższym wykresie widać, że po odrzuceniu obserwacji odstających otrzymujemy dużo bardziej „wiarygodną” krzywą regresji."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.4. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix([\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4) \n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5) \n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X5 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)).reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAT+0lEQVR4nO3df4zteV3f8dd79kKQmWlcwgXWhRbqnYCWP8TeEpRJQ0Xa9bZxW6OZNVFXc5NNm1Kx17RS20jSNC1pGlPbWJvNQtEUYQhi3dhblaJEb7Rk765bYbmSmVCF27u6lzbB2Wkb3M6nf5y5vdfLvXtnl5nve+6cxyPZnJnzPXPOO9987/Dk+2tqjBEAAKa10D0AAMA8EmEAAA1EGABAAxEGANBAhAEANBBhAAANDizCquq9VfVUVX3qmudeUlUfraqN3cc7D+rzAQAOs4PcE/a+JPdc99w7k3xsjLGS5GO73wMAzJ06yJu1VtWrk/ziGOP1u99/JslbxhhPVtVdST4+xnjtgQ0AAHBITX1O2MvHGE8mye7jyyb+fACAQ+FY9wA3U1UPJHkgSRYXF//86173uuaJAAD+pEcfffQLY4zjz+dnp46wP6yqu645HPnUzV44xngwyYNJcvLkyXH+/PmpZgQA2JOq+v3n+7NTH458OMn9u1/fn+QXJv58AIBD4SBvUfGBJL+V5LVVdbGqTid5d5K3VdVGkrftfg8AMHcO7HDkGOO7b7LorQf1mQAAtwt3zAcAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGx7oHAOCI2tpK1teTjY1kZSVZW0uWl7ungkNDhAGw/86dS06dSnZ2ku3tZHExOXMmOXs2WV3tng4OBYcjAdhfW1uzANvamgVYMnu88vzTT/fOB4eECANgf62vz/aA3cjOzmw5IMIA2GcbG1f3gF1vezvZ3Jx2HjikRBgA+2tlZXYO2I0sLiYnTkw7DxxSIgxgHm1tJQ89lPzIj8wet7b2773X1pKFm/zPy8LCbDng6kiAuXPQVy4uL8/e6/rPWFiYPb+09JV/BhwBIgxgnlx75eIVV87fOnUquXRpfyJpdXX2Xuvrs3PATpyY7QETYPD/iTCAebKXKxdPn96fz1pa2r/3giPIOWEA88SVi3BoiDCAeeLKRTg0RBjAPHHlIhwaIgxgnly5cnF5+eoescXFq887cR4m48R8gHnjykU4FEQYwDxy5SK0czgSAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKBBS4RV1d+tqieq6lNV9YGqelHHHAAAXSaPsKq6O8kPJjk5xnh9kjuS3Df1HAAAnboORx5L8lVVdSzJi5NcapoDAKDF5BE2xvjvSf5Fks8leTLJF8cYv3L966rqgao6X1XnL1++PPWYAAAHquNw5J1J7k3ymiRfk2Sxqr7n+teNMR4cY5wcY5w8fvz41GMCAByojsOR35rkv40xLo8x/jjJR5J8c8McAABtOiLsc0neVFUvrqpK8tYkFxrmAABo03FO2CeSfDjJY0k+uTvDg1PPAQDQ6VjHh44x3pXkXR2fDQBwGLhjPgBAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAg2PdAwC029pK1teTjY1kZSVZW0uWl7unAo44EQbMt3PnklOnkp2dZHs7WVxMzpxJzp5NVle7pwOOMIcjgfm1tTULsK2tWYAls8crzz/9dO98wJEmwoD5tb4+2wN2Izs7s+UAB0SEAfNrY+PqHrDrbW8nm5vTzgPMFREGzK+Vldk5YDeyuJicODHtPMBcEWHA/FpbSxZu8mtwYWG2HOCAiDBgfi0vz66CXF6+ukdscfHq80tLvfMBR5pbVADzbXU1uXRpdhL+5ubsEOTamgADDpwIA1haSk6f7p4CmDMORwIANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANCgJcKq6qur6sNV9btVdaGqvqljDgCALseaPvcnkvzSGOM7q+qFSV7cNAcAQIvJI6yq/lSSv5jk+5NkjPGlJF+aeg4AgE4dhyP/bJLLSf5dVf12VT1UVYsNcwAAtOmIsGNJvjHJT40x3pBkO8k7r39RVT1QVeer6vzly5ennhEA4EB1RNjFJBfHGJ/Y/f7DmUXZnzDGeHCMcXKMcfL48eOTDggAcNAmj7Axxh8k+XxVvXb3qbcm+fTUcwAAdOq6OvLvJHn/7pWRn03yA01zAAC0aImwMcbjSU52fDYAwGHgjvkAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA2OdQ8AzJmtrWR9PdnYSFZWkrW1ZHm5eyqAyYkwYDrnziWnTiU7O8n2drK4mJw5k5w9m6yudk8HMCmHI4FpbG3NAmxraxZgyezxyvNPP907H8DERBgwjfX12R6wG9nZmS0HmCMiDJjGxsbVPWDX295ONjennQegmQgDprGyMjsH7EYWF5MTJ6adB6CZCAOmsbaWLNzkV87Cwmw5wBwRYcA0lpdnV0EuL1/dI7a4ePX5paXe+QAm5hYVwHRWV5NLl2Yn4W9uzg5Brq0JMGAuiTBgWktLyenT3VMAtHM4EgCgwS0jrKreXlV3TjEMAMC82MuesFckeaSqPlRV91RVHfRQAABH3S0jbIzxj5KsJHlPku9PslFV/7SqvvaAZwMAOLL2dE7YGGMk+YPd/55JcmeSD1fVPz/A2QAAjqxbXh1ZVT+Y5P4kX0jyUJK/N8b446paSLKR5O8f7IgAAEfPXm5R8dIk3zHG+P1rnxxj7FTVXzuYsQAAjrZbRtgY48eeZdmF/R0HAGA+uE8YAEADEQYA0ECEAQA0EGEAAA1EGABAg7YIq6o7quq3q+oXu2YAAOjSuSfsHUnc4gIAmEstEVZVr0zyVzO7Az8AwNzp2hP2LzP7c0c7N3tBVT1QVeer6vzly5enmwwAYAKTR9junzp6aozx6LO9bozx4Bjj5Bjj5PHjxyeaDgBgGh17wt6c5Nur6veSfDDJt1TVv2+YAwCgzeQRNsb4B2OMV44xXp3kviS/Osb4nqnnAADo5D5hAAANjnV++Bjj40k+3jkDAEAHe8IAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoMHkEVZVr6qqX6uqC1X1RFW9Y+oZAAC6HWv4zGeS/PAY47GqWk7yaFV9dIzx6YZZAABaTL4nbIzx5Bjjsd2vt5JcSHL31HMAAHRqPSesql6d5A1JPtE5BwDA1NoirKqWkvxckh8aY/zRDZY/UFXnq+r85cuXpx8QAOAAtURYVb0gswB7/xjjIzd6zRjjwTHGyTHGyePHj087IADAAZv8xPyqqiTvSXJhjPHjU38+kGRrK1lfTzY2kpWVZG0tWV7ungpgrnRcHfnmJN+b5JNV9fjucz86xjjbMAv
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')"
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 26,
2021-04-06 11:16:04 +02:00
"metadata": {},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)"
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 27,
2021-04-06 11:16:04 +02:00
"metadata": {},
"outputs": [],
"source": [
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 28,
2021-04-06 11:16:04 +02:00
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-09 15:58:14 +02:00
"[<matplotlib.lines.Line2D at 0x22ab32eac70>]"
2021-04-06 11:16:04 +02:00
]
},
2021-04-09 15:58:14 +02:00
"execution_count": 28,
2021-04-06 11:16:04 +02:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXiU9b3+8fubHSZhD/sOkR0SS13Ro3VHLXXBhJ722Nbfse2pBY0b7latlbpC29MeThfb05YEXFHRui9oXdCEsJOwh7AkLGGyL/P9/TGxUJpAEmbmO8v7dV1eSeaZzHNfz/VkvHmWzxhrrQAAABBaca4DAAAAxCJKGAAAgAOUMAAAAAcoYQAAAA5QwgAAAByghAEAADgQtBJmjPm9MWavMWb1EY/1Msa8YYwpbvnaM1jrBwAACGfBPBL2tKSLj3psrqS3rLUZkt5q+RkAACDmmGAOazXGDJf0srV2YsvPGySdY63dZYwZIOlda+2YoAUAAAAIU6G+JqyftXaXJLV87Rvi9QMAAISFBNcB2mKMuV7S9ZLk8Xi+MnbsWMeJAAAdUVZZq31VDUpKiFNG31TFGeM6EhBwn3/+eYW1Nr0zvxvqErbHGDPgiNORe9t6orV2oaSFkjR16lS7YsWKUGUEAJygdzbs1Xf/8JmGxBk9+8MzNGVID9eRgKAwxmzr7O+G+nTkUknXtnx/raQXQ7x+AECQlXvrdeuSlZKkmy8cQwED2hDMERWLJP1d0hhjTKkx5jpJj0i6wBhTLOmClp8BAFHCWqtbn1mpiqoGnT6yt75/9kjXkYCwFbTTkdbaWW0sOi9Y6wQAuPX0R1v17oZy9eiaqCeypygujuvAgLYwMR8AEBDrdh3Sz15dL0l65MrJGtC9i+NEQHijhAEATlhdY7NmLypQQ5NPs04Zqosn9ncdCQh7lDAAwAn76SvrVLy3SiPTPbrnsnGu4wARgRIGADghb67do//7eJsS440W5GSpa1LYjqAEwgolDADQaXsP1em2Z4skSbddNFYTB3V3nAiIHJQwAECn+HxWNy9Zqf3VDToro4+umzbCdSQgolDCAACd8rvlW/RBcYV6eZL0+EzGUQAdRQkDAHTY6p2V+vnf/OMo5l01WX27pThOBEQeShgAoENqGpo0O69Ajc1W3z5tmC4Y3891JCAiUcIAAB3y4MvrtLm8Whl9U3XXpYyjADqLEgYAaLfXVu/Wok+3KykhTgtmZSklMd51JCBiUcIAAO2yq7JWc5/zj6OYe/FYjRvQzXEiILJRwgAAx9Xss8rNX6mDNY06Z0y6vnvmcNeRgIhHCQMAHNfC9zfr75v3qU9qkh69eoqMYRwFcKIoYQCAY1q546Aef32DJOnRq6coPS3ZcSIgOlDCAABtqq5v0py8AjX5rL5zxnCdO7av60hA1KCEAQDadP/SNdq6r0Zj+6dp7iVjXccBogolDADQqpeLyrTk81IlM44CCApKGADgX5QeqNEdz62SJN196Tid1C/NcSIg+iS4DgAACC9fjqPw1jXp/HF99a3ThnXuhbxeKT9fKi6WMjKk7GwpjTIHfIkSBgD4J//9Tok+3bpf6WnJmnfV5M6No1i+XJo+XfL5pOpqyeORcnOlZcukadMCHxqIQJyOBAD8w+fbDuipt4olSY/PnKLeqZ0YR+H1+guY1+svYJL/65ePV1UFMDEQuShhAABJkreuUTfmF6jZZ/WfZ43Q2Seld+6F8vP9R8Ba4/P5lwOghAEA/O59cY127K/V+AHddMtFYzr/QsXFh4+AHa26Wiop6fxrA1GEEgYA0AsFO/V8wU6lJPrHUSQnnMA4iowM/zVgrfF4pNGjO//aQBThwnwAiEVH3Lm4Y/hY3b1rgCTp3ssmaHTf1BN77exs/0X4rYmL8y8HQAkDgJhzxJ2LTTW1mvPtR1U1oK8u6p+gWacMOfHXT0vz3wV59N2RcXH+x1NPsOQBUYISBgCx5Mg7FyUtmPZNfTFgjPp7K/TI7+6Q+X8bA1OSpk2Tysr8R9tKSvynILOzKWDAEShhABBLjrhz8bNB4/XL07NlrE9PvPyEetYe8i+/7rrArCs1NXCvBUQhShgAxJKWOxcrkz268fJb5IuL1w8+XqIzthf5l3PnIhAylDAAiCUZGbIej+4670fa2b2vJu/aqNwP/uJfxp2LQEgxogIAYkl2tp4dd45eHne2ujbUav5LjynJ1+Rfxp2LQEhxJAwAYsjW+jjdd+EPpWbp/g+e1ogDZdy5CDhCCQOAGNHY7NOc/EJVN0uXjkvXzNFXSOdP4s5FwBFKGADEiKfe3KiVOw5qYPcUPTwzS6brKa4jATGNa8IAIAZ8vHmf/vvdTYoz0pPZmereNdF1JCDmUcIAIModrGnQTfmFslb60bmjderI3q4jARAlDACimrVWdz6/Srsq65Q5pIdmn5fhOhKAFpQwAIhii1fs0LJVu5WanKAFOVlKjOdtHwgX/DUCQJTaVF6l+5eulSQ9MGOChvbu6jgRgCNRwgAgCjU0+XRjXqFqG5s1I3Ogrsga5DoSgKNQwgAgCj3+xgat2lmpwT276MFvTJQxxnUkAEehhAFAlPmwpEL/895mxRlpfk6muqUwjgIIR5QwAIgi+6sblLu4UJI0+7wMfWVYL8eJALSFEgYAUcJaq9ufLdKeQ/WaOqynbjh3tOtIAI6BEgYAUeKvn27XG2v3KC05QU9mZyqBcRRAWOMvFACiQMlerx582T+O4qdXTtKQXoyjAMIdJQwAIlx9U7N+vKhQdY0+XXnyIH19ykDXkQC0g5MSZoy5yRizxhiz2hizyBiT4iIHAESDn7+2Qet2HdLQXl31wIyJruMAaKeQlzBjzCBJsyVNtdZOlBQvKSfUOQAgGry3sVy/W75F8XFG83MylZqc4DoSgHZydToyQVIXY0yCpK6SyhzlAICIVVFVr5sXr5Qk5V5wkrKG9nScCEBHhLyEWWt3SnpM0nZJuyRVWmtfP/p5xpjrjTErjDErysvLQx0TAMKatVa3PVOkiqp6nTqil37wb6NcRwLQQS5OR/aUNEPSCEkDJXmMMd86+nnW2oXW2qnW2qnp6emhjgkAYe1Pf9+mt9fvVbcU/ziK+Dg+lgiINC5OR54vaYu1ttxa2yjpOUlnOMgBABFpw26vfrpsnSTpkasma2CPLo4TAegMFyVsu6TTjDFdjf8TZc+TtM5BDgCIOHWNzZq9qEANTT5lTx2i6ZMGuI4EoJNcXBP2iaRnJH0haVVLhoWhzgEAkeiRV9drwx6vRvTx6N7Lx7uOA+AEOLmX2Vp7n6T7XKwbACLV2+v36OmPtiox3mhBTpY8jKMAIhoT8wEgAuz11unWJUWSpJsvHKNJg7s7TgTgRFHCACDM+XxWtywp0r7qBp0xqreuP2uk60gAAoASBgBh7g8fbdX7G8vVo2uinrgmU3GMowCiAiUMAMLYmrJKzXt1vSRp3lWT1b87H7ULRAtKGACEqdqGZs3JK1RDs0/fPHWoLprQ33UkAAFECQOAMPXQK2tVsrdKo9I9uudSxlEA0YYSBgBh6PU1u/WXT7YrKT5OC2ZlqUtSvOtIAAKMEgYAYWbPoTrd/qx/HMVtF4/RhIGMowCiESUMAMKIz2eVu7hQB2oadVZGH33vzBGuIwEIEkoYAISR//1gsz4s2afeniQ9fs0UxlEAUYwSBgBhYlVppR57fYMk6edXT1bfNMZRANGMEgYAYaCmoUlz8grU2Gx17enDdN64fq4jAQgyShgAhIEHXlqrzRXVGtMvTXdMH+c6DoAQSHAdAACc83ql/HypuFjKyJCys6W0tJCt/tVVu5T32Q4lJcRp/qxMpSQyjgKIBZQwALFt+XJp+nTJ55OqqyWPR8rNlZYtk6ZNC/rqyw7Wau5zqyRJd14yVmP7dwv6OgGEB05HAohdXq+/gHm9/gIm+b9++XhVVVBX3+yzuim/UJW1jTp3TLquPWN4UNcHILxQwgDErvx8/xGw1vh8/uVB9Jv3NumTLfvVJzVZj86cImMYRwHEEkoYgNhVXHz4CNjRqqulkpKgrbp
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, *bias*) – zachodzi **niedostateczne dopasowanie** (*underfitting*)."
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 29,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-09 15:58:14 +02:00
"[<matplotlib.lines.Line2D at 0x22ab32723d0>]"
2021-04-06 11:16:04 +02:00
]
},
2021-04-09 15:58:14 +02:00
"execution_count": 29,
2021-04-06 11:16:04 +02:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3yV5cH/8e91ssmAACFA2CTsKWGDIm4caBVxVlstWldt+1StHc/z9Pm1tba1tVatVq0LFQsOFMRRQUWGhr1JGBkEQoCQBZnn+v1xAiKCBEjOdcbn/XrxSnLOCfl6v26O31zXfV+XsdYKAAAA/uVxHQAAACAcUcIAAAAcoIQBAAA4QAkDAABwgBIGAADgACUMAADAgWYrYcaY54wxu40xa494rLUx5kNjTHbDx+Tm+vkAAACBrDlHwp6XdOFRjz0g6T/W2gxJ/2n4GgAAIOyY5lys1RjTTdK71toBDV9vkjTBWrvTGNNB0gJrbe9mCwAAABCg/H1NWKq1dqckNXxs5+efDwAAEBAiXQc4HmPMNEnTJCk+Pn5Ynz59HCcCAAD4umXLlu2x1qacyvf6u4QVGWM6HDEduft4L7TWPi3paUnKzMy0WVlZ/soIAADQKMaY3FP9Xn9PR86WdFPD5zdJetvPPx8AACAgNOcSFa9KWiyptzGmwBhzi6SHJJ1njMmWdF7D1wAAAGGn2aYjrbXXHuepc5rrZwIAAAQLVswHAABwgBIGAADgACUMAADAAUoYAACAA5QwAAAAByhhAAAADlDCAAAAHKCEAQAAOEAJAwAAcIASBgAA4AAlDAAAwAFKGAAAgAOUMAAAAAcoYQAAAA5QwgAAAByghAEAADhACQMAAHCAEgYAAOBApOsAAIDwYa1Vndeqps7r+1Pv+1h91Ne+z+u/9lyrFtE6o0srtUmIcf2fATQJShgAoEmVHqzV5zl79MnaQn2xvkDlNV5VR0SqxhOpmnqvrD29v79H23id0TVZmV2TNaxrsnqmJMjjMU0THvAjShgA4LRYa7WusEyfbC7WJ5uKtSyvRPXeQ00rUjKSvJK8XklShMcoOsKj6MiGPxEexUR+/etjfb6j5KBWFezX1j2V2rqnUjOXFUiSWsZFaVhDIRvWNVmDO7VSXHSEk2MBnAxKGADgpJUeqNVnOcVasKlYn2wuVnF59eHnIow0YscGTcheqjO3r1C7in2Krq/1/YmLVeSOAikh4ZR+bm29V+sLy7Qst0TLckuUlbtPRWXV+njjbn28cbckKdJj1L9jkoZ1ba3Mbr5ilpoU2yT/3UBTMvZ0x4X9IDMz02ZlZbmOAQBhy+v1jXYt2LRbCzYXa0VeibxH/O+jfVKsJvRO0Vm9UjR20Vwl/fRHUmXlN/+i+Hjp0UelW25pklzWWu3Yf/CrUra9RBt3lX0tmyR1So47PH05vHtr9U5NlDFMYeL0GWOWWWszT+V7GQkDABxTSWWNPs32jXR9urlYeypqDj8X6TEa2b21zuqdogm9U75eal7efOwCJvkez8lpsozGGHVKbqFOyS00eUiaJKmiuk4r8/YrK3efluWWaEXefhWUHFRByUG9tbJQknRGl1a6e2KGJvROoYzBGUoYAOCwmjqvZmTl643lBVqVv/9rI0odW8bqrN7tNKF3isamt1VCzHH+F5KR4RvxOt5IWHp684RvkBATqXEZbTUuo60kqd5rtbmoXFm5JVqeW6L5m3Zred5+fe/5LzUgLUl3nZ2h8/ulcnE//I7pSAAIR+Xl0owZUna2lJEh79VXa862Cv3pg03K3XtAkhQVYTSie2tN6OUrXuntEho3alReLqWl+T4eLTFRKiw85WvCmkJldZ1eWZqnpz7dqj0VvmvZeqcm6s6J6bp4YAdFUMZwEk5nOpISBgDhZuFCadIk392KlZVa1GuEHhp7vVa36ylJ6pkSrx+d20vn9Gmn+OONdp3kz1B8vOTxSHPnSuPGNeF/zKmrqq3XjC/z9Y9PtmhnaZUk3/IXd5ydrslDOioqgvXMcWKUMABA4xwxSrU+pbv+MOEmfdLD9/+PdpUl+vGUkZoytqcim6KAVFT4RttycnxTkFOnOh0BO56aOq9mLS/QEwtylL/voCTfhfw/nNBTVw3rpJhIlrvA8VHCAACN88wzyv/Vb/XIsO/orf4TZI1HidWVun3JTH1vw3/U4s8PN9mdi8Gmtt6r2SsL9fiCHG0t9l3P1j4pVred1UPXDO/C2mM4Ju6OBACcUElljR7PrteL1/9FNZFRiqqv1Y3LZ+uuxa+r9cEy34ua8M7FYBMV4dGVwzrp8qFpmrtmp/7+cY42FZXrf99Zr8fn5+gH43vo+lFdj39DAnCSOJMAIMQdrKnXc59v0z8WbFG56SRFSpevm6+ffvayOpcWffVCP9y5GAwiPEaXDu6oiwd20EcbivTYxzlas6NUv39vo578ZIu+P7a7bhrTTS3jolxHRZBjOhIAQlRdvVczlxXoLx9tVlGZ7y7A8T2Sdf9Dt2vAtjXf/IYAuHMxEFlr9cnmYj32cY6W5ZZIkhJjInXTmG66ZVx3JcdHO04Il7gmDABwmLVWH64v0sPvb1LO7gpJ0oC0JD1wYV/f2llBcOdiILLWavHWvfr7xzlatGWvJKltQrT+MnWIxmekOE4HVyhhAABJUtb2fXrovY3Kahix6dw6Tv91fm9dOqjj1xcjDZI7FwPVstx9+sO8Tfpi2z4ZI90xoad+fG6vprmrFEGFEgYAYa64vFq/fGuN3l/nu8ardXy07p6YrutHdlV0JMWgOdR7rR6fn6O/frRZXisN75asv107VB1axrmOBj+ihAFAGFuVv1+3v7xMO0urFBcVoVvHd9e0M3soMZYLx/1hyda9+tFrK1RUVq3kFlH689WDNbFPqutY8JPTKWH8egQAQez1rHxNeWqxdpZW6YwurfTxf52ln57fmwLmR6N6tNHce8ZrQu8UlRyo1fefz9Jv56xXTZ3XdTQEOEoYAAShmjqvfv32Wt03c7Vq6ry6fmQXvTZtNFNhjrRJiNFzNw3Xzy/qowiP0T8/26YpTy1W/r4DrqMhgFHCACDI7C6v0vXPLNGLi3MVHeHRQ98ZqN9eMZBrvxzzeIxuO6unXr9ttNJaxWlV/n5N+ttnem/NTtfREKD4FwsAQWRFXokufWyhvtxeotSkGM24bZSuGdHFdSwcYVjXZM25Z5zO75eq8qo6/XD6cv367bWqqq13HQ0BhhIGAEFixpd5mvrUEhWVVWt4t2S9c/c4De2S7DoWjqFVi2g9deMw/c+l/RQd4dGLi3P1nScWaWtxhetoCCCUMAAIcDV1Xv3izTW6f9Ya1dR79d3RXTX91lFqlxjrOhq+hTFGN4/trlk/HKOubVpo/c4yXfrYQr29cofraAgQlDAACGC7y6p07T+XaPrSPEVHevTwVYP0m8kDuP4riAzs1FLv3j1OlwzqoMqaev3otZW6f+ZqHaxhejLc8a8YAALUstx9uuSxhVqWW6IOLWP179tG6+rMzq5j4RQkxkbpsWuH6ndXDFRMpEczsvI1+fGF2lxU7joaHKKEAUAAemVpnq55eol2l1drRPfWeufucRrcuZXrWDgNxhhdN7KL3rpzrHqmxGtzUYUu+/tCvf5lvoJh4XQ0PUoYAASQ6rp6/fyN1XrwzTWqrbe6eUw3Tb91pNomxLiOhibSt0OSZt81Tlee0UlVtV7dN2u1fjxjJXdPhqFI1wEAAD5FZVW6/eVlWpG3X9GRHv3uioG6algn17HQDOJjIvXnqwdrdM82+tVba/XWykLtqajRMzdlKjYqwnU8+AkjYQAQALK2+67/WpG3Xx1bxmrW7WMoYGHgqmGd9NadY9U2IUYLc/bo1heyGBELI05KmDHmx8aYdcaYtcaYV40x3GcNICxZa/XSklxd8/QSFZdXa1QP3/VfAzu1dB0NftK7faJe/cFIilgY8nsJM8akSbpHUqa1doCkCEnX+DsHALhmrdVv3l2vX721VnVeq1vGddfLt4xUG67/CjsZqRSxcORqOjJSUpwxJlJSC0mFjnIAgDM
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 30,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-09 15:58:14 +02:00
"[<matplotlib.lines.Line2D at 0x22ab32cebe0>]"
2021-04-06 11:16:04 +02:00
]
},
2021-04-09 15:58:14 +02:00
"execution_count": 30,
2021-04-06 11:16:04 +02:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU1eH///fJvhMgCzsBEkABFUFEBeqCG221tSK471pbq1Zr66efbp/+PvXbRWu1WvtRUNGqgFatVeuKG4LsyCokQIAQyAYkkz0zc35/TLBUWRLIzJnl9Xw8fCSZmcy8vY/L5c2595xrrLUCAABAaMW5DgAAABCLKGEAAAAOUMIAAAAcoIQBAAA4QAkDAABwgBIGAADgQNBKmDHmCWNMpTFmzX6P9TDGvGOMKW7/2j1Ynw8AABDOgjkS9pSk87702D2S3rPWFkl6r/1nAACAmGOCuVirMaZA0mvW2pHtP2+QdLq1dqcxprekD6y1w4IWAAAAIEyF+pqwfGvtTklq/5oX4s8HAAAICwmuAxyMMeYmSTdJUnp6+pjhw4c7TgQACFdev9WGXR75rdWgnHRlJIftX2+IMsuWLau21uYeye+Gei+tMMb03u90ZOXBXmitfUzSY5I0duxYu3Tp0lBlBABEmP9+ebX2Ltqms4bnaeY1J7mOgxhijNl6pL8b6tORr0q6uv37qyX9I8SfDwCIMiWV9Zq9ZLvijHTP+Zw1QeQI5hIVz0taKGmYMabMGHO9pN9KOtsYUyzp7PafAQA4Yr/91+fy+a2mjxugovxM13GADgva6Uhr7aUHeeqsYH0mACC2fLq5Ru+ur1BaUrzumFzkOg7QKayYDwCISH6/1b1vrJck3TxpiPIyUxwnAjqHEgYAiEj/XFWuVWW1ystM1o2TBrmOA3QaJQwAEHGa23z6/ZsbJEl3nTNUaUksSYHIQwkDAEScpxeWasfeJg3Lz9TFY/q7jgMcEUoYACCi7G1s1cPzSiRJ90wZrvg44zgRcGQoYQCAiPLneSWqa/ZqQmGOTh96RAuVA2GBEgYAiBhbaxr09MJSGSP915ThMoZRMEQuShgAIGL8/q0NavNZfXt0X43o0811HOCoUMIAABFh+bY9en3VTiUnxOlH5wxzHQc4apQwAEDYs9bq3tcDC7NeP2GQ+mSnOk4EHD1KGAAg7L21tkJLt+5Rj/Qkfff0Ia7jAF2CEgYACGsLSqr1q1fXSpLumFykrJREx4mArsESwwCAsLS3sVW/eX29XlhWJkk6oX+2Lh03wHEqoOtQwgAAYcVaq3+u2qlf/3OtqutblRQfpx+cWaibvzZEifGcwEH0oIQBAMLGjr1N+vkrazTv80pJ0rhBPfT/LhqlIbkZjpMBXY8SBgBwzue3enphqf7w1gY1tvqUmZKgn045RtPG9lcctyVClKKEAQCc+nxXne75+2qt3L5XknT+yF76nwtGKC8rxXEyILgoYQAAJ5rbfHp4Xon++uEmef1W+VnJ+v8uHKlzRvRyHQ0ICUoYACDkPt1co5++tFqbqxskSVeMH6Afnzec5ScQUyhhAIDg8HikOXOk4mKpqEiaNk21CSn67b/W6/nF2yVJhXkZ+u1FozS2oIfjsEDoUcIAAF1v/nxpyhTJ75caGmTT0/Wvh2frlxfeqaoWq8R4o++fUahbTh+i5IR412kBJyhhAICu5fEECpjHI0naldFTPz/7u3pn6ClSi9WYfln67dQTVJSf6Tgo4BYlDADQtebMkd9vVZ6Vq3cLT9Z9k65SfXKaMloa9ZOFz+nyG76huPyJrlMCzlHCAABHrMXr05bqBm2qbNCmqnqVVNZrU3GGNt/0lJqS/r3ExNkbF+rX7/5VvT010hnDHSYGwgclDABwWLWNbSqp8mhTZYNKquq1qbJeJVX12r67UX77pRebTClJyq3frcKa7bpq+es6b+MCGUlKT5cKCx38HwDhhxIGALHoADMXbUaGKj0t2rDLo5L2krWpsl6bqupVXd96wLeJM9KgnHQNyU3XkLwMDcnNUGFGnIZMHKtu1bsO8Atx0rRpQf6fAyIDJQwAYoz9+GNVX3yZNnbvq40Zedq4tlXFHz+mjQOGq67twL+TmhivIXnpgZKVm6EheRkqzMvQwJ5pB57d+PIL/zE7UunpgQL2xhtSBveBBCRKGABEtd0NrdpY4fn3f+W1Kt6wQ3uu/etXX9wmdUtJ0LBeWSrMD5StwrxA4eqdldK5ezhOmCCVlwdG20pKAqcgp02jgAH7oYQBQBSw1mr1jlqtKqtVcYVHGyvqVVzpOfBpxNRMZbY0qKh6m4ZVbVVR9TYNrd6qoY3Vyr33f2RuuL5rQmVkSNd30XsBUYgSBgARzFqrj4qr9dB7xVq2dc9Xnk9PildhfqaG5mVoWK9MFb02V0MfvU+9PDU64LjWppKgZwYQQAkDgAhkrdX7Gyr14Hsl+mz7XklSdlqizhyep6H5mRqWn6mi/Az16Zb6n6cRN+RI/uYDvykzF4GQooQBQASx1uqddRV6aF6x1uyokyT1TE/SjZMG64rxA5WRfJjD+rRp0p13Hvg5Zi4CIUUJA4AI4Pdbvbl2l/48r0TrdwbKV05Gsr77tcG67OQBSkvq4OE8MzMwQ5GZi4BzlDAACGM+v9Xrq3fq4XnF2lhRL0nKz0rWd782RJeOG6CUxCO4+TUzF4GwQAkDgDDk9fn1z1Xl+vO8Em2uapAk9emWolvOKNTUMf2OrHztj5mLgHOUMAAII20+v15ZsUOPvF+i0ppGSVK/7qn6/hmF+s6J/ZSUEOc4IYCuQgkDgDDQ6vXrpeVleuSDEm3f3SRJGtgzTd8/o1DfHt1XifGULyDaUMIAwKEWr08vLC3Tox9s0o69gfI1OCddt55ZqAuO76MEyhcQtShhAOBIXXObrntyiZa2L7JamJehH5xZqG8c10fxnblFEICIRAkDAAdqG9t01ZOL9dn2verdLUU/+/qxOn9kr87dnxFARKOEAUCI7Wlo1RUzF2lteZ36dU/V8zeOV/8eaa5jAQgxShgAhFB1fYuumLFIn+/yaGDPND1/43j1yU51HQuAA5QwAAiRyrpmXT5jkYor6zU4N13P3zhe+VkprmMBcIQSBgAhsKu2WZc9/qk2VzdoaH6Gnr1hvHIzk13HAuAQJQwAgqxsT6Mue3yRtu1u1DG9s/S368epZwYFDIh1lDAACKJtNY269PFPtWNvk0b17aZnrh+n7LQk17EAhAFKGAAEyZbqBl32+KfaWdus0QOy9dS149QtNdF1LABhghIGAEFQUunRZY8vUqWnRScVdNcT15ykzBQKGIB/c3I/DGPMD40xa40xa4wxzxtjmB4EIGps2OXR9Mc+VaWnRacM7qmnrh1HAQPwFSEvYcaYvpJukzTWWjtSUryk6aHOAQDBsLa8VtMfW6jq+lZNLMrRE9ecpPRkTjoA+CpXR4YESanGmDZJaZLKHeUAgC6zqmyvrpy5WLVNbTpzeJ7+cvmJSkmMdx0LQJgK+UiYtXaHpPskbZO0U1KttfbtL7/OGHOTMWapMWZpVVVVqGMCQKcs27pHlz++SLVNbTrn2Hz99YoxFDAAh+TidGR3SRdKGiSpj6R0Y8wVX36dtfYxa+1Ya+3Y3NzcUMcEgA5bvGW3rpq5SJ4Wr74+qrceufxEJSU4ueQWQARxcZSYLGmLtbbKWtsm6SVJpzrIAQBHbUFJta5+YrEaWn361gl99OD0E5QYTwEDcHgujhTbJI03xqQZY4yksyStd5ADAI7KRxurdO1TS9TU5tPFY/rp/ktOUAIFDEAHubgmbJGkFyUtl7S6PcNjoc4BAEdj3ucVumHWUrV4/bp03AD9/jvHKT7OuI4FIII4mR1prf2lpF+6+GwAOFrryuv03WeWq9Xn19WnDNSvLhihwMA+AHQc4+YA0AmtXr9+9MJnavX5NW1sfwoYgCNGCQOATnjk/RKt21mnAT3S9ItvHksBA3DEKGEA0EFrdtTqkfdLJEm/v/g4VsIHcFQoYQDQAftOQ3r9VtecWqDxg3u6jgQgwlHCAKAD/jy
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (*variance*) – zachodzi **nadmierne dopasowanie** (*overfitting*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu – to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie (błąd systematyczny, *bias*)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wariancja (*variance*)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.5. Regularyzacja"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 55,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(h, fJ, fdJ, theta, X, Y, \n",
" alpha=0.001, maxEpochs=1.0, batchSize=100, \n",
" adaGrad=False, logError=False, validate=0.0, valStep=100, lamb=0, trainsetsize=1.0):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
" \n",
" XT, YT = X, Y\n",
" \n",
" m_end=int(trainsetsize*len(X))\n",
" \n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv] \n",
" XT, YT = X[mv:m_end], Y[mv:m_end] \n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
" \n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n,1)\n",
" \n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end,:], YT[start:end,:]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
" \n",
" if logError:\n",
" errorsX.append(float(i*batchSize)/m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i*batchSize)/m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
" \n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 56,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:,0], data[:,1], n)\n",
"Y = data[:,2]"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 57,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" plt.subplot(121)\n",
" plt.scatter(X[:, 2].tolist(), X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100, cmap=plt.cm.get_cmap('prism'));\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=alpha, adaGrad=adaGrad, maxEpochs=maxEpochs, batchSize=100, \n",
" logError=True, validate=validate, valStep=1, lamb=lamb)\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02),\n",
" np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1),yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
" plt.ylim(-1,1.2);\n",
" plt.xlim(-1,1.2);\n",
" plt.legend();\n",
" plt.subplot(122)\n",
" plt.plot(err[0],err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2],err[3], lw=3, label=\"Validation error\");\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8);"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 58,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-04-16 10:29:03 +02:00
"<ipython-input-48-09634685a32b>:5: RuntimeWarning: overflow encountered in exp\n",
2021-04-06 11:16:04 +02:00
" y = 1.0/(1.0 + np.exp(-x))\n",
2021-04-16 10:29:03 +02:00
"<ipython-input-57-f0220c89a5e3>:19: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-06 11:16:04 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
"No handles with labels found to put in legend.\n"
]
},
{
"data": {
2021-04-16 10:29:03 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHWCAYAAABOj2WsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3hUxdfA8e/NphdCr6GEGkooIQlNpDfpnYAioBSx+xMUbLw27F0RFEQQCIqAiCBNEOlVOqGX0AkC6ckm8/4xIYW0Td2U83keHph7Z+892YTsnp2ZM4ZSCiGEEEIIIYQQoqCzsXYAQgghhBBCCCGEJSSBFUIIIYQQQghRKEgCK4QQQgghhBCiUJAEVgghhBBCCCFEoSAJrBBCCCGEEEKIQkESWCGEEEIIIYQQhYIksEIIIUQxYxhGd8MwggzDOGUYxstpnHc3DON3wzAOGIZxxDCM0daIUwghhLifIfvACiGEEMWHYRgm4ATQBQgGdgMBSqmjyfpMBdyVUi8ZhlEOCAIqKqVirBGzEEIIcY+MwAohhBDFiz9wSil1JiEhDQT63tdHAW6GYRiAK3ALMOdvmEIIIURqksAKIYQQxUsV4GKydnDCseS+AuoDl4FDwLNKqfj8CU8IIYRIn621A8iOsmXLqho1alg7DCGEEEXE3r17byqlylk7jnxipHHs/vVE3YB/gY5ALWCdYRj/KKXupriQYYwDxgG4uLg09/LyyoNwhcgHsRFw6wzExSYdcy0PJe7/bEcIkV/Se20ulAlsjRo12LNnj7XDEEIIUUQYhnHe2jHko2CgarK2B3qkNbnRwHtKF8o4ZRjGWcAL2JW8k1JqFjALwNfXV8lrsyiUglbDkjEQ6wA4gI0t9PoUfEZaOzIhirX0XptlCrEQQghRvOwG6hiG4WkYhj0wDFhxX58LQCcAwzAqAPWAM/kapRD5Yce3sChAj8ACOLrDI8sleRWiACuUI7BCCCGEyB6llNkwjKeANYAJmKOUOmIYxoSE898CbwFzDcM4hJ5y/JJS6qbVghYit8WZYc0U2DUr6VipGjD8FyhX12phCSEyJwmsEEIIUcwopVYBq+479m2yf18GuuZ3XELki+gwPWX45JqkYx7+ELAIXMpaLy4hhEUkgRVCCFGkxMbGEhwcTFRUVKpzjo6OeHh4YGdnZ4XIhBBWd/cyLBwKVw8mHWvYH/rNADsn68UlciSj3/ui4Mvqa7MksEIIIYqU4OBg3NzcqFGjBnobU00pRUhICMHBwXh6eloxQiGEVVw9BAuGQGiymmUPvAAdXwMbKQtTmKX3e18UfNl5bZb/rUIIIYqUqKgoypQpk+pNjGEYlClTRj6hF6I4OrEW5nRPSl5tbKHPl9D5DUlei4D0fu+Lgi87r80yAiuEEKLISe9NjLy5EaIY2vUdrJ4MKl63HUrAkHlQq4N14xK5Sn6/F15Z/d7JR05CCCGEEKLoiY+HNa/AqheTklf3avDYWkleRa4KCQmhadOmNG3alIoVK1KlSpXEdkxMTIaP3bNnD88880ym92jdunVuhVvoyQisEEIIIYQoWszRsGwCHFmadKyyDwQEglsF68UliqQyZcrw77//AjBt2jRcXV158cUXE8+bzWZsbdNOu3x9ffH19c30Htu2bcudYC0QFxeHyWRKt52ejL7O3CQjsEIIIYocpVSWjgshipDI2zB/QMrk1asXjPpDkleRb0aNGsULL7xAhw4deOmll9i1axetW7emWbNmtG7dmqCgIAA2bdpEr169AJ38jhkzhvbt21OzZk2++OKLxOu5urom9m/fvj2DBg3Cy8uLESNGJL62rVq1Ci8vLx544AGeeeaZxOsmFxcXx6RJk/Dz86Nx48bMnDkz8bodOnRg+PDheHt7p2pHRUUxevRovL29adasGRs3bgRg7ty5DB48mN69e9O1a/7sviYjsELkpqNHYfFiuHkTqlaF4cOhWjVrRyVEseLo6EhISEiqgh73Kh06OjpaMTohRJ66Eww/DYIbx5KO+Y2FHu+DTeYjSKLwq/HyH3l27XPv9cxS/xMnTrB+/XpMJhN3795l8+bN2Nrasn79eqZOncqvv/6a6jHHjx9n48aNhIaGUq9ePZ544olU28vs37+fI0eOULlyZdq0acPWrVvx9fVl/PjxbN68GU9PTwICAtKMafbs2bi7u7N7926io6Np06ZNYuK5a9cuDh8+jKenJ5s2bUrR/vjjjwE4dOgQx48fp2vXrpw4cQKA7du3c/DgQUqXLp2l5ye7JIEVIjfcvAkDBsCePRAbC2Yz2NvDtGnQuzfMmwdOsr+cEPnBw8OD4OBgbty4kercvb3mhBBF0LUjOnlNvk1O52nQ5jmQAj/CCgYPHpw49fbOnTs8+uijnDx5EsMwiI2NTfMxPXv2xMHBAQcHB8qXL8+1a9dSvW75+/snHmvatCnnzp3D1dWVmjVrJm5FExAQwKxZs1Jdf+3atRw8eJAlS5YkxnXy5Ens7e3x9/dPsZVN8vaWLVt4+umnAfDy8qJ69eqJCWyXLl3yLXkFSWCFyLmwMGjVCs6f18nrPfcW7a9cCT16wIYNYMH6ASFEztjZ2ck+r0IUN2c3Q+AIiL6r2zZ20PdraDLUunGJYs3FxSXx36+99hodOnRg2bJlnDt3jvbt26f5GAcHh8R/m0wmzGazRX0sXSKjlOLLL7+kW7duKY5v2rQpRbz3x5/R9e9/XF6TBFaInJoxAy5dSpm8JhcVBXv36kS2b9/8jU0IIYQo6g4tgeVPQFzCB8f2bjB0vlQaLqayOs03v9y5c4cqVaoAet1obvPy8uLMmTOcO3eOGjVqsHjx4jT7devWjRkzZtCxY0fs7Ow4ceJEYlwZefDBB1mwYAEdO3bkxIkTXLhwgXr16rFv377c/lIyJUWchMgJpeCTTyAyMuN+YWHwwQf5E1NBopReF7x1K5w+be1ohBBCFCVKwbYv4dfHkpJX14owepUkr6LAmTx5MlOmTKFNmzbExcXl+vWdnJz45ptv6N69Ow888AAVKlTA3d09Vb/HH3+cBg0a4OPjQ6NGjRg/fnyao7z3mzhxInFxcXh7ezN06FDmzp2bYiQ4PxmFsSKjr6+v2rNnj7XDEALCw6FkSb3mNTMlSsCdO3kfU0GgFPzwA7z5pl4fbGurp1TXqgXvvAN9+lg7QiFSMAxjr1Iq830MRLrktVnkq/h4WDMVds5IOla2Hjy8BEpK8cTi5tixY9SvX9/aYVhdWFgYrq6uKKV48sknqVOnDs8//7y1w7JIWt/D9F6bZQRWiJywsdHJmqV9iwOlYOJEeOYZvS44PFwn7pGRcPgwBAToUWshhBAiO8zRsGR0yuS1WmsY86ckr6JY++6772jatCkNGzbkzp07jB8/3toh5QlZAytETjg5gacnnDqVcT/DgNat8ycma1u5EubP14lrWiIi4NVXoXNnaNw4f2MTQghRuEWHQuBwXbTpngZ9of8ssJMtskTx9vzzzxeaEdecKCZDQkLkocmTwdk54z7OzjBpUv7EY23Tp6efvN4TEwOffpo/8QghhCgawm/C3F4pk1f/8TDoB0lehShGJIEVIqdGjYLmzdPf59XZGQYOhHbt8jUsq4iNhZ07M+8XFwcrVuR9PEIIIYqG2xdgTje48m/SsY6vQY/3wUa2qBOiOJEEVoicsrODtWvhkUfA0RFcXfXfbm763y++qAsaFYdN1KOjLV/re2+fXCGEECIj14/B7G4QkrBcx7CBXp/Bgy8Wj9dWIUQKsgZWiNzg6AgzZ+qtclatgtu3oUIF6NEj/ZHZosjFRY84372bed+qVfM+HiGEEIXbxV2wYDBE3dZtkz0M/F6vexVCFEsyAitEbnJ311V2n3gCBgwoXskr6E/Cx40De/uM+7m4wAsv5E9MQojC48Qa+LEPHFpi7UhEQXByPczrm5S82rvCiCWSvIoCp3379qxZsybFsc8++4yJEydm+Jh7W4899NBD3L59O1WfadOm8dFHH2V47+XLl3P06NHE9uuvv8769euzEn6hIwmsECJ3vfCCTlDTm9Z
2021-04-06 11:16:04 +02:00
"text/plain": [
"<Figure size 1152x576 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrazenie regularyzacyjne** – zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów – mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej – funkcja kosztu\n",
"\n",
"$$\n",
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} h_\\theta(x^{(i)}) - y^{(i)} \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ – parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej – gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej – funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej – gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 59,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h,theta,X,y,lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0/m \\\n",
" * -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0) \\\n",
" + lamb/(2*m) * np.sum(np.power(theta[1:] ,2))\n",
" return j\n",
"\n",
"def dJ_(h,theta,X,y,lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" g[1:] += lamb/m * theta[1:]\n",
" return g"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 60,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(min=0.0, max=0.5, step=0.005, value=0.01, description=r'$\\lambda$', width=300)\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 61,
2021-04-06 11:16:04 +02:00
"metadata": {
2021-04-16 10:29:03 +02:00
"scrolled": true,
2021-04-06 11:16:04 +02:00
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-04-16 10:29:03 +02:00
"model_id": "e49f6a277b3d4b378df0c46330a07094",
2021-04-06 11:16:04 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
2021-04-16 10:29:03 +02:00
"execution_count": 61,
2021-04-06 11:16:04 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)"
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 38,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=lamb)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label='training error')\n",
" plt.plot(Lambda, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(r'$\\lambda$')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8)"
]
},
{
"cell_type": "code",
2021-04-16 10:29:03 +02:00
"execution_count": 62,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-16 10:29:03 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHmCAYAAABK9WIBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd5Rd5X3v/88zp0w504t6GwkB6kgIIRsZS8Zg7MS44CLXkFybmMSxndwfwU5uwCbxDU4cL9vXIfywf9i+XgQWxjW5gLFuKAZTJFGEEBLq0mik0fReTnl+f+zTp2jaPmXm/VrrrP3sffbZeg4Yaz7zfYqx1goAAAAAgFxXkO0OAAAAAAAwHgRYAAAAAEBeIMACAAAAAPICARYAAAAAkBcIsAAAAACAvECABQAAAADkBVcDrDHmemPMIWPMEWPMl0d4v8IY8x/GmFeNMa8bY/7Yzf4AAAAAAPKXcWsfWGOMR9Kbkq6V1CBpt6SPWWsPJN3zN5IqrLW3GWPqJB2SNM9aO+RKpwAAAAAAecvNCuwWSUestceigfRBSe9Lu8dKKjPGGEmlktokhVzsEwAAAAAgT7kZYBdKOp103hC9lux7klZJapT0mqQvWmsjLvYJAAAAAJCnvC4+24xwLX288rskvSLpHZJWSPqtMeZ31tqulAcZc7OkmyUpEAhcfumll7rQXRf0NkudDU47UCdVLMpufwAAAAAgx+3du7fFWls30ntuBtgGSYuTzhfJqbQm+2NJd1lnIu4RY8xxSZdKejH5JmvtvZLulaTNmzfbPXv2uNbpabXnh9J/fslpb/yQ9L7vZbc/AAAAAJDjjDEnR3vPzSHEuyWtNMbUG2P8knZK+nXaPackXSNJxpi5ki6RdMzFPmWWtyjRDrMuFQAAAABMhWsVWGttyBjzeUm/keSRdJ+19nVjzOei798j6e8l/cgY85qcIce3WWtb3OpTxnn9iXZoIHv9AAAAAIAZwM0hxLLWPiLpkbRr9yS1GyVd52Yfsiq5AhuiAgsAAAAAU+FqgJ31PIWJNhVYAAAAYFoFg0E1NDRoYICftfNRUVGRFi1aJJ/PN+7PEGDd5E0KsMyBBQAAAKZVQ0ODysrKtGzZMhkz0iYoyFXWWrW2tqqhoUH19fXj/pybizjBSwUWAAAAcMvAwIBqamoIr3nIGKOampoJV88JsG5KCbBUYAEAAIDpRnjNX5P5d0eAdRNzYAEAAIAZq6OjQ3ffffekPvue97xHHR0dY95z++23a9euXZN6/kxFgHVTyhzYwez1AwAAAMC0GyvAhsPhMT/7yCOPqLKycsx77rzzTr3zne+cdP8mKr3PF/oOE71vOhBg3ZQyhJgACwAAAMwkX/7yl3X06FFddtlluvXWW/Xkk09qx44d+vjHP65169ZJkt7//vfr8ssv15o1a3TvvffGP7ts2TK1tLToxIkTWrVqlT772c9qzZo1uu6669Tf3y9Juummm/Twww/H77/jjju0adMmrVu3TgcPHpQkNTc369prr9WmTZv0p3/6p1q6dKlaWlqG9fXxxx/XW97yFm3atEkf/vCH1dPTE3/unXfeqW3btumnP/3psPMHHnhA69at09q1a3XbbbfFn1daWqrbb79dV155pZ577jl3/gGPgFWI3ZSyDywBFgAAAHDLsi//H9eefeKuPxjx+l133aX9+/frlVdekSQ9+eSTevHFF7V///74yrr33Xefqqur1d/fryuuuEI33nijampqUp5z+PBhPfDAA/r+97+vj3zkI/rZz36mT37yk8P+vNraWr300ku6++679c1vflM/+MEP9LWvfU3veMc79JWvfEWPPfZYSkiOaWlp0T/8wz9o165dCgQC+sY3vqFvfetbuv322yU529k888wzkpxQHjtvbGzU1q1btXfvXlVVVem6667TL3/5S73//e9Xb2+v1q5dqzvvvHPy/2AngQqsmzz+RJsACwAAAMx4W7ZsSdkW5rvf/a42bNigrVu36vTp0zp8+PCwz9TX1+uyyy6TJF1++eU6ceLEiM/+4Ac/OOyeZ555Rjt37pQkXX/99aqqqhr2ueeff14HDhzQVVddpcsuu0w//vGPdfLkyfj7H/3oR1Puj53v3r1b27dvV11dnbxerz7xiU/o6aefliR5PB7deOON4/lHMq2owLopuQIbHpSslVglDQAAAJixAoFAvP3kk09q165deu6551RSUqLt27ePuG1MYWFi6qHH44kPIR7tPo/Ho1AoJMnZT/VCrLW69tpr9cADD1ywz8nnYz27qKhIHo/ngn/2dCPAusnjlUyBZCPOKxKSPL5s9woAAACYcUYb5uumsrIydXd3j/p+Z2enqqqqVFJSooMHD+r555+f9j5s27ZNDz30kG677TY9/vjjam9vH3bP1q1b9ed//uc6cuSILrroIvX19amhoUEXX3zxmM++8sor9cUvflEtLS2qqqrSAw88oL/4i7+Y9u8wEQwhdhvzYAEAAIAZqaamRldddZXWrl2rW2+9ddj7119/vUKhkNavX6+/+7u/09atW6e9D3fccYcef/xxbdq0SY8++qjmz5+vsrKylHvq6ur0ox/9SB/72Me0fv16bd26Nb4I1Fjmz5+vf/zHf9SOHTu0YcMGbdq0Se973/um/TtMhBlPyTmXbN682e7Zsyfb3Ri/u5ZKA9H9nW49JgVqxr4fAAAAwLi88cYbWrVqVba7kVWDg4PyeDzyer167rnndMstt8QXlcoHI/07NMbstdZuHul+hhC7LX0eLAAAAABMk1OnTukjH/mIIpGI/H6/vv/972e7S64iwLrNm7wS8fAJ2wAAAAAwWStXrtTLL7+c7W5kDHNg3ZYyB3Yoe/0AAAAAgDxHgHWbJ7EkNkOIAQAAAGDyCLBu8yYFWFYhBgAAAIBJI8C6jQALAAAAANOCAOs2AiwAAACAqNLSUklSY2OjPvShD414z/bt23WhrUO//e1vq6+vL37+nve8Rx0dHdPX0RxFgHUbc2ABAAAApFmwYIEefvjhSX8+PcA+8sgjqqysnI6uXVAoFBrzfLyfmwwCrNtSKrBsowMAAADMFLfddpvuvvvu+PlXv/pV/cu//It6enp0zTXXaNOmTVq3bp1+9atfDfvsiRMntHbtWklSf3+/du7cqfXr1+ujH/2o+vv74/fdcsst2rx5s9asWaM77rhDkvTd735XjY2N2rFjh3bs2CFJWrZsmVpaWiRJ3/rWt7R27VqtXbtW3/72t+N/3qpVq/TZz35Wa9as0XXXXZfy58Q0Nzfrxhtv1BVXXKErrrhCzz77bPy73Xzzzbruuuv06U9/etj5yZMndc0112j9+vW65pprdOrUKUnSTTfdpL/6q7/Sjh07dNttt035nzn7wLotJcCyjQ4AAADgiq9WuPjszhEv79y5U1/60pf0Z3/2Z5Kkhx56SI899piKior0i1/8QuXl5WppadHWrVt1ww03yBgz4nP+7d/+TSUlJdq3b5/27dunTZs2xd/7+te/rurqaoXDYV1zzTXat2+fvvCFL+hb3/qWnnjiCdXW1qY8a+/evfrhD3+oF154QdZaXXnllXr729+uqqoqHT58WA888IC+//3v6yMf+Yh+9rOf6ZOf/GTK57/4xS/qL//yL7Vt2zadOnVK73rXu/TGG2/En/3MM8+ouLhYX/3qV1PO3/ve9+rTn/60/uiP/kj33XefvvCFL+iXv/ylJOnNN9/Url275PF4JvfPPwkB1m1UYAEAAIAZaePGjTp//rwaGxvV3NysqqoqLVmyRMFgUH/zN3+jp59+WgUFBTpz5oyampo0b968EZ/z9NNP6wtf+IIkaf369Vq/fn38vYceekj33nuvQqGQzp49qwMHDqS8n+6ZZ57RBz7wAQUCAUnSBz/4Qf3ud7/TDTfcoPr6el122WWSpMsvv1wnTpwY9vldu3bpwIED8fOuri51d3dLkm644QYVFxfH30s+f+655/Tzn/9ckvSpT31Kf/3Xfx2/78Mf/vC0hFeJAOu+lDmwVGABAACAmeRDH/qQHn74YZ07d047d+6
2021-04-06 11:16:04 +02:00
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.6. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 40,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=0.01, trainsetsize=m)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label='training error')\n",
" plt.plot(M, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(u'trainset size')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
2021-04-09 15:58:14 +02:00
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
" * Gdy model jest nadmiernie dopasowany, zebranie większej ilości danych uczących może pomóc.\n",
" * Gdy model jest niedostatecznie dopasowany, pomóc może zwiększenie liczby stopni swobody modelu."
]
},
2021-04-06 11:16:04 +02:00
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 41,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-09 15:58:14 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHgCAYAAACcrIEcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzde3hU5b3//c+dyeQECSQBlYMQwBOncAqCwq4glapttZ4Qj6W2smu7a9t9XW7tfp6tbtv99PBrvayPtf7Qqq0/C+K5dau17EerKLQEFOSggBhCQCGBhEPOM3M/f6xJMkkmRzKzZs28X9c118xa615rvkFUPtzrey9jrRUAAAAAAIkuze0CAAAAAADoDQIsAAAAAMATCLAAAAAAAE8gwAIAAAAAPIEACwAAAADwBAIsAAAAAMAT0t0uoK+GDRtmi4qK3C4DAAAAABADGzdurLLWDo92LGYB1hjzuKSvSDpkrZ3SxZgFkh6Q5JdUZa29oKfrFhUVqbS0dCBLBQAAAAAkCGPM3q6OxfIW4iclXdzVQWPMUEkPS7rMWjtZ0jUxrAUAAAAA4HExC7DW2rclHelmyPWSXrDWlofHH4pVLQAAAAAA73NzEaezJOUbY94yxmw0xtzc1UBjzHJjTKkxprSysjKOJQIAAAAAEoWbizilS5olaZGkbEnrjDHrrbU7Ow601q6QtEKSSkpKbFyrBAAAAJCQmpubVVFRoYaGBrdLQT9kZWVp9OjR8vv9vT7HzQBbIWfhplpJtcaYtyVNk9QpwAIAAABARxUVFcrNzVVRUZGMMW6Xgz6w1urw4cOqqKjQuHHjen2em7cQvyzpn4wx6caYHElzJO1wsR4AAAAAHtLQ0KDCwkLCqwcZY1RYWNjn2fNYPkZnpaQFkoYZYyok3SPncTmy1j5ird1hjHld0hZJIUmPWWu3xqoeAAAAAMmH8Opd/flnF8tViK+z1o6w1vqttaOttb8LB9dHIsb8L2vtJGvtFGvtA7GqBQAAAAAGWk1NjR5++OF+nXvppZeqpqam2zF333231qxZ06/rJys3byEGAAAAAM/qLsAGg8Fuz3311Vc1dOjQbsfcd999+uIXv9jv+vqqY809/Qx9HTcQCLAAAAAA0A933XWXPvnkE02fPl133HGH3nrrLS1cuFDXX3+9pk6dKkn62te+plmzZmny5MlasWJF67lFRUWqqqpSWVmZJk6cqFtvvVWTJ0/W4sWLVV9fL0latmyZnnvuudbx99xzj2bOnKmpU6fqo48+kiRVVlbqoosu0syZM/XP//zPGjt2rKqqqjrV+sYbb+i8887TzJkzdc011+jEiROt173vvvs0f/58Pfvss522V65cqalTp2rKlCm68847W683ePBg3X333ZozZ47WrVsXm1/gKNxchRgAAAAABkTRXf8ds2uX/ezLUff/7Gc/09atW/XBBx9Ikt566y394x//0NatW1tX1n388cdVUFCg+vp6zZ49W1dddZUKCwvbXWfXrl1auXKlHn30US1ZskTPP/+8brzxxk7fN2zYMG3atEkPP/ywfvnLX+qxxx7Tf/7nf+rCCy/Uj370I73++uvtQnKLqqoq/eQnP9GaNWs0aNAg/fznP9f999+vu+++W5LzOJu1a9dKckJ5y/aBAwc0d+5cbdy4Ufn5+Vq8eLFeeuklfe1rX1Ntba2mTJmi++67r/+/sP3ADCwAAAAADJBzzz233WNhHnzwQU2bNk1z587Vvn37tGvXrk7njBs3TtOnT5ckzZo1S2VlZVGvfeWVV3Yas3btWi1dulSSdPHFFys/P7/TeevXr9f27ds1b948TZ8+Xb///e+1d+/e1uPXXnttu/Et2xs2bNCCBQs0fPhwpaen64YbbtDbb78tSfL5fLrqqqt680syoJiBBQAAAIABMmjQoNbPb731ltasWaN169YpJydHCxYsiPrYmMzMzNbPPp+v9Rbirsb5fD4FAgFJzvNUe2Kt1UUXXaSVK1f2WHPkdnfXzsrKks/n6/G7BxoBFgAAAIDndXWbbyzl5ubq+PHjXR4/evSo8vPzlZOTo48++kjr168f8Brmz5+v1atX684779Qbb7yh6urqTmPmzp2r7373u9q9e7fOOOMM1dXVqaKiQmeddVa3154zZ46+//3vq6qqSvn5+Vq5cqW+973vDfjP0BfcQgwAAAAA/VBYWKh58+ZpypQpuuOOOzodv/jiixUIBFRcXKz/+I//0Ny5cwe8hnvuuUdvvPGGZs6cqddee00jRoxQbm5uuzHDhw/Xk08+qeuuu07FxcWaO3du6yJQ3RkxYoR++tOfauHChZo2bZpmzpypyy+/fMB/hr4wvZlyTiQlJSW2tLTU7TIAAAAAuGzHjh2aOHGi22W4qrGxUT6fT+np6Vq3bp1uu+221kWlvCDaP0NjzEZrbUm08dxCjJPXVCu9+f9Ih3dLi38iDTvT7YoAAACAlFBeXq4lS5YoFAopIyNDjz76qNslxRQBFienukxaeb10aFt4h5GuX+VmRQAAAEDKOPPMM/X++++7XUbcEGDRf3vekp5dJtVHNIqXvSMFmyWf362qAAAAACQpFnFC31krrXtYeurK9uFVkppOSJ9tdqcuAAAAAEmNAIu+aW6QXrpN+suPJBt09g0+TTp9TtuYsnfcqQ0AAABAUiPAoveO7peeuETaHPEA5FEl0vK3pBk3te0rWxvvygAAAACkAAIseqd8vbRigXRgU9u+6TdKy/5byhshFc1vPzbYHPcSAQAAgEQ3ePBgSdKBAwd09dVXRx2zYMEC9fTo0AceeEB1dXWt25deeqlqamoGrtAERYBFz0qfkJ78ilR7yNk2PumS/yVd/pDkz3L25RdJeaOdz/TBAgAAAN0aOXKknnvuuX6f3zHAvvrqqxo6dOhAlNajQCDQ7XZvz+sPAiy6FmiSXvmh9MoPpFB4RjWnULr5ZWnOcsmYtrHGtJ+FpQ8WAAAASe7OO+/Uww8/3Lp977336le/+pVOnDihRYsWaebMmZo6dapefvnlTueWlZVpypQpkqT6+notXbpUxcXFuvbaa1VfX9867rbbblNJSYkmT56se+65R5L04IMP6sCBA1q4cKEWLlwoSSoqKlJVVZUk6f7779eUKVM0ZcoUPfDAA63fN3HiRN16662aPHmyFi9e3O57WlRWVuqqq67S7NmzNXv2bL377rutP9vy5cu1ePFi3XzzzZ229+7dq0WLFqm4uFiLFi1SeXm5JGnZsmX613/9Vy1cuFB33nnnSf+a8xgdRHfikLT6Zql8Xdu+06ZKS/8oDR0T/Zyi+dKW8DNgy9ZK838Y+zoBAAAASbp3SAyvfTTq7qVLl+oHP/iBvvOd70iSVq9erddff11ZWVl68cUXlZeXp6qqKs2dO1eXXXaZTOQEUITf/va3ysnJ0ZYtW7RlyxbNnDmz9dh//dd/qaCgQMFgUIsWLdKWLVt0++236/7779ebb76pYcOGtbvWxo0b9cQTT+jvf/+7rLWaM2eOLrjgAuXn52vXrl1auXKlHn30US1ZskTPP/+8brzxxnbnf//739cPf/hDzZ8/X+Xl5frSl76kHTt2tF577dq1ys7O1r333ttu+6tf/apuvvlmff3rX9fjjz+u22+/XS+99JIkaefOnVqzZo18Pl//fv0jEGDR2f5N0jM3Ssf2t+2bcpV02UNSRk7X50Xrg+V5sAAAAEhSM2bM0KFDh3TgwAFVVlYqPz9fY8aMUXNzs/793/9db7/9ttLS0rR//34dPHhQp512WtTrvP3227r99tslScXFxSouLm49tnr1aq1YsUKBQECfffaZtm/f3u54R2vXrtUVV1yhQYMGSZKuvPJKvfPOO7rssss0btw4TZ8+XZI0a9YslZWVdTp/zZo12r59e+v2sWPHdPz4cUnSZZddpuzs7NZjkdvr1q3TCy+8IEm66aab9G//9m+t46655poBCa8SARYdbX5G+vPtUqAhvMNIX7xXmvf99rcMR9PSB3usoq0PdnRJbOsFAAAAXHT11Vfrueee0+eff66lS5dKkp5
2021-04-06 11:16:04 +02:00
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.7. Warianty metody gradientu prostego\n",
"\n",
"* Batch gradient descent\n",
"* Stochastic gradient descent\n",
"* Mini-batch gradient descent"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### _Batch gradient descent_\n",
"\n",
"* Klasyczna wersja metody gradientu prostego\n",
"* Obliczamy gradient funkcji kosztu względem całego zbioru treningowego:\n",
" $$ \\theta := \\theta - \\alpha \\cdot \\nabla_\\theta J(\\theta) $$\n",
"* Dlatego może działać bardzo powoli\n",
"* Nie można dodawać nowych przykładów na bieżąco w trakcie trenowania modelu (*online learning*)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### *Stochastic gradient descent* (SGD)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"#### Algorytm\n",
"\n",
"Powtórz określoną liczbę razy (liczba epok):\n",
" 1. Randomizuj dane treningowe\n",
" 1. Powtórz dla każdego przykładu $i = 1, 2, \\ldots, m$:\n",
" $$ \\theta := \\theta - \\alpha \\cdot \\nabla_\\theta \\, J \\! \\left( \\theta, x^{(i)}, y^{(i)} \\right) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Randomizacja danych** to losowe potasowanie przykładów uczących (wraz z odpowiedziami)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### SGD - zalety\n",
"\n",
"* Dużo szybszy niż _batch gradient descent_\n",
"* Można dodawać nowe przykłady na bieżąco w trakcie trenowania (*online learning*)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### SGD\n",
"\n",
"* Częsta aktualizacja parametrów z dużą wariancją:\n",
"\n",
"<img src=\"http://ruder.io/content/images/2016/09/sgd_fluctuation.png\" style=\"margin: auto;\" width=\"50%\" />\n",
"\n",
"* Z jednej strony dzięki temu nie utyka w złych minimach lokalnych, ale z drugiej strony może „wyskoczyć” z dobrego minimum"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### _Mini-batch gradient descent_"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"#### Algorytm\n",
"\n",
"1. Ustal rozmiar \"paczki/wsadu\" (*batch*) $b \\leq m$.\n",
"2. Powtórz określoną liczbę razy (liczba epok):\n",
" 1. Powtórz dla każdego batcha (czyli dla $i = 1, 1 + b, 1 + 2 b, \\ldots$):\n",
" $$ \\theta := \\theta - \\alpha \\cdot \\nabla_\\theta \\, J \\left( \\theta, x^{(i : i+b)}, y^{(i : i+b)} \\right) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### _Mini-batch gradient descent_\n",
"\n",
"* Kompromis między _batch gradient descent_ i SGD\n",
"* Stabilniejsza zbieżność dzięki redukcji wariancji aktualizacji parametrów\n",
"* Szybszy niż klasyczny _batch gradient descent_\n",
"* Typowa wielkość batcha: między kilka a kilkaset przykładów\n",
" * Im większy batch, tym bliżej do BGD; im mniejszy batch, tym bliżej do SGD\n",
" * BGD i SGD można traktować jako odmiany MBGD dla $b = m$ i $b = 1$"
]
},
{
"cell_type": "code",
2021-04-09 15:58:14 +02:00
"execution_count": 42,
2021-04-06 11:16:04 +02:00
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"# Mini-batch gradient descent - przykładowa implementacja\n",
"\n",
"def MiniBatchSGD(h, fJ, fdJ, theta, X, y, \n",
" alpha=0.001, maxEpochs=1.0, batchSize=100, \n",
" logError=True):\n",
" errorsX, errorsY = [], []\n",
" \n",
" m, n = X.shape\n",
" start, end = 0, batchSize\n",
" \n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
" for i in range(int(maxSteps)):\n",
" XBatch, yBatch = X[start:end,:], y[start:end,:]\n",
"\n",
" theta = theta - alpha * fdJ(h, theta, XBatch, yBatch)\n",
" \n",
" if logError:\n",
" errorsX.append(float(i*batchSize)/m)\n",
" errorsY.append(fJ(h, theta, XBatch, yBatch).item())\n",
" \n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" \n",
" return theta, (errorsX, errorsY)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Porównanie uśrednionych krzywych uczenia na przykładzie klasyfikacji dwuklasowej zbioru [MNIST](https://en.wikipedia.org/wiki/MNIST_database):\n",
"\n",
"<img src=\"sgd-comparison.png\" width=\"70%\" />"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wady klasycznej metody gradientu prostego, czyli dlaczego potrzebujemy optymalizacji\n",
"\n",
"* Trudno dobrać właściwą szybkość uczenia (*learning rate*)\n",
"* Jedna ustalona wartość stałej uczenia się dla wszystkich parametrów\n",
"* Funkcja kosztu dla sieci neuronowych nie jest wypukła, więc uczenie może utknąć w złym minimum lokalnym lub punkcie siodłowym"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.8. Algorytmy optymalizacji metody gradientu\n",
"\n",
"* Momentum\n",
"* Nesterov Accelerated Gradient\n",
"* Adagrad\n",
"* Adadelta\n",
"* RMSprop\n",
"* Adam\n",
"* Nadam\n",
"* AMSGrad"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Momentum\n",
"\n",
"* SGD źle radzi sobie w „wąwozach” funkcji kosztu\n",
"* Momentum rozwiązuje ten problem przez dodanie współczynnika $\\gamma$, który można trakować jako „pęd” spadającej piłki:\n",
" $$ v_t := \\gamma \\, v_{t-1} + \\alpha \\, \\nabla_\\theta J(\\theta) $$\n",
" $$ \\theta := \\theta - v_t $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Przyspiesony gradient Nesterova (*Nesterov Accelerated Gradient*, NAG)\n",
"\n",
"* Momentum czasami powoduje niekontrolowane rozpędzanie się piłki, przez co staje się „mniej sterowna”\n",
"* Nesterov do piłki posiadającej pęd dodaje „hamulec”, który spowalnia piłkę przed wzniesieniem:\n",
" $$ v_t := \\gamma \\, v_{t-1} + \\alpha \\, \\nabla_\\theta J(\\theta - \\gamma \\, v_{t-1}) $$\n",
" $$ \\theta := \\theta - v_t $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Adagrad\n",
"\n",
"* “<b>Ada</b>ptive <b>grad</b>ient”\n",
"* Adagrad dostosowuje współczynnik uczenia (*learning rate*) do parametrów: zmniejsza go dla cech występujących częściej, a zwiększa dla występujących rzadziej:\n",
"* Świetny do trenowania na rzadkich (*sparse*) zbiorach danych\n",
"* Wada: współczynnik uczenia może czasami gwałtownie maleć\n",
"* Wyniki badań pokazują, że często **starannie** dobrane $\\alpha$ daje lepsze wyniki na zbiorze testowym"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Adadelta i RMSprop\n",
"* Warianty algorytmu Adagrad, które radzą sobie z problemem gwałtownych zmian współczynnika uczenia"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Adam\n",
"\n",
"* “<b>Ada</b>ptive <b>m</b>oment estimation”\n",
"* Łączy zalety algorytmów RMSprop i Momentum\n",
"* Można go porównać do piłki mającej ciężar i opór\n",
"* Obecnie jeden z najpopularniejszych algorytmów optymalizacji"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Nadam\n",
"* “<b>N</b>esterov-accelerated <b>ada</b>ptive <b>m</b>oment estimation”\n",
"* Łączy zalety algorytmów Adam i Nesterov Accelerated Gradient"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### AMSGrad\n",
"* Wariant algorytmu Adam lepiej dostosowany do zadań takich jak rozpoznawanie obiektów czy tłumaczenie maszynowe"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img src=\"contours_evaluation_optimizers.gif\" style=\"margin: auto;\" width=\"80%\" />"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img src=\"saddle_point_evaluation_optimizers.gif\" style=\"margin: auto;\" width=\"80%\" />"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.9. Metody zbiorcze"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
" * **Metody zbiorcze** (*ensemble methods*) używają połączonych sił wielu modeli uczenia maszynowego w celu uzyskania lepszej skuteczności niż mogłaby być osiągnięta przez każdy z tych modeli z osobna."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Na metodę zbiorczą składa się:\n",
" * dobór modeli\n",
" * sposób agregacji wyników"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * Warto zastosować randomizację, czyli przetasować zbiór uczący przed trenowaniem każdego modelu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uśrednianie prawdopodobieństw"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"#### Przykład\n",
"\n",
"Mamy 3 modele, które dla klas $c=1, 2, 3, 4, 5$ zwróciły prawdopodobieństwa:\n",
"\n",
"* $M_1$: [0.10, 0.40, **0.50**, 0.00, 0.00]\n",
"* $M_2$: [0.10, **0.60**, 0.20, 0.00, 0.10]\n",
"* $M_3$: [0.10, 0.30, **0.40**, 0.00, 0.20]\n",
"\n",
"Która klasa zostanie wybrana według średnich prawdopodobieństw dla każdej klasy?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Średnie prawdopodobieństwo: [0.10, **0.43**, 0.36, 0.00, 0.10]\n",
"\n",
"Została wybrana klasa $c = 2$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Głosowanie klas"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"#### Przykład\n",
"\n",
"Mamy 3 modele, które dla klas $c=1, 2, 3, 4, 5$ zwróciły prawdopodobieństwa:\n",
"\n",
"* $M_1$: [0.10, 0.40, **0.50**, 0.00, 0.00]\n",
"* $M_2$: [0.10, **0.60**, 0.20, 0.00, 0.10]\n",
"* $M_3$: [0.10, 0.30, **0.40**, 0.00, 0.20]\n",
"\n",
"Która klasa zostanie wybrana według głosowania?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Liczba głosów: [0, 1, **2**, 0, 0]\n",
"\n",
"Została wybrana klasa $c = 3$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Inne metody zbiorcze\n",
"\n",
" * Bagging\n",
" * Boostng\n",
" * Stacking\n",
" \n",
"https://towardsdatascience.com/ensemble-methods-bagging-boosting-and-stacking-c9214a10a205"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}