umz21/wyk/05_Regresja_wielomianowa.ipynb

1752 lines
378 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
2021-04-06 09:51:15 +02:00
"## Uczenie maszynowe zastosowania\n",
"# 5. Regresja wielomianowa. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ szerokość działki, $x_2$ długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
2021-04-07 15:03:18 +02:00
{
"cell_type": "markdown",
2021-04-14 08:03:54 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-04-07 15:03:18 +02:00
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
2021-04-14 08:03:54 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print('Algorithm does not converge!')\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv('data_flats.tsv', header=0, sep='\\t',\n",
" usecols=['price', 'rooms', 'sqrMetres'])\n",
"data = np.matrix(alldata[['sqrMetres', 'price']])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(m, 3 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x246ed465910>]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAFvCAYAAADkPtfiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVzU1foH8M+ZGYZlQNxABHdB3Jfcy0ottWixbMEs27y3LKvrre7PunWX7m273ZZrpZWZmVlmpWkL2eoS5YZr4gbuCAKKwgCyzZzfH4cRxBlmBmb4DjOf9+vla3Tmy3eewODxnOc8j5BSgoiIiIi0o9M6ACIiIqJAx4SMiIiISGNMyIiIiIg0xoSMiIiISGNMyIiIiIg0xoSMiIiISGPNMiETQiwQQuQJIXa5eP2tQojdQoh0IcTH3o6PiIiIyB2iOfYhE0JcBqAYwCIpZV8n1yYA+BTAWCnlaSFEtJQyryniJCIiInJFs1whk1KuA1BQ+zkhRHchxCohxBYhxC9CiJ7VL/0RwBwp5enqj2UyRkRERD6lWSZkDswD8LCUcjCAxwHMrX6+B4AeQohfhRAbhBBXaRYhERERkR0GrQPwBCFEOICLAXwmhLA9HVz9aACQAGA0gA4AfhFC9JVSnmnqOImIiIjs8YuEDGql74yUcqCd17IAbJBSVgI4JITYB5WgbW7KAImIiIgc8YstSyllEVSydQsACGVA9csrAIypfr4t1BbmQU0CJSIiIrKjWSZkQoglANYDSBRCZAkhpgG4HcA0IcQOAOkAJlZf/h2AU0KI3QBWA/iLlPKUFnETERER2dMs214QERER+ZNmuUJGRERE5E+YkBERERFprNmdsmzbtq3s0qULACD7zFmcKqlAm3AjYiNDtQ2MiJpWVhaQm+v49ZgYIC6u6eKhgFFpkdh7oggCQK/2LaDXCacfQ4Fhy5YtJ6WUUQ352GaXkHXp0gVpaWkAgN+zCnHdm6lobTLityevgNHABT+igDF/PjBzJlBScuFrJhPw7LPAtGlNHxf5vXfXHcRzKXtwVZ8YvD11sNbhkA8RQhxp6Mc26wymb1wLJLaLQEFJBVbv40QkooCSnAzoHHwL0+nU60Re8MW24wCAGwZxBZY8p1knZEII3Dy4AwDg8y1ZGkdDRE0qIgJISVGPJpN6zmSqeT48XNv4yC/tO2HG7pwitAgxYEzPBu1MEdnV7LYs65o4KBYvrtqL1XvzcLK4HG3Dg51/EBH5h1GjgOxsYOlSIDMTiI9XK2NMxshLVmxXq2PX9I9FsEGvcTTkT5p9QhYdEYLRPaLw0948rNyejWmjumodEhE1pfBw1opRk7BaJVZWb1feyO1K8rBmvWVpw21LIiLyto2HCpBdWIa4lqEY0rmV1uGQn/GLhGxsr2i0DAvCnpwipGcXah0OERH5oRW1Vsd0bHVBHuYXCVmwQY+JA2IBcJWMiIg8r6zSgpTfcwAANwyK1Tga8kd+kZABwM2DOwIAVm7PRkWVVeNoiIjIn/y8Nw/m8ir0i4tEfHSE1uGQH/KbhIw9yYiIyFvYe4y8zW8SMvYkIyIibzhdUoE1+/KgE8B1A9prHQ75Kb9JyADVk0yvE+d6khERETXW17/noNIiMSohCtERIVqHQ37KrxIyW0+yKqvEyu3ZWodDRER+wHa6chK3K8mL/CohA9iTjIiIPOfoqVJsOXIaYUY9xvdpp3U45Mf8LiFjTzIiIvIU26ikCX1iEGZs9sNtyIf5XUIWbNDjhoFqWZmrZERE1FBSSizfqn6O8HQleZvfJWRAzbYle5IREVFDbThYgMOnShHTIgSXdG+jdTjk5/wyIesT2wI9Y1RPsp/3sicZERG5b8mmowCAW4d0gEHvlz8uyYf45d8w9iQjIqLGOF1SgVW7TkAI4NahHbUOhwKA1xIyIURHIcRqIcQeIUS6EOJPdq4ZLYQoFEJsr/71d0+9/8SBcaon2b485JvZk4yIiFy3fNtxVFisuDQhCh1ahWkdDgUAb66QVQF4TErZC8AIADOEEL3tXPeLlHJg9a9/eerNoyKCMSYxCharxMrqUzJERETOSCnxSfV25ZRhXB2jpuG1hExKmSOl3Fr9ezOAPQCa9JhK7W1LKWVTvjURETVTW4+eRkZeMdqGB+OKXuw9Rk2jSWrIhBBdAAwCsNHOyyOFEDuEEN8KIfp48n3H9myHVmFB2HvCjPTsIk/emoiI/NSSTccAqH/UB7GYn5qI1/+mCSHCASwDMFNKWTcr2gqgs5RyAIA3AKxwcI/7hBBpQoi0/Px8l9/baNBhInuSERGRi4rKKvH1TjV6bzKL+akJeTUhE0IEQSVjH0kpl9d9XUpZJKUsrv59CoAgIURbO9fNk1IOkVIOiYqKciuGmp5kx9mTjIiI6rVyezbKKq0Y2a0NurQ1aR0OBRBvnrIUAN4DsEdK+aqDa2Kqr4MQYlh1PKc8GYetJ9np0kr2JCMiIoeklFiyURXzT2YxPzUxb66QXQJgKoCxtdpaJAkhpgshpldfczOAXUKIHQBeBzBZerj6nj3JiIjIFb8fL8TunCK0DAvChD4xWodDAcZrk1KllKkAhJNr3gTwprdisJk4MA4vfLv3XE+yqIhgb78lERE1M7Zi/kmDOiAkSK9xNBRoAuL4CHuSERFRfUrKq/Bl9c+H27hdSRoIiIQMYE8yIiJy7Oud2SipsGBI51ZIaBehdTgUgAImIWNPMiIicsS2XTl5WCeNI6FAFTAJGXuSERGRPXtPFGH7sTOICDHgmn7ttQ6HAlTAJGQAe5IREdGFPqleHbthYBxCjSzmJ20EVELGnmRERFRbWaUFy7eqXRP2HiMtBVRCxp5kRERU27e7clBUVoX+HSLRJzZS63AogAVUQgYANwyKg0EnzvUkIyKiwLVkY3Ux/1AW85O2Ai4haxsejNGJ0exJRkQU4DLzirHpcAHCjHpcPzBW63AowAVcQgbUFPd/lsaeZEREgWrpZjW38voBsQgP9trgGiKXBGRCNrZnNFqbjNiXa8a2Y2e0DoeIiJpYeZUFy7aqXRL2HiNfEJAJmdGgw61D1GmaxeuPaBwNERE1tR9256KgpAI9YyIwoAOL+Ul7AZmQAcDtwztBCODrnTkoKKnQOhwiImpCtt5jtw3rBCGExtEQBXBC1rF1GMYkRqPCYsWnace0DoeIiJrI0VOlSM08iWCDDjdUT3Ah0lrAJmQAMHVEZwDARxuPwGJlcT8RUSBYmqaK+a/p1x6RYUEaR0OkBHRCdlmPKHRsHYpjBWexbn++1uEQEZGXVVqs+CzN1pmfxfzkOwI6IdPrBG4frlbJPtzA4n4iIn/389485JnL0T3KhKFdWmkdDtE5AZ2QAcCtQzrCaNBh9b48HCso1TocIiLyok82qe1KFvOTrwn4hKy1yYhr+7WHlMBHG49qHQ4REXlJ9pmzWLs/H0a9DpMu6qB1OETnCfiEDABury7u/zTtGMqrLBpHQ0RE3vBp2jFYJTC+Tzu0Nhm1DofoPEzIAFzUqSV6t2+BgpIKfPv7Ca3DISIiD7NYJT7dXNN7jMjXMCEDIITA1JEs7ici8lfrMvKRXViGTq3DMLJbG63DIboAE7JqEwfGIiLYgC1HTiM9u1DrcIiIyINsxfzJQztCp2MxP/keJmTVwowG3DRYFXku3sDifiIif5FnLsNPe/Kg1wncMpjF/OSbmJDVckd1cf+KbcdRVFapcTREROQJn2/JQpVV4spe0YhuEaJ1OER2MSGrJT46HBd3b4OzlRYs35KldThERNRIVqs8N0icnfnJlzEhq8M23/LDDUcgJedbEhE1Z+sPnsLRglLEtQzFZQlRWodD5BATsjqu7N0O7VoE40B+CdYfPKV1OERE1AhLqov5bxnSAXoW85MPY0JWR5Beh8lD1bL2YrbAICJqtgpKKvB9ei50Qo3JI/JlTMjsuG1YJ+h1At+l5yK3qEzrcIiIqAGWb81ChcWKy3tEIbZlqNbhENWLCZkdMZEhGN+7HSxWeW65m4iImg8pa75/s5ifmgMmZA7YivuXbDqKSotV42i
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-07 15:03:18 +02:00
"[[ 397521.22456017]\n",
" [-841359.33647153]\n",
" [2253763.58150567]\n",
" [-244046.90860749]]\n"
]
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAFvCAYAAADkPtfiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3iUVfYH8O+dkjZpkEJIKKGEjoDSRQRsgIJ9g2DHgouu6LqrrrvqFhd/6FpBXDc2xIIiIkpbXSwE6U0SWkJLB0JIMunJzP398WZICFOTmbxTvp/n4RmZefPOMdHk5Nx7zxFSShARERGRejRqB0BEREQU6JiQEREREamMCRkRERGRypiQEREREamMCRkRERGRypiQEREREanMJxMyIcR7QohTQogMJ6//jRBivxAiUwjxiafjIyIiInKF8MU+ZEKI8QAqACyRUg5ycG0KgM8BTJJSnhVCxEspT7VHnERERETO8MkKmZTyZwAlzZ8TQvQSQqwTQuwUQmwUQvRrfOl+AIuklGcbP5bJGBEREXkVn0zIbHgHwCNSyksAPAHgrcbn+wDoI4TYJITYIoSYrFqERERERFbo1A7AHYQQ4QDGAvhCCGF5OrjxUQcgBcAEAF0AbBRCDJJSlrZ3nERERETW+EVCBqXSVyqlHGrltTwAW6SU9QCOCSEOQUnQtrdngERERES2+MWSpZSyHEqydSsACMWQxpdXApjY+HwslCXMo6oESkRERGSFTyZkQohPAWwG0FcIkSeEmA1gFoDZQoi9ADIBXN94+XoAZ4QQ+wH8AOAPUsozasRNREREZI1Ptr0gIiIi8ic+WSEjIiIi8idMyIiIiIhU5nOnLGNjY2VycrJH7p1bUoXS6nokRIYgLiLY8QcQkXry8oCTJ22/npAAJCW1XzxEAMxS4kChEWYpkRIfjhC9Vu2QqB3t3LmzWEoZ15qP9bmELDk5GTt27PDIvdfuK8RDH+/C0K7RWDn3Uo+8BxG5SVoaMG8eUFl54WsGA/CPfwCzZ7d/XBTQPt56As98lYERyR3wxZyxaodD7UwIcaK1H8sly2Yu7xuHEL0Ge3JLUVhWrXY4RGRPaiqgsfEtTKNRXidqR1JKfLRZ+Xl8++juKkdDvoYJWTNhQTpM6BMPAFiXUaRyNERkV0QEsGaN8mgwKM8ZDE3Ph4erGx8FnJ0nzuJgkRExhiBMHpSgdjjkY3xuydLTpgxOwLrMIqzNKMI9l/ZQOxwismfcOKCgAFi2DMjOBnr3VipjTMZIBUu3KNWx1BFdEazj3jFyDROyFib1i0eQVoPtx0tw2ljLzf1E3i48nHvFSHXFFbVYs68IQgAzR3VTOxzyQVyybCEiRI9xKbGQEvjvfi5bEhGRY5/vyEWdyYxJfePRpUOY2uGQD2JCZoVl7Z/7yIiIyBGTWeKTrTkAgNvHcDM/tQ4TMiuu6t8JWo3A5iNnUFpVp3Y4RETkxX46fAp5Z6vRtWMoLk9pVQsqIiZk1nQwBGFMzxg0mCW+22+n8SQREQW8c60uRnWHRiNUjoZ8FRMyG7hsSUREjuSWVOHHw6cRpNPg1uFd1Q6HfBgTMhuuHtgJQgAbs4phrKlXOxwiIvJCH2/NgZTAdYM7o6MhSO1wyIcxIbMhPiIEI7p3RJ3JjA0HT6kdDhEReZnaBhM+35ELgJv5qe2YkNnBZUsiIrJl7b4ilFTWYUDnSAzrGq12OOTjmJDZYUnIfjx0GtV1JpWjISIib/JRY2f+O8Z0hxDczE9tw4TMjsToUAzpGo3qehN+OsxlSyIiUuwvKMfOE2cREazD9UMT1Q6H/AATMgemNFbJ1nLZkoiIGi3dqlTHbr6kC8KCOIWQ2o4JmQOWhGzDgVOobeCyJRFRoCuvqcfK3fkAgNtHc24luQcTMge6xxjQv3MkjLUN2JRdrHY4RESksq925aOqzoQxPWPQOz5C7XDITzAhc8K5Zct9XLYkIgpkUkosbbaZn8hdPJaQCSG6CiF+EEIcEEJkCiEetXLNBCFEmRBiT+OfZz0VT1tYErLvDpxEvcmscjRERKSWrcdKkHWqAvERwbhqQCe1wyE/4smdiA0Afi+l3CWEiACwUwjxnZRyf4vrNkopr/NgHG2W0ikCveIMOHK6EluPlmBcSqzaIRERkQosrS5mjOwGvZaLTOQ+HvuvSUpZKKXc1fjPRgAHACR56v08bcqgzgCAtRmFKkdCRERqOGWswfqMImg1AreN5NxKcq92Se+FEMkAhgHYauXlMUKIvUKItUKIge0RT2tYmsSuzzwJk1mqHA0REbW3Zdty0WCWuKp/J3SOClU7HPIzHk/IhBDhAL4EME9KWd7i5V0AuksphwB4E8BKG/d4QAixQwix4/Tp054N2IaBiZHo2jEUxRW12HnirCoxEBGROhpMZnyyLQcAcPtobuYn9/NoQiaE0ENJxj6WUq5o+bqUslxKWdH4z2sA6IUQF2zQklK+I6UcLqUcHhcX58mQbRJCcNmSiChA/e/gKRSW1aBnrAFje8WoHQ75IU+eshQA3gVwQEr5io1rEhqvgxBiZGM8ZzwVU1udW7bMKIKUXLYkIgoUllYXs0Z3h0bDuZXkfp48ZXkpgDsA7BNC7Gl87k8AugGAlPJtALcAeEgI0QCgGsAM6cWZztAu0UiIDEFBWQ325pVhaNdotUMiIiIPO1ZciY1ZxQjRa3DLxV3UDof8lMcSMillOgC7v0ZIKRcCWOipGNxNoxGYPCgBH/xyHGszCpmQEREFgI8bq2PThyQiKkyvcjTkr9hExUWWZct1XLYkIvJ7NfUmfLEzDwBwx+hkdYMhv8aEzEUjkjsixhCEE2eqcKDQqHY4RETkQd/sLUBZdT2GdI3G4C5RaodDfowJmYu0GoGrByrjMtbxtCURkd+SUp7rzH/7qG4qR0P+jglZK0xubH+xLpPDxomI/NWOE2fxa14ZOoTpMW1IotrhkJ9jQtYKY3rGIDJEh8MnK3DkdIXa4RARkQe8u/EYAOCO0d0RoteqHA35OyZkrRCk0+DKAZZlS1bJiIj8zYkzlVi/vwhBWg1uH8PO/OR5TMhaiV37iYj81/ubjkNKYPrQRMRHhKgdDgUAJmStdFlKLAxBWmTklyO3pErtcIiIyE3KquvxxY5cAMDscT1UjoYCBROyVgrRazGxXzwALlsSEfmTZdtzUFlnwqW9Y9C/c6Ta4VCAYELWBly2JCLyLw0mMz7YdBwAcN+4nuoGQwGFCVkbTOgbh2CdBrtySlFUVqN2OERE1EZrM4pQUFaDnnEGXN4nTu1wKIAwIWsDQ7Du3P+w69mTjIjIp0kpkZautLqYPa4HNBq745iJ3IoJWRtNGazMtlyzj8uWRES+bFfOWezNLUWHMD1uGtZF7XAowDAha6Mr+ndCsE6DrcdKkF9arXY4RETUSmmNjWBnjeqO0CA2gqX2xYSsjSJD9LiqsUnsyt35KkdDREStkVtShfWZRdBrBe5kI1hSARMyN7j5YqW0vWJXHqSUKkdDRESuen/TcZglMH1IEuIj2QiW2h8TMje4LCUWseFBOHK6Er/mlakdDhERuaC8ph7LtucAYCNYUg8TMjfQaTWYPiQJgFIlIyIi37FsWy4q60wY2ysGAxLZCJbUwYTMTW66WEnIvvm1EHUNZpWjISIiZzSYzPjgl+MAgPsuY3WM1MOEzE0GJkaib6cIlFTW4afDp9UOh4iInLAuswj5pdXoGWfAhD7xaodDAYwJmZsIIc5VybhsSUTkG95tbAR776VsBEvqYkLmRtcPTYIQwP8OnEJZVb3a4RARkR07T5zF7pxSRIfpz52WJ1ILEzI3SogKwbjesagzmfHtvgK1wyEiIjveTT8KAJg1qhsbwZLqmJC5WdOyJZvEEhF5q9ySKqzLsDSCTVY7HCImZO52zcAEhAVpsfPEWRwvrlQ7HCIisuKDX5RGsNMuSkQnNoIlL8CEzM3CgnSYMqgzAGAFRykREXkdY009lm3PBQDcy0aw5CWYkHmAZdnyq90cpURE5G2Wbc9FRW0DxvSMwaCkKLXDIQLAhMwjRveMQee
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1,x2,n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n+1):\n",
" for i in range(m+1):\n",
" X.append(np.multiply(np.power(x1,i),np.power(x2,(m-i))))\n",
" return np.hstack(X)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv('polynomial_logistic.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c='r', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1p, X2p, c='g', marker='o', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" return fig"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3df3Ac533n+c8XEqGUQSQx9cOhafOkXaJcK/k2tMjTJmuWaceRV4LKJignBXm1Xt2eqliqO5VIUcmKqew5Lmev4tNdSEO7XqcUrCvOFi9CtgRSTMRIlnVZ+7guJwa5+kGeogDRObICns1QXnuIXAG05nt/9DTRGMwAM8DM9PN0v19VU5jp7gGexsz0fPvp7/N9zN0FAACAcPXl3QAAAACsjIANAAAgcARsAAAAgSNgAwAACBwBGwAAQOAI2AAAAAJ3dd4NyMN1113nN954Y97NAAAAWOL06dN/6+7X1y8vZcB24403ampqKu9mAAAALGFmf91oOZdEAQAAAkfABgAAEDgCNgAAgMARsAEAAAQu94DNzL5sZt83s7NN1puZPW5mM2b2spndmll3h5m9Vlt3qHetBgAA6J3cAzZJvyfpjhXW3ylpqHbbJ+lLkmRmV0n6Ym39zZI+ZWY3d7WlAAAAOcg9YHP3b0h6a4VN9kj6fU98S9JPm9lmSbdJmnH31919QdKTtW0BAAAKJfeArQVbJH038/jN2rJmywEAAAolhoDNGizzFZY3/iVm+8xsysymLly40LHGNeUuHTuW/GxlOQAAQBMxBGxvSnpv5vF7JM2usLwhd3/C3Xe6+87rr18240PnHT8u3X239PDDi8GZe/L47ruT9QDWhhMiACUTQ8B2QtI/r40W/TlJP3T385K+LWnIzG4ys35J99S2DcPIiLR/vzQ2thi0Pfxw8nj//mQ9gLXhhAhAyeQ+l6iZ/YGkD0u6zszelPQbkjZIkrv/jqSTkoYlzUj6O0n/orbux2b2oKTnJF0l6cvufq7nO9CMmXTkSHJ/bCy5SUmwduRIsj507skX38jI0vY2Ww70SvaESEo+U5wQASgw8xJeOti5c6f3bPJ3d6kv05FZrcYT5Bw7lvRWZIPMbE/h5KS0d2/erURZZd+LqZhOiABEozJf0cS5CU1fnNbQtUMavWVUg9cMduVvmdlpd9+5bDkBWxfF/oVSfxm3vhcjlv1AccV8QgQgCqfeOKXho8OqelVzl+c0sGFAfdank/ee1K6tuzr+95oFbDHksMWpPtipVpfntIUuvaybtruvj2AN4Ug/Y1mxfLYARKEyX9Hw0WFVFiqauzwnSZq7PKfKQrL80sKlnrWFgK1bjh9fHtxkg59YkqKzuXgpgjXkrQgnROgeRhGjQybOTajq1Ybrql7VxNmJnrWFgK1bRkaSHK9scJMGP5OT8SRF04uBEBXlhAjdwShidMj0xekrPWv15i7PaeatmZ61hYCtW8yShPz6nqhmy0NELwZC1coJEb0s5UVZJXTI0LVDGtgw0HDdwIYBbdu0rWdtYdABmmOUKGLG+7fcYh/0hSBU5ivacniLKguVZesG+wc1+8isNvZv7OjfZJRoBgFbi6jDhpgxyhmMIkYHhDJKlIANQHHRy1JevPbooEsLlzRxdkIzb81o26ZtGn3/aMd71lIEbBkEbEBBtNILLNHLUjb0riJi1GEDUDyrjQY8doxRzmXEKGIUEAEbgHitNBrwoYekr3+dUc5lVJSySkAGl0QBxK1ZrtKHPiR98pOMEgUQFXLYMgjYgIJpNBpQYpQzgOiQwwagmJrNxiHFX7y6lyg0DASNgA1AvJiNo3OYzgkIGgEbgHgxGrBzmM6p8+i1RAeRwwYgXszG0VkUm+0spkfDGjDoIIOADQCaYDqnzqGAL9aAQQcAgJU1G8BRwhP7jqi/RN/XR7CGNSNgAwAwgKNb0qAti2ANa0DABgBgAEe30GuJDiFgAwAwnVM30GuJDro67wYAAAKQFhRudTlW16zXUkqW797N/xYtI2ADAKAb0l7LbHmZNGjbvZteS7SFgA0AgG6g1xIdRA4bAABA4AjYAAAAAhdEwGZmd5jZa2Y2Y2aHGqz/VTN7sXY7a2Zvm9mm2rrvmNkrtXVMXwAAAAon94DNzK6S9EVJd0q6WdKnzOzm7Dbu/r+5+3Z33y7p1yR93d3fymzykdr6ZVM5RItJgwEAQE3uAZuk2yTNuPvr7r4g6UlJe1bY/lOS/qAnLcvT8ePJpMHZWj1pTZ+776aIJQAUGSftqBNCwLZF0nczj9+sLVvGzN4h6Q5JT2UWu6SvmtlpM9vXtVb22sjI8gKL2QKMDAcHgOLipB11Qijr0WhCtWanDh+X9J/rLod+0N1nzewGSc+b2V+4+zeW/ZEkmNsnSVu3bl1vm7uvvsDi2Fhyn0mDAaD4siftUnLc56S91Mxz7lY1s5+X9Fl3/ye1x78mSe7+Ww22PSbpP7r7/9Hkd31W0iV3/99X+ps7d+70qalIxie4S32ZjtBqlWANAMoge2UlxUl74ZnZ6UY5+SFcEv22pCEzu8nM+iXdI+lE/UZm9lOSdkt6OrNswMwG0/uSPibpbE9a3QtMGgwA5ZW90pIiWCut3AM2d/+xpAclPSfpVUl/6O7nzOwBM3sgs+leSV9197nMsndJOmVmL0n6c0nPuPuzvWp7V6QJpdXq0u7vt9+W7rqLSYMBoCw4aUdGCDlscveTkk7WLfuduse/J+n36pa9Lulnu9y83koTTe+6S3rmmSRYO3xYOngweZwGbUwaDADFVT/QLJvDJtHTVkK597ChTppomgZnabCWfmhPnFicTBgAUEzHjy8N1tLLo+lABEaJlk7ugw7yEPygAxJNAaDc3JOgbGRk6XG/2XIURrNBBwRsoWJ0KACgWwgIgxXyKFHUI9EUANBNFOaNDgFbaOoTTavV5TMexIhpVlBAlfmKxs+M69HnH9X4mXFV5it5NwloDbPpRIdLoqE5diw5u8nmrGU/SJOTcY4OLep+obROvXFKw0eHVfWq5i7PaWDDgPqsTyfvPaldW3fl3TxgdeRLB4lLorEYGUmCl8OHky5p98XRQZOT0p49cfZIcTaHAqnMVzR8dFiVhYrmLielIecuz6mykCy/tHAp5xYCLYixMG+Jr9YQsIXGLOlpevrppfkFZklQc/BgnPkF9UPS+/qWD1kHIjFxbkJVrzZcV/WqJs5O9LhFwBrEmC9d4tw7ArZQFbFHKsazOaCB6YvTV3rW6s1dntPMWzM9bhHQpljzpYv43diiIGY6QAPZ4GZsbDHHIOYeqWZnc7HuD0pr6NohDWwYaBi0DWwY0LZN23JoFdCGZoV5pbBn0ynid2OLGHQQuqLUY1tpmpUSfNAQuDZrUlXmK9pyeIsqC8tHhQ72D2r2kVlt7N/Yi5YDaxN7HbaifDc2wKCDGMWYX9BMp6dZKXHiKbqgzbyYwWsGdfLekxrsH9TAhgFJSc/aYH+ynGANwUvzpeuDnGbLQ1Kk78Z2uHvpbjt27PDgVavu+/e7S8nPRo9jUq26T04ub3ez5auZnFz+v8j+jyYnO9NulMMaP2+V+YqPnx73Q88f8vHT416Zr/S44UDJFO27sQFJU94gdsk9eMrjFkXARkCyshJ8aNFj2fdQeuO9BISlBN+NzQI2cthC5ZHnF/SCU/QRHeZt5sXwOQV6qwSfOXLYYhNzfkGvUCYEneRryIspcU0oIBcl/m4kYEO81vIFCzSS7a1tpyZViWtCAegt6rAhTvVfjNkyIRI9bWjPWmtSlbgmFIDeIocNcWIyeXTSevNi2s19A4AmyGFDsYyMJEFZthcj7e2YnORSFNqznrwYLs0D6AECNsSpxImnCMhac98AoE3ksAHAWsU6HyOA6BCwAcBapZfmszluadC2e3fPLs1X5iuaODeh6YvTGrp2SKO3jGrwmsGe/G0AvcGgg04pQTE/AOE59cYpDR8dVtWrmrs8p4ENA+qzPp2896R2bd2Vd/MAtIlBB91GAU0APVaZr2j46LAqCxXNXZ6TJM1dnlNlIVl+aeFSzi0E0CkEbJ1
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób, \n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0/(1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X*theta, eps)\n",
"\n",
"def J(h,theta,X,y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0)/m\n",
" if lamb > 0:\n",
" j += lamb/(2*m) * np.sum(np.power(theta[1:],2))\n",
" return j\n",
"\n",
"def dJ(h,theta,X,y,lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" if lamb > 0:\n",
" g[1:] += lamb/float(y.shape[0]) * theta[1:] \n",
" return g\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta]) \n",
" return theta, errors"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, Xpl, Ypl, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)\n",
"print(r'theta = {}'.format(theta))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02),\n",
" np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-14-d7e55b0bd73a>:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-07 15:03:18 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n"
]
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVxUVf8H8M9BAWHADXeUx43ccyPNtMg0U0pF08hsecoyy3Jr0bbnycq0LBVbbLGnxayoX4iZZlm2oZmimYkbuCsuCKLDgAzMnN8fw+gwzCDLzNxlPu/Xa14w995hvgPMvd8553vOEVJKEBEREZF6BSgdABERERFVjAkbERERkcoxYSMiIiJSOSZsRERERCrHhI2IiIhI5ZiwEREREalcbaUDUEKjRo1k69atbXeOHgVOnwaaNAFatSp/n4i8Ky8P2L+//HvO/l5s1w6oX1+5+IiIfGjr1q1npJSNnbf7ZcLWunVrpKWl2e5ICUyfDiQm2i4OADB1KrBwISCEckFWlZRASgoQH182bnfbidTC8T04bpztvWe/r8X3IhFRDQghDrvc7o8T58bExMiLCRtgu2AEOPQOW63au0CsWAGMHl32Aud4IUxOBkaNUjpKItcc/1ftmKwRkcKMRUYkpSchIycD0RHRSOiSgPDgcK8+pxBiq5Qyptx2v0/Y9HKhcHwd9vjZSkFaoocPTkSkG6lHUhG3PA5WaYWp2ARDoAEBIgBrxq/BgKgBXntedwmbfw86cE5yrFbb18RE23YtJbNC2JIye/wBAUzWSDvs70VHWnsPEpFuGIuMiFseB6PZCFOxCQBgKjbBaLZtzzfn+zwm/07YUlLKJzWOSU9KitIRVo09fkdM1kjt9PTBiXxDSlsZiPP/hrvtRFWUlJ4Eq7S63GeVViTtTPJxRP6esMXH22q7HJMae9KTnGzbryVspSAt0tsHJ/K+lBRbza7j+c1+/hs9mv8zVGMZORkXW9acmYpNyMzN9HFE/p6wCWErxHdugXK3Xc3YSkFadbkPTiNHsjWFyoqPL39+czz/ae3DNqlOdEQ0DIEGl/sMgQa0b9jexxH5e8KmJ2ylIK263AenlSvZmkJlsWaXvCyhSwIChOsUKUAEIKFrgo8j4ihR/eA8bKRXHAFN7nBkMXmR2kaJMmEjIvXTy/Q75Dn8nyAfyDfnI2lnEjJzM9G+YXskdE1AWFCYV5+TCZsDJmxEGsTWFLJjqyvpGOdhIyL1utw0DVYrR0DTJazZJT/EhI2IlHe5aRpGjOAIaLpEb1MyEVWCXy7+TkQq4zhNA1C2i+vmm4HVq8u3pgC2/bGxXCfX39hHEFd2O5EOMGEjIuU5J2H2xG3qVGDBAtvUHo4jne3Hx8ayNYWI/AIHHRCRenBgARH5OQ46ICJ149Jq3sF1N4l0gQkbESmPS6t5D9fdJNIFJmxEpDxO0+A9XHfT+9iKST7AGjYiUh6XVvMurgrgXStW2ForHX+njr/z5GSOXqVK40oHDpiwEZHf4YAO7+HKC+RBHHRAROSvOKDDu5y78AMCmKyRxzFhIyLSMw7o8A3HuQTtmKyRBzFhIyLSMw7o8A22YpKXMWEjItIzrrvpfWzFJB/g0lRERHrGdTe9z10rJsD1bsljmLARERHVhL0Vk+vdkhcxYSMiIqoJtmKSD7CGjYiIiEjlmLARERERqRwTNiIiIiKVU0XCJoQYKoTYK4TIFELMcrH/CSHE9tLbTiGERQjRsHTfISHEP6X7uN4UERER6Y7igw6EELUAvAXgRgDHAGwRQnwjpdxlP0ZKOR/A/NLjhwOYLqXMdfgxA6WUZ3wYNhEREZHPqKGFrQ+ATCnlASmlGcAXAEZWcPw4AJ/7JLLLkRJYsaL8pIjuthMRETnidYQqSQ0JWySAow73j5VuK0cIEQpgKICvHTZLAD8IIbYKISZ6LUpXUlKA0aPLzmRtn/F69Ggu+UJERBXjdYQqSfEuUQCuVsZ195FiOIANTt2h/aWUWUKIJgDWCSH2SCl/K/cktmRuIgBERUXVNGab+PhLy48AtkkSHZcn4WSJRERUEV5HqJLUkLAdA9DK4X5LAFlujr0dTt2hUsqs0q+nhRArYOtiLZewSSnfA/AeAMTExHimjdl5+RH7G85xeRIiIiJ3eB2hShJS4f5xIURtAPsADAJwHMAWAHdIKdOdjqsH4CCAVlJKU+k2A4AAKaWx9Pt1AF6QUq6t6DljYmJkWpoHB5RKCQQ49C5brXyTERFR5fE6QqWEEFullDHO2xWvYZNSlgB4BMD3AHYD+FJKmS6EmCSEmORw6CgAP9iTtVJNAaQKIf4GsBnA6sslax4K+lIxqL3WwNG0aSwUJSKiynF1HXGsaSOCOrpEIaVcA2CN07Z3nO5/BOAjp20HAHT3cnjl2YtEp0yx3V+8uOz3ixdfaubmJyQiInLHnqzZa9Yca9gAXkfoIlUkbJrjXCTqKnFLTARiY7nwLxERuZeSUjZZc65p43WESilew6YEj9SwSWnr+ly8+NI2+xsOsL0J4+P5yYiIiNyT0vX1wt120j13NWxM2GqCRaJERETkQaoddKBZLBIlIiIt4GoKusCErTqci0St1ks1bUzaiIhITbiagi5w0EF1sEiUSJeMRUYkpSchIycD0RHRSOiSgPDgcKXDIqoZrqagC6xhqw5/KBL1h9dI5CD1SCrilsfBKq0wFZtgCDQgQARgzfg1GBA1QOnwiGrGsWfIjqspqBJr2DxJiEstaI79/+62axGb0MmPGIuMiFseB6PZCFOxbW5uU7EJRrNte745X+EIiWrIsSfITivJGmvwADBhqxk9JzWOTej218cmdNKppPQkWKXV5T6rtCJpZ5KPIyLyMC0PlNPztbYKWMNWE3quC+CCxORHMnIyLrasOTMVm5CZm+njiIg8SOurKej5WlsFTNhqQu9Jjf31OdY86OF1ETmJjoiGIdDgMmkzBBrQvmF7BaIi8hCtD5TT+7W2kjjowBP0OoEui1TJTxiLjIhcEAmj2VhuX3hQOLIey0JYUJgCkRF5gF4Gken1WuuEgw68Rct1ARXhXHOkVdUoUA4PDsea8WsQHhQOQ6ABgK1lLTzItp3JGmmafUCcc3Ljbrsa6fVaWxVSSr+79e7dW3qE1Srl1KlSAravru5rVXJy+dfh+PqSk6v286xW22OcfyfuthNVVw3+d41FRrl061I5a90suXTrUmksMvooaCJyS8/XWhcApEkXuYviyZMSN48lbJ5OatTE0wmWnn9XpC5+dnIn0j0/u364S9hYw1YTUid1Ab4gKxilxLo48jTH/zc7/p8RaZOfXWvd1bAxYSPf4UWUfElWo0DZzy4MRKQ+HHRAytPyTNukLfYPB44qU6DMCTqJSKWYsJHvVPciSlQVzt3vVRnhzBU+iEilOHEu+UZFNWwAW9rIc2oySSgn6CQilWING/nGihW2LiXHC59jEpecrO6Ztkk7PFGHVp36t8uGJVFUaEbB+QIUnC/EhYIiFBeVoLiouPRWAkuJ5eLxwv58AggMDkRQcCAC61z6GhoeAkO9UISE1bl0LBFpHgcdOGDCpgAWc5NWVGFwjKXEgpwTZ5F9NAfZR8/gzPFc5J0+h3PZ55GXfR55p88hL/s8THkmmM4XwmpxvcB8TQQECBjqhcJQ34C6EeFo2Kw+GjStjwZN66FBs/qIaN4ATVs3RrPWTRDeMIzJHZHKuUvY2CVKvmGfUbuy24mU4KLrPv/hqTie+Amy9hUg65ohyDpwClmZJ3HqUDZysnJhtZb90Fs7sBbqN6mHeo3ron6TeoiMbo6w+gaElLaIGeqGICQ8BHUMwQgMDixtPauNwOBA1KpdCxAAHH6k1Wq92BJnvmBrjSsqNKPQWAjTuQLk55lgOlcA07kCnM8xIvtYDval7Ufe6XPlYgsJq3MxeWt5RQu06hiJqI62r/Ua1fXBL5iIqostbEREAAqMhTjw5jIcfnoeDve8Docj2uLw7mPIyTpb5riIFg3Qol0zNGvTBE1aNULjVhFo1DICTUq/htU3qKIVy2Kx4HxOPnKO5+L
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')\n",
"plot_decision_boundary(fig, theta, Xpl)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv('polynomial_logistic.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, Xpl, Ypl, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-14-d7e55b0bd73a>:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-07 15:03:18 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n"
]
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVxU1fsH8M8ZBEQYF3DHXTHXNKW01MykUkxFs8gsv/mzTMvdXMr2zbJc0CwrK7PsG5WIlrj2rRR33MUN3BVXXBgGZJk5vz+G0QFmWGfm3pn5vF+vecHce2d4gJl7nznnOecIKSWIiIiISL00SgdAREREREVjwkZERESkckzYiIiIiFSOCRsRERGRyjFhIyIiIlI5JmxEREREKldB6QCUUL16ddmoUSPTnbNngcuXgZo1gfr1C98nIse6cQM4frzwe878XmzaFKhaVbn4iEh1Lp+5ihtX0uBf2Q81G1SHt6+30iHZza5du65KKWsU3O6RCVujRo2QkJBguiMlMGECEBVlujgAwLhxwJw5gBDKBVlaUgKxsUBERP64bW0nUgvL9+Dgwab3nvm+K74XicjhDAYD/vhyHb57/WcYTxgx9J2n8MSEx+FVwUvp0MpNCHHa6nZPnDg3NDRU3k7YANMFQ2PRO2w0ut4FYvlyYODA/Bc4ywthTAwwYIDSURJZZ/laNWOyRkTFuHIuFZ+P+RZbVuxE0/aNMGnRKIR0aGK359dl6RCdGI2k1CSEBIUgsnUktL5auz2/NUKIXVLK0ELbPT5hc5cLheXvYY6frRTkStzhgxMROZ2UEvHLd+DzMd/ixuWbeGpyfzz31iD4VPQp1/PGn4lH+NJwGKUR+hw9/L39oREaxA2JQ9cGXe0UfWG2EjbPHnRQMMkxGk1fo6JM210pmRXClJSZ49domKyR6zC/Fy252nuQiBQhhEC3gZ2w6OBsPPJcd/zy8XKM7DAFh7YeLfNz6rJ0CF8aDl22DvocPQBAn6OHLtu0PT073V7hl5hnJ2yxsYWTGsukJzZW6QhLxxy/JSZrpHbu9MGJnENKUxlIwdeGre3kEbTVAvDqdy/jo9XTcUt/C+O7vomvXl2CrMysUj9XdGI0jNJodZ9RGhF9MLq84ZaaZydsERGm2i7LpMac9MTEmPa7ErZSkCtytw9O5HixsaaaXcvzm/n8N3AgXzMe7t7H2uObA7MR/mIYfp/9B0Z1nIqjO5NL9RxJqUm3W9YK0ufokXytdM9nD56dsAlhKsQv2AJla7uasZWCXFVxH5z692drCuUXEVH4/GZ5/nO1D9tkd/6VK2H8whGYseYNZOoyMfaB6Vj81i/Iyc4p0eNDgkLg7+1v/bm9/dEssJk9wy0RDjpwFxwlSu6Kr22yxl0GjJHDpd/QY8G477Dhx41odk9jTF0yBo1aFz3Pqi5Lh+DZwdBl6wrt0/pokTIpBQE+AQ6Jl6NELbhlwsZ52MhdcQQ02cKRxVQKm2N3YO5LX0GflokRM59D/9G9IIp4vahtlCgTNiJSP7amUEF8TVAZXL98E7OGf4Htq3bj3t73YPJ3L6NaLdsrqaRnpyP6YDSSryWjWWAzRLaJdFjLmhkTNgtM2IhcEFtTyIytrlQOUkr88eU6fPXqD6ik9cOkb19G58c7Kh3WbZyHjYjUq7hpGoxGjoCmOziymMpBCIF+Lz+GLxI+QWDdaniz38f4csLiEg9IUAoTNiJSXnHTNPTrxxHQdIe7TclEimjYqj7mb5uBiDG9ERO1ChO6vYkLJy8pHZZN7BIlIuUV1cXVpw+wahVHiRKRw2yK2Y5Zw78AALz63cvoOqCTYrGwhs0CEzYiFbJVRD57NrBiBUdAE5FDXThxCR88PQfHEo7jw1Wv477e9ygSBxM2C0zYiFSKAwuISEHZWTmI+2YD+o56FF5eXorEwEEHRKRuXFrNMbjuJlGJ+fh6I2J0b8WStaIwYSMi5XFpNcfhuptEboEJGxEpj9M0OA7X3XQ8tmKSE7CGjYiUx6XVHIurAjgW17slO+KgAwtM2IjI43BAh+Nw5QWyIw46ICLyVBzQ4VgFu/A1GiZrZHdM2IiI3BkHdDiHOWmzxGSN7IgJGxGRO+OADudgKyY5GBM2IiJ3xnU3HY+tmOQEFZQOgIiIHEgI6yMUbW2n0rPVigmYtnfvzr81lRsTNiIiovIwt2JaTj9jTtq6d2crJtkFEzYiIqLyYCsmOQFr2IiIiIhUjgkbERERkcoxYSMiIiJSOVUkbEKIXkKIo0KIZCHENCv7Jwsh9ubdDgohDEKIwLx9p4QQB/L2cb0pIiIicjuKDzoQQngBWADgEQDnAOwUQqyUUh4yHyOl/BTAp3nH9wUwQUp5zeJpekgprzoxbCIiIiKnUUML230AkqWUJ6SU2QB+AdC/iOMHA/ivUyIrjpTA8uWFJ0W0tZ2IiMgSryNUQmpI2IIBnLW4fy5vWyFCiEoAegFYZrFZAlgnhNglhBjhsCitiY0FBg7MP5O1ecbrgQO55AsRERWN1xEqIcW7RAFYWxnX1keKvgA2F+gO7SKlTBFC1ASwXghxREq5sdAPMSVzIwCgQYMG5Y3ZJCLizvIjgGmSRMvlSThZIhERFYXXESohNSRs5wDUt7hfD0CKjWOfRoHuUCllSt7Xy0KI5TB1sRZK2KSUXwP4GgBCQ0Pt08ZccPkR8xvOcnkSIiIiW3gdoRISUuH+cSFEBQDHAPQEcB7ATgDPSCkTCxxXBcBJAPWllPq8bf4ANFJKXd736wG8J6VcU9TPDA0NlQkJdhxQKiWgsehdNhr5JiMiopLjdYTyCCF2SSlDC25XvIZNSpkLYDSAtQAOA/hVSpkohBgphBhpcegAAOvMyVqeWgDihRD7AOwAsKq4ZM1OQd8pBjXXGlgaP56FokREVDLWriOWNW1EUEeXKKSUcQDiCmxbWOD+YgCLC2w7AaCdg8MrzFwkOnas6f68efm/nzfvTjM3PyEREZEt5mTNXLNmWcMG8DpCt6kiYXM5BYtErSVuUVFA9+5c+JeIiGyLjc2frBWsaeN1hPIoXsOmBLvUsElp6vqcN+/ONvMbDjC9CSMi+MmIiIhsk9L69cLWdnJ7tmrYmLCVB4tEiYiIyI5UO+jAZbFIlIiIXAFXU3ALTNjKomCRqNF4p6aNSRsREakJV1NwCxx0UBYsEiVyS7osHaITo5GUmoSQoBBEto6E1lerdFhE5cPVFNwCa9jKwhOKRD3hdySyEH8mHuFLw2GURuhz9PD39odGaBA3JA5dG3RVOjyi8rHsGTLjagqqxBo2exLiTguaZf+/re2uiE3o5EF0WTqELw2HLlsHfY5pbm59jh66bNP29Ox0hSMkKifLniAzV0nWWIMHgAlb+bhzUmPZhG7+/diETm4qOjEaRmm0us8ojYg+GO3kiIjszJUHyrnztbYUWMNWHu5cF8AFicmDJKUm3W5ZK0ifo0fytWQnR0RkR66+moI7X2tLgQlbebh7UmP+/SxrHtzh9yIqICQoBP7e/laTNn9vfzQLbKZAVER24uoD5dz9WltCHHRgD+46gS6LVMlD6LJ0CJ4dDF22rtA+rY8WKZNSEOAToEBkRHbgLoPI3PVaWwAHHTiKK9cFFIVzzZGrKkOBstZXi7ghcdD6aOHv7Q/A1LKm9TFtZ7JGLs08IK5gcmNruxq567W2NKSUHnfr2LGjtAujUcpx46QETF+t3XdVMTGFfw/L3y8mpnTPZzSaHlPwb2JrO1FZleO1q8vSyUW7Fslp66fJRbsWSV2WzklBE5FN7nyttQJAgrSSuyiePClxs1vCZu+kRk3snWC589+K1MXDTu5Ebs/Drh+2EjbWsJWHdJO6AGeQRYxSYl0c2Zvl682MrzMi1+Rh11pbNWxM2Mh5eBElZ5JlKFD2sAsDEakPBx2Q8lx5pm1yLeYPB5ZKUqDMCTqJSKWYsJHzlPUiSlQaBbvfSzPCmSt8EJFKceJcco6iatgAtrSR/ZRnklBO0ElEKsUaNnKO5ctNXUqWFz7LJC4mRt0zbZPrsEcdWlnq34iI7IA1bKSsiAhTUmb
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')\n",
"plot_decision_boundary(fig, theta, Xpl)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix([\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4) \n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5) \n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X5 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)).reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAT+0lEQVR4nO3df4zteV3f8dd79kKQmWlcwgXWhRbqnYCWP8TeEpRJQ0Xa9bZxW6OZNVFXc5NNm1Kx17RS20jSNC1pGlPbWJvNQtEUYQhi3dhblaJEb7Rk765bYbmSmVCF27u6lzbB2Wkb3M6nf5y5vdfLvXtnl5nve+6cxyPZnJnzPXPOO9987/Dk+2tqjBEAAKa10D0AAMA8EmEAAA1EGABAAxEGANBAhAEANBBhAAANDizCquq9VfVUVX3qmudeUlUfraqN3cc7D+rzAQAOs4PcE/a+JPdc99w7k3xsjLGS5GO73wMAzJ06yJu1VtWrk/ziGOP1u99/JslbxhhPVtVdST4+xnjtgQ0AAHBITX1O2MvHGE8mye7jyyb+fACAQ+FY9wA3U1UPJHkgSRYXF//86173uuaJAAD+pEcfffQLY4zjz+dnp46wP6yqu645HPnUzV44xngwyYNJcvLkyXH+/PmpZgQA2JOq+v3n+7NTH458OMn9u1/fn+QXJv58AIBD4SBvUfGBJL+V5LVVdbGqTid5d5K3VdVGkrftfg8AMHcO7HDkGOO7b7LorQf1mQAAtwt3zAcAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGx7oHAOCI2tpK1teTjY1kZSVZW0uWl7ungkNDhAGw/86dS06dSnZ2ku3tZHExOXMmOXs2WV3tng4OBYcjAdhfW1uzANvamgVYMnu88vzTT/fOB4eECANgf62vz/aA3cjOzmw5IMIA2GcbG1f3gF1vezvZ3Jx2HjikRBgA+2tlZXYO2I0sLiYnTkw7DxxSIgxgHm1tJQ89lPzIj8wet7b2773X1pKFm/zPy8LCbDng6kiAuXPQVy4uL8/e6/rPWFiYPb+09JV/BhwBIgxgnlx75eIVV87fOnUquXRpfyJpdXX2Xuvrs3PATpyY7QETYPD/iTCAebKXKxdPn96fz1pa2r/3giPIOWEA88SVi3BoiDCAeeLKRTg0RBjAPHHlIhwaIgxgnly5cnF5+eoescXFq887cR4m48R8gHnjykU4FEQYwDxy5SK0czgSAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKBBS4RV1d+tqieq6lNV9YGqelHHHAAAXSaPsKq6O8kPJjk5xnh9kjuS3Df1HAAAnboORx5L8lVVdSzJi5NcapoDAKDF5BE2xvjvSf5Fks8leTLJF8cYv3L966rqgao6X1XnL1++PPWYAAAHquNw5J1J7k3ymiRfk2Sxqr7n+teNMR4cY5wcY5w8fvz41GMCAByojsOR35rkv40xLo8x/jjJR5J8c8McAABtOiLsc0neVFUvrqpK8tYkFxrmAABo03FO2CeSfDjJY0k+uTvDg1PPAQDQ6VjHh44x3pXkXR2fDQBwGLhjPgBAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAg2PdAwC029pK1teTjY1kZSVZW0uWl7unAo44EQbMt3PnklOnkp2dZHs7WVxMzpxJzp5NVle7pwOOMIcjgfm1tTULsK2tWYAls8crzz/9dO98wJEmwoD5tb4+2wN2Izs7s+UAB0SEAfNrY+PqHrDrbW8nm5vTzgPMFREGzK+Vldk5YDeyuJicODHtPMBcEWHA/FpbSxZu8mtwYWG2HOCAiDBgfi0vz66CXF6+ukdscfHq80tLvfMBR5pbVADzbXU1uXRpdhL+5ubsEOTamgADDpwIA1haSk6f7p4CmDMORwIANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANCgJcKq6qur6sNV9btVdaGqvqljDgCALseaPvcnkvzSGOM7q+qFSV7cNAcAQIvJI6yq/lSSv5jk+5NkjPGlJF+aeg4AgE4dhyP/bJLLSf5dVf12VT1UVYsNcwAAtOmIsGNJvjHJT40x3pBkO8k7r39RVT1QVeer6vzly5ennhEA4EB1RNjFJBfHGJ/Y/f7DmUXZnzDGeHCMcXKMcfL48eOTDggAcNAmj7Axxh8k+XxVvXb3qbcm+fTUcwAAdOq6OvLvJHn/7pWRn03yA01zAAC0aImwMcbjSU52fDYAwGHgjvkAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA2OdQ8AzJmtrWR9PdnYSFZWkrW1ZHm5eyqAyYkwYDrnziWnTiU7O8n2drK4mJw5k5w9m6yudk8HMCmHI4FpbG3NAmxraxZgyezxyvNPP907H8DERBgwjfX12R6wG9nZmS0HmCMiDJjGxsbVPWDX295ONjennQegmQgDprGyMjsH7EYWF5MTJ6adB6CZCAOmsbaWLNzkV87Cwmw5wBwRYcA0lpdnV0EuL1/dI7a4ePX5paXe+QAm5hYVwHRWV5NLl2Yn4W9uzg5Brq0JMGAuiTBgWktLyenT3VMAtHM4EgCgwS0jrKreXlV3TjEMAMC82MuesFckeaSqPlRV91RVHfRQAABH3S0jbIzxj5KsJHlPku9PslFV/7SqvvaAZwMAOLL2dE7YGGMk+YPd/55JcmeSD1fVPz/A2QAAjqxbXh1ZVT+Y5P4kX0jyUJK/N8b446paSLKR5O8f7IgAAEfPXm5R8dIk3zHG+P1rnxxj7FTVXzuYsQAAjrZbRtgY48eeZdmF/R0HAGA+uE8YAEADEQYA0ECEAQA0EGEAAA1EGABAg7YIq6o7quq3q+oXu2YAAOjSuSfsHUnc4gIAmEstEVZVr0zyVzO7Az8AwNzp2hP2LzP7c0c7N3tBVT1QVeer6vzly5enmwwAYAKTR9junzp6aozx6LO9bozx4Bjj5Bjj5PHjxyeaDgBgGh17wt6c5Nur6veSfDDJt1TVv2+YAwCgzeQRNsb4B2OMV44xXp3kviS/Osb4nqnnAADo5D5hAAANjnV++Bjj40k+3jkDAEAHe8IAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoMHkEVZVr6qqX6uqC1X1RFW9Y+oZAAC6HWv4zGeS/PAY47GqWk7yaFV9dIzx6YZZAABaTL4nbIzx5Bjjsd2vt5JcSHL31HMAAHRqPSesql6d5A1JPtE5BwDA1NoirKqWkvxckh8aY/zRDZY/UFXnq+r85cuXpx8QAOAAtURYVb0gswB7/xjjIzd6zRjjwTHGyTHGyePHj087IADAAZv8xPyqqiTvSXJhjPHjU38+kGRrK1lfTzY2kpWVZG0tWV7ungpgrnRcHfnmJN+b5JNV9fjucz86xjjbMAv
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x246ed4ba250>]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXiU9b3+8fubHSZhD/sOkR0SS13Ro3VHLXXBhJ722Nbfse2pBY0b7latlbpC29MeThfb05YEXFHRui9oXdCEsJOwh7AkLGGyL/P9/TGxUJpAEmbmO8v7dV1eSeaZzHNfz/VkvHmWzxhrrQAAABBaca4DAAAAxCJKGAAAgAOUMAAAAAcoYQAAAA5QwgAAAByghAEAADgQtBJmjPm9MWavMWb1EY/1Msa8YYwpbvnaM1jrBwAACGfBPBL2tKSLj3psrqS3rLUZkt5q+RkAACDmmGAOazXGDJf0srV2YsvPGySdY63dZYwZIOlda+2YoAUAAAAIU6G+JqyftXaXJLV87Rvi9QMAAISFBNcB2mKMuV7S9ZLk8Xi+MnbsWMeJAAAdUVZZq31VDUpKiFNG31TFGeM6EhBwn3/+eYW1Nr0zvxvqErbHGDPgiNORe9t6orV2oaSFkjR16lS7YsWKUGUEAJygdzbs1Xf/8JmGxBk9+8MzNGVID9eRgKAwxmzr7O+G+nTkUknXtnx/raQXQ7x+AECQlXvrdeuSlZKkmy8cQwED2hDMERWLJP1d0hhjTKkx5jpJj0i6wBhTLOmClp8BAFHCWqtbn1mpiqoGnT6yt75/9kjXkYCwFbTTkdbaWW0sOi9Y6wQAuPX0R1v17oZy9eiaqCeypygujuvAgLYwMR8AEBDrdh3Sz15dL0l65MrJGtC9i+NEQHijhAEATlhdY7NmLypQQ5NPs04Zqosn9ncdCQh7lDAAwAn76SvrVLy3SiPTPbrnsnGu4wARgRIGADghb67do//7eJsS440W5GSpa1LYjqAEwgolDADQaXsP1em2Z4skSbddNFYTB3V3nAiIHJQwAECn+HxWNy9Zqf3VDToro4+umzbCdSQgolDCAACd8rvlW/RBcYV6eZL0+EzGUQAdRQkDAHTY6p2V+vnf/OMo5l01WX27pThOBEQeShgAoENqGpo0O69Ajc1W3z5tmC4Y3891JCAiUcIAAB3y4MvrtLm8Whl9U3XXpYyjADqLEgYAaLfXVu/Wok+3KykhTgtmZSklMd51JCBiUcIAAO2yq7JWc5/zj6OYe/FYjRvQzXEiILJRwgAAx9Xss8rNX6mDNY06Z0y6vnvmcNeRgIhHCQMAHNfC9zfr75v3qU9qkh69eoqMYRwFcKIoYQCAY1q546Aef32DJOnRq6coPS3ZcSIgOlDCAABtqq5v0py8AjX5rL5zxnCdO7av60hA1KCEAQDadP/SNdq6r0Zj+6dp7iVjXccBogolDADQqpeLyrTk81IlM44CCApKGADgX5QeqNEdz62SJN196Tid1C/NcSIg+iS4DgAACC9fjqPw1jXp/HF99a3ThnXuhbxeKT9fKi6WMjKk7GwpjTIHfIkSBgD4J//9Tok+3bpf6WnJmnfV5M6No1i+XJo+XfL5pOpqyeORcnOlZcukadMCHxqIQJyOBAD8w+fbDuipt4olSY/PnKLeqZ0YR+H1+guY1+svYJL/65ePV1UFMDEQuShhAABJkreuUTfmF6jZZ/WfZ43Q2Seld+6F8vP9R8Ba4/P5lwOghAEA/O59cY127K/V+AHddMtFYzr/QsXFh4+AHa26Wiop6fxrA1GEEgYA0AsFO/V8wU6lJPrHUSQnnMA4iowM/zVgrfF4pNGjO//aQBThwnwAiEVH3Lm4Y/hY3b1rgCTp3ssmaHTf1BN77exs/0X4rYmL8y8HQAkDgJhzxJ2LTTW1mvPtR1U1oK8u6p+gWacMOfHXT0vz3wV59N2RcXH+x1NPsOQBUYISBgCx5Mg7FyUtmPZNfTFgjPp7K/TI7+6Q+X8bA1OSpk2Tysr8R9tKSvynILOzKWDAEShhABBLjrhz8bNB4/XL07NlrE9PvPyEetYe8i+/7rrArCs1NXCvBUQhShgAxJKWOxcrkz268fJb5IuL1w8+XqIzthf5l3PnIhAylDAAiCUZGbIej+4670fa2b2vJu/aqNwP/uJfxp2LQEgxogIAYkl2tp4dd45eHne2ujbUav5LjynJ1+Rfxp2LQEhxJAwAYsjW+jjdd+EPpWbp/g+e1ogDZdy5CDhCCQOAGNHY7NOc/EJVN0uXjkvXzNFXSOdP4s5FwBFKGADEiKfe3KiVOw5qYPcUPTwzS6brKa4jATGNa8IAIAZ8vHmf/vvdTYoz0pPZmereNdF1JCDmUcIAIModrGnQTfmFslb60bmjderI3q4jARAlDACimrVWdz6/Srsq65Q5pIdmn5fhOhKAFpQwAIhii1fs0LJVu5WanKAFOVlKjOdtHwgX/DUCQJTaVF6l+5eulSQ9MGOChvbu6jgRgCNRwgAgCjU0+XRjXqFqG5s1I3Ogrsga5DoSgKNQwgAgCj3+xgat2lmpwT276MFvTJQxxnUkAEehhAFAlPmwpEL/895mxRlpfk6muqUwjgIIR5QwAIgi+6sblLu4UJI0+7wMfWVYL8eJALSFEgYAUcJaq9ufLdKeQ/WaOqynbjh3tOtIAI6BEgYAUeKvn27XG2v3KC05QU9mZyqBcRRAWOMvFACiQMlerx582T+O4qdXTtKQXoyjAMIdJQwAIlx9U7N+vKhQdY0+XXnyIH19ykDXkQC0g5MSZoy5yRizxhiz2hizyBiT4iIHAESDn7+2Qet2HdLQXl31wIyJruMAaKeQlzBjzCBJsyVNtdZOlBQvKSfUOQAgGry3sVy/W75F8XFG83MylZqc4DoSgHZydToyQVIXY0yCpK6SyhzlAICIVVFVr5sXr5Qk5V5wkrKG9nScCEBHhLyEWWt3SnpM0nZJuyRVWmtfP/p5xpjrjTErjDErysvLQx0TAMKatVa3PVOkiqp6nTqil37wb6NcRwLQQS5OR/aUNEPSCEkDJXmMMd86+nnW2oXW2qnW2qnp6emhjgkAYe1Pf9+mt9fvVbcU/ziK+Dg+lgiINC5OR54vaYu1ttxa2yjpOUlnOMgBABFpw26vfrpsnSTpkasma2CPLo4TAegMFyVsu6TTjDFdjf8TZc+TtM5BDgCIOHWNzZq9qEANTT5lTx2i6ZMGuI4EoJNcXBP2iaRnJH0haVVLhoWhzgEAkeiRV9drwx6vRvTx6N7Lx7uOA+AEOLmX2Vp7n6T7XKwbACLV2+v36OmPtiox3mhBTpY8jKMAIhoT8wEgAuz11unWJUWSpJsvHKNJg7s7TgTgRFHCACDM+XxWtywp0r7qBp0xqreuP2uk60gAAoASBgBh7g8fbdX7G8vVo2uinrgmU3GMowCiAiUMAMLYmrJKzXt1vSRp3lWT1b87H7ULRAtKGACEqdqGZs3JK1RDs0/fPHWoLprQ33UkAAFECQOAMPXQK2tVsrdKo9I9uudSxlEA0YYSBgBh6PU1u/WXT7YrKT5OC2ZlqUtSvOtIAAKMEgYAYWbPoTrd/qx/HMVtF4/RhIGMowCiESUMAMKIz2eVu7hQB2oadVZGH33vzBGuIwEIEkoYAISR//1gsz4s2afeniQ9fs0UxlEAUYwSBgBhYlVppR57fYMk6edXT1bfNMZRANGMEgYAYaCmoUlz8grU2Gx17enDdN64fq4jAQgyShgAhIEHXlqrzRXVGtMvTXdMH+c6DoAQSHAdAACc83ql/HypuFjKyJCys6W0tJCt/tVVu5T32Q4lJcRp/qxMpSQyjgKIBZQwALFt+XJp+nTJ55OqqyWPR8rNlZYtk6ZNC/rqyw7Wau5zqyRJd14yVmP7dwv6OgGEB05HAohdXq+/gHm9/gIm+b9++XhVVVBX3+yzuim/UJW1jTp3TLquPWN4UNcHILxQwgDErvx8/xGw1vh8/uVB9Jv3NumTLfvVJzVZj86cImMYRwHEEkoYgNhVXHz4CNjRqqulkpKgrbp
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, _bias_) zachodzi **niedostateczne dopasowanie** (_underfitting_)."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2471ceee370>]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3yV5cH/8e91ssmAACFA2CTsKWGDIm4caBVxVlstWldt+1StHc/z9Pm1tba1tVatVq0LFQsOFMRRQUWGhr1JGBkEQoCQBZnn+v1xAiKCBEjOdcbn/XrxSnLOCfl6v26O31zXfV+XsdYKAAAA/uVxHQAAACAcUcIAAAAcoIQBAAA4QAkDAABwgBIGAADgACUMAADAgWYrYcaY54wxu40xa494rLUx5kNjTHbDx+Tm+vkAAACBrDlHwp6XdOFRjz0g6T/W2gxJ/2n4GgAAIOyY5lys1RjTTdK71toBDV9vkjTBWrvTGNNB0gJrbe9mCwAAABCg/H1NWKq1dqckNXxs5+efDwAAEBAiXQc4HmPMNEnTJCk+Pn5Ynz59HCcCAAD4umXLlu2x1qacyvf6u4QVGWM6HDEduft4L7TWPi3paUnKzMy0WVlZ/soIAADQKMaY3FP9Xn9PR86WdFPD5zdJetvPPx8AACAgNOcSFa9KWiyptzGmwBhzi6SHJJ1njMmWdF7D1wAAAGGn2aYjrbXXHuepc5rrZwIAAAQLVswHAABwgBIGAADgACUMAADAAUoYAACAA5QwAAAAByhhAAAADlDCAAAAHKCEAQAAOEAJAwAAcIASBgAA4AAlDAAAwAFKGAAAgAOUMAAAAAcoYQAAAA5QwgAAAByghAEAADhACQMAAHCAEgYAAOBApOsAAIDwYa1Vndeqps7r+1Pv+1h91Ne+z+u/9lyrFtE6o0srtUmIcf2fATQJShgAoEmVHqzV5zl79MnaQn2xvkDlNV5VR0SqxhOpmnqvrD29v79H23id0TVZmV2TNaxrsnqmJMjjMU0THvAjShgA4LRYa7WusEyfbC7WJ5uKtSyvRPXeQ00rUjKSvJK8XklShMcoOsKj6MiGPxEexUR+/etjfb6j5KBWFezX1j2V2rqnUjOXFUiSWsZFaVhDIRvWNVmDO7VSXHSEk2MBnAxKGADgpJUeqNVnOcVasKlYn2wuVnF59eHnIow0YscGTcheqjO3r1C7in2Krq/1/YmLVeSOAikh4ZR+bm29V+sLy7Qst0TLckuUlbtPRWXV+njjbn28cbckKdJj1L9jkoZ1ba3Mbr5ilpoU2yT/3UBTMvZ0x4X9IDMz02ZlZbmOAQBhy+v1jXYt2LRbCzYXa0VeibxH/O+jfVKsJvRO0Vm9UjR20Vwl/fRHUmXlN/+i+Hjp0UelW25pklzWWu3Yf/CrUra9RBt3lX0tmyR1So47PH05vHtr9U5NlDFMYeL0GWOWWWszT+V7GQkDABxTSWWNPs32jXR9urlYeypqDj8X6TEa2b21zuqdogm9U75eal7efOwCJvkez8lpsozGGHVKbqFOyS00eUiaJKmiuk4r8/YrK3efluWWaEXefhWUHFRByUG9tbJQknRGl1a6e2KGJvROoYzBGUoYAOCwmjqvZmTl643lBVqVv/9rI0odW8bqrN7tNKF3isamt1VCzHH+F5KR4RvxOt5IWHp684RvkBATqXEZbTUuo60kqd5rtbmoXFm5JVqeW6L5m3Zred5+fe/5LzUgLUl3nZ2h8/ulcnE//I7pSAAIR+Xl0owZUna2lJEh79VXa862Cv3pg03K3XtAkhQVYTSie2tN6OUrXuntEho3alReLqWl+T4eLTFRKiw85WvCmkJldZ1eWZqnpz7dqj0VvmvZeqcm6s6J6bp4YAdFUMZwEk5nOpISBgDhZuFCadIk392KlZVa1GuEHhp7vVa36ylJ6pkSrx+d20vn9Gmn+OONdp3kz1B8vOTxSHPnSuPGNeF/zKmrqq3XjC/z9Y9PtmhnaZUk3/IXd5ydrslDOioqgvXMcWKUMABA4xwxSrU+pbv+MOEmfdLD9/+PdpUl+vGUkZoytqcim6KAVFT4RttycnxTkFOnOh0BO56aOq9mLS/QEwtylL/voCTfhfw/nNBTVw3rpJhIlrvA8VHCAACN88wzyv/Vb/XIsO/orf4TZI1HidWVun3JTH1vw3/U4s8PN9mdi8Gmtt6r2SsL9fiCHG0t9l3P1j4pVred1UPXDO/C2mM4Ju6OBACcUElljR7PrteL1/9FNZFRiqqv1Y3LZ+uuxa+r9cEy34ua8M7FYBMV4dGVwzrp8qFpmrtmp/7+cY42FZXrf99Zr8fn5+gH43vo+lFdj39DAnCSOJMAIMQdrKnXc59v0z8WbFG56SRFSpevm6+ffvayOpcWffVCP9y5GAwiPEaXDu6oiwd20EcbivTYxzlas6NUv39vo578ZIu+P7a7bhrTTS3jolxHRZBjOhIAQlRdvVczlxXoLx9tVlGZ7y7A8T2Sdf9Dt2vAtjXf/IYAuHMxEFlr9cnmYj32cY6W5ZZIkhJjInXTmG66ZVx3JcdHO04Il7gmDABwmLVWH64v0sPvb1LO7gpJ0oC0JD1wYV/f2llBcOdiILLWavHWvfr7xzlatGWvJKltQrT+MnWIxmekOE4HVyhhAABJUtb2fXrovY3Kahix6dw6Tv91fm9dOqjj1xcjDZI7FwPVstx9+sO8Tfpi2z4ZI90xoad+fG6vprmrFEGFEgYAYa64vFq/fGuN3l/nu8ardXy07p6YrutHdlV0JMWgOdR7rR6fn6O/frRZXisN75asv107VB1axrmOBj+ihAFAGFuVv1+3v7xMO0urFBcVoVvHd9e0M3soMZYLx/1hyda9+tFrK1RUVq3kFlH689WDNbFPqutY8JPTKWH8egQAQez1rHxNeWqxdpZW6YwurfTxf52ln57fmwLmR6N6tNHce8ZrQu8UlRyo1fefz9Jv56xXTZ3XdTQEOEoYAAShmjqvfv32Wt03c7Vq6ry6fmQXvTZtNFNhjrRJiNFzNw3Xzy/qowiP0T8/26YpTy1W/r4DrqMhgFHCACDI7C6v0vXPLNGLi3MVHeHRQ98ZqN9eMZBrvxzzeIxuO6unXr9ttNJaxWlV/n5N+ttnem/NTtfREKD4FwsAQWRFXokufWyhvtxeotSkGM24bZSuGdHFdSwcYVjXZM25Z5zO75eq8qo6/XD6cv367bWqqq13HQ0BhhIGAEFixpd5mvrUEhWVVWt4t2S9c/c4De2S7DoWjqFVi2g9deMw/c+l/RQd4dGLi3P1nScWaWtxhetoCCCUMAAIcDV1Xv3izTW6f9Ya1dR79d3RXTX91lFqlxjrOhq+hTFGN4/trlk/HKOubVpo/c4yXfrYQr29cofraAgQlDAACGC7y6p07T+XaPrSPEVHevTwVYP0m8kDuP4riAzs1FLv3j1OlwzqoMqaev3otZW6f+ZqHaxhejLc8a8YAALUstx9uuSxhVqWW6IOLWP179tG6+rMzq5j4RQkxkbpsWuH6ndXDFRMpEczsvI1+fGF2lxU7joaHKKEAUAAemVpnq55eol2l1drRPfWeufucRrcuZXrWDgNxhhdN7KL3rpzrHqmxGtzUYUu+/tCvf5lvoJh4XQ0PUoYAASQ6rp6/fyN1XrwzTWqrbe6eUw3Tb91pNomxLiOhibSt0OSZt81Tlee0UlVtV7dN2u1fjxjJXdPhqFI1wEAAD5FZVW6/eVlWpG3X9GRHv3uioG6algn17HQDOJjIvXnqwdrdM82+tVba/XWykLtqajRMzdlKjYqwnU8+AkjYQAQALK2+67/WpG3Xx1bxmrW7WMoYGHgqmGd9NadY9U2IUYLc/bo1heyGBELI05KmDHmx8aYdcaYtcaYV40x3GcNICxZa/XSklxd8/QSFZdXa1QP3/VfAzu1dB0NftK7faJe/cFIilgY8nsJM8akSbpHUqa1doCkCEnX+DsHALhmrdVv3l2vX721VnVeq1vGddfLt4xUG67/CjsZqRSxcORqOjJSUpwxJlJSC0mFjnIAgDM
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x246ed4eb3a0>]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU1eH///fJvhMgCzsBEkABFUFEBeqCG221tSK471pbq1Zr66efbp/+PvXbRWu1WvtRUNGqgFatVeuKG4LsyCokQIAQyAYkkz0zc35/TLBUWRLIzJnl9Xw8fCSZmcy8vY/L5c2595xrrLUCAABAaMW5DgAAABCLKGEAAAAOUMIAAAAcoIQBAAA4QAkDAABwgBIGAADgQNBKmDHmCWNMpTFmzX6P9TDGvGOMKW7/2j1Ynw8AABDOgjkS9pSk87702D2S3rPWFkl6r/1nAACAmGOCuVirMaZA0mvW2pHtP2+QdLq1dqcxprekD6y1w4IWAAAAIEyF+pqwfGvtTklq/5oX4s8HAAAICwmuAxyMMeYmSTdJUnp6+pjhw4c7TgQACFdev9WGXR75rdWgnHRlJIftX2+IMsuWLau21uYeye+Gei+tMMb03u90ZOXBXmitfUzSY5I0duxYu3Tp0lBlBABEmP9+ebX2Ltqms4bnaeY1J7mOgxhijNl6pL8b6tORr0q6uv37qyX9I8SfDwCIMiWV9Zq9ZLvijHTP+Zw1QeQI5hIVz0taKGmYMabMGHO9pN9KOtsYUyzp7PafAQA4Yr/91+fy+a2mjxugovxM13GADgva6Uhr7aUHeeqsYH0mACC2fLq5Ru+ur1BaUrzumFzkOg7QKayYDwCISH6/1b1vrJck3TxpiPIyUxwnAjqHEgYAiEj/XFWuVWW1ystM1o2TBrmOA3QaJQwAEHGa23z6/ZsbJEl3nTNUaUksSYHIQwkDAEScpxeWasfeJg3Lz9TFY/q7jgMcEUoYACCi7G1s1cPzSiRJ90wZrvg44zgRcGQoYQCAiPLneSWqa/ZqQmGOTh96RAuVA2GBEgYAiBhbaxr09MJSGSP915ThMoZRMEQuShgAIGL8/q0NavNZfXt0X43o0811HOCoUMIAABFh+bY9en3VTiUnxOlH5wxzHQc4apQwAEDYs9bq3tcDC7NeP2GQ+mSnOk4EHD1KGAAg7L21tkJLt+5Rj/Qkfff0Ia7jAF2CEgYACGsLSqr1q1fXSpLumFykrJREx4mArsESwwCAsLS3sVW/eX29XlhWJkk6oX+2Lh03wHEqoOtQwgAAYcVaq3+u2qlf/3OtqutblRQfpx+cWaibvzZEifGcwEH0oIQBAMLGjr1N+vkrazTv80pJ0rhBPfT/LhqlIbkZjpMBXY8SBgBwzue3enphqf7w1gY1tvqUmZKgn045RtPG9lcctyVClKKEAQCc+nxXne75+2qt3L5XknT+yF76nwtGKC8rxXEyILgoYQAAJ5rbfHp4Xon++uEmef1W+VnJ+v8uHKlzRvRyHQ0ICUoYACDkPt1co5++tFqbqxskSVeMH6Afnzec5ScQUyhhAIDg8HikOXOk4mKpqEiaNk21CSn67b/W6/nF2yVJhXkZ+u1FozS2oIfjsEDoUcIAAF1v/nxpyhTJ75caGmTT0/Wvh2frlxfeqaoWq8R4o++fUahbTh+i5IR412kBJyhhAICu5fEECpjHI0naldFTPz/7u3pn6ClSi9WYfln67dQTVJSf6Tgo4BYlDADQtebMkd9vVZ6Vq3cLT9Z9k65SfXKaMloa9ZOFz+nyG76huPyJrlMCzlHCAABHrMXr05bqBm2qbNCmqnqVVNZrU3GGNt/0lJqS/r3ExNkbF+rX7/5VvT010hnDHSYGwgclDABwWLWNbSqp8mhTZYNKquq1qbJeJVX12r67UX77pRebTClJyq3frcKa7bpq+es6b+MCGUlKT5cKCx38HwDhhxIGALHoADMXbUaGKj0t2rDLo5L2krWpsl6bqupVXd96wLeJM9KgnHQNyU3XkLwMDcnNUGFGnIZMHKtu1bsO8Atx0rRpQf6fAyIDJQwAYoz9+GNVX3yZNnbvq40Zedq4tlXFHz+mjQOGq67twL+TmhivIXnpgZKVm6EheRkqzMvQwJ5pB57d+PIL/zE7UunpgQL2xhtSBveBBCRKGABEtd0NrdpY4fn3f+W1Kt6wQ3uu/etXX9wmdUtJ0LBeWSrMD5StwrxA4eqdldK5ezhOmCCVlwdG20pKAqcgp02jgAH7oYQBQBSw1mr1jlqtKqtVcYVHGyvqVVzpOfBpxNRMZbY0qKh6m4ZVbVVR9TYNrd6qoY3Vyr33f2RuuL5rQmVkSNd30XsBUYgSBgARzFqrj4qr9dB7xVq2dc9Xnk9PildhfqaG5mVoWK9MFb02V0MfvU+9PDU64LjWppKgZwYQQAkDgAhkrdX7Gyr14Hsl+mz7XklSdlqizhyep6H5mRqWn6mi/Az16Zb6n6cRN+RI/uYDvykzF4GQooQBQASx1uqddRV6aF6x1uyokyT1TE/SjZMG64rxA5WRfJjD+rRp0p13Hvg5Zi4CIUUJA4AI4Pdbvbl2l/48r0TrdwbKV05Gsr77tcG67OQBSkvq4OE8MzMwQ5GZi4BzlDAACGM+v9Xrq3fq4XnF2lhRL0nKz0rWd782RJeOG6CUxCO4+TUzF4GwQAkDgDDk9fn1z1Xl+vO8Em2uapAk9emWolvOKNTUMf2OrHztj5mLgHOUMAAII20+v15ZsUOPvF+i0ppGSVK/7qn6/hmF+s6J/ZSUEOc4IYCuQgkDgDDQ6vXrpeVleuSDEm3f3SRJGtgzTd8/o1DfHt1XifGULyDaUMIAwKEWr08vLC3Tox9s0o69gfI1OCddt55ZqAuO76MEyhcQtShhAOBIXXObrntyiZa2L7JamJehH5xZqG8c10fxnblFEICIRAkDAAdqG9t01ZOL9dn2verdLUU/+/qxOn9kr87dnxFARKOEAUCI7Wlo1RUzF2lteZ36dU/V8zeOV/8eaa5jAQgxShgAhFB1fYuumLFIn+/yaGDPND1/43j1yU51HQuAA5QwAAiRyrpmXT5jkYor6zU4N13P3zhe+VkprmMBcIQSBgAhsKu2WZc9/qk2VzdoaH6Gnr1hvHIzk13HAuAQJQwAgqxsT6Mue3yRtu1u1DG9s/S368epZwYFDIh1lDAACKJtNY269PFPtWNvk0b17aZnrh+n7LQk17EAhAFKGAAEyZbqBl32+KfaWdus0QOy9dS149QtNdF1LABhghIGAEFQUunRZY8vUqWnRScVdNcT15ykzBQKGIB/c3I/DGPMD40xa40xa4wxzxtjmB4EIGps2OXR9Mc+VaWnRacM7qmnrh1HAQPwFSEvYcaYvpJukzTWWjtSUryk6aHOAQDBsLa8VtMfW6jq+lZNLMrRE9ecpPRkTjoA+CpXR4YESanGmDZJaZLKHeUAgC6zqmyvrpy5WLVNbTpzeJ7+cvmJSkmMdx0LQJgK+UiYtXaHpPskbZO0U1KttfbtL7/OGHOTMWapMWZpVVVVqGMCQKcs27pHlz++SLVNbTrn2Hz99YoxFDAAh+TidGR3SRdKGiSpj6R0Y8wVX36dtfYxa+1Ya+3Y3NzcUMcEgA5bvGW3rpq5SJ4Wr74+qrceufxEJSU4ueQWQARxcZSYLGmLtbbKWtsm6SVJpzrIAQBHbUFJta5+YrEaWn361gl99OD0E5QYTwEDcHgujhTbJI03xqQZY4yksyStd5ADAI7KRxurdO1TS9TU5tPFY/rp/ktOUAIFDEAHubgmbJGkFyUtl7S6PcNjoc4BAEdj3ucVumHWUrV4/bp03AD9/jvHKT7OuI4FIII4mR1prf2lpF+6+GwAOFrryuv03WeWq9Xn19WnDNSvLhihwMA+AHQc4+YA0AmtXr9+9MJnavX5NW1sfwoYgCNGCQOATnjk/RKt21mnAT3S9ItvHksBA3DEKGEA0EFrdtTqkfdLJEm/v/g4VsIHcFQoYQDQAftOQ3r9VtecWqDxg3u6jgQgwlHCAKAD/jy
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (*variance*) zachodzi **nadmierne dopasowanie** (*overfitting*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
2021-04-07 15:03:18 +02:00
"### Obciążenie (błąd systematyczny, *bias*)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
2021-04-07 15:03:18 +02:00
"### Wariancja (*variance*)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.3. Regularyzacja"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(h, fJ, fdJ, theta, X, Y, \n",
" alpha=0.001, maxEpochs=1.0, batchSize=100, \n",
" adaGrad=False, logError=False, validate=0.0, valStep=100, lamb=0, trainsetsize=1.0):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na wykładzie 11)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
" \n",
" XT, YT = X, Y\n",
" \n",
" m_end=int(trainsetsize*len(X))\n",
" \n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv] \n",
" XT, YT = X[mv:m_end], Y[mv:m_end] \n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
" \n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n,1)\n",
" \n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end,:], YT[start:end,:]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
" \n",
" if logError:\n",
" errorsX.append(float(i*batchSize)/m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i*batchSize)/m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
" \n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:,0], data[:,1], n)\n",
"Y = data[:,2]"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" plt.subplot(121)\n",
" plt.scatter(X[:, 2].tolist(), X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100, cmap=plt.cm.get_cmap('prism'));\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=alpha, adaGrad=adaGrad, maxEpochs=maxEpochs, batchSize=100, \n",
" logError=True, validate=validate, valStep=1, lamb=lamb)\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02),\n",
" np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1),yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
" plt.ylim(-1,1.2);\n",
" plt.xlim(-1,1.2);\n",
" plt.legend();\n",
" plt.subplot(122)\n",
" plt.plot(err[0],err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2],err[3], lw=3, label=\"Validation error\");\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8);"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-11-7e1d5e247279>:5: RuntimeWarning: overflow encountered in exp\n",
2021-04-07 15:03:18 +02:00
" y = 1.0/(1.0 + np.exp(-x))\n",
"<ipython-input-37-f0220c89a5e3>:19: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-07 15:03:18 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
"No handles with labels found to put in legend.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHWCAYAAABOj2WsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3hUxdfA8e+kV3oRCEIEpAYChCKhdwGpooCvioD0JgIComBBEUEElKZI11j4gTQBaVKlS2/Sg4p0SG/z/nEhBrKb3c1usgmcz/Pkkb137p2zi2T37MycUVprhBBCCCGEEEKIrM7F2QEIIYQQQgghhBDWkARWCCGEEEIIIUS2IAmsEEIIIYQQQohsQRJYIYQQQgghhBDZgiSwQgghhBBCCCGyBUlghRBCCCGEEEJkC5LACiGEEI8ZpVRzpdRJpdSfSqkRJs7nVEqtUEodVEodVUq95ow4hRBCiIcp2QdWCCGEeHwopVyBU0ATIBzYA3TWWh9L0WYUkFNr/ZZSKj9wEnhCax3njJiFEEKI+2QEVgghhHi8VAf+1FqfvZeQhgFtHmqjAX+llAL8gBtAQuaGKYQQQqQmCawQQgjxeCkCXErxOPzesZS+AMoCfwGHgUFa66TMCU8IIYQwz83ZAaRHvnz5dPHixZ0dhhBCiEfEvn37rmmt8zs7jkyiTBx7eD1RM+APoCFQAvhVKbVVa33ngRsp1RPoCeDr61u1TJky9kX214H//ly4sn33EkIIka2Ze2/Olgls8eLF2bt3r7PDEEII8YhQSl1wdgyZKBwomuJxAMZIa0qvAeO1USjjT6XUOaAMsDtlI631bGA2QEhIiLbnvVlrjXov138Hxsr7vBBCPM7MvTfLFGIhhBDi8bIHKKWUClRKeQCdgOUPtbkINAJQShUESgNnMzVKIYQQwoRsOQIrhBBCiPTRWicopfoDawFX4But9VGlVO9752cCHwDzlFKHMaYcv6W1vua0oIUQQoh7JIEVQgghHjNa69XA6oeOzUzx57+AppkdlxBCCGGJJLBCCCEeKfHx8YSHhxMTE5PqnJeXFwEBAbi7uzshMiGEEBkhrd/7Iuuz9b1ZElghhBCPlPDwcPz9/SlevDjGNqYGrTXXr18nPDycwMBAJ0YohBDCkcz93hdZX3rem6WIkxBCiEdKTEwMefPmTfUhRilF3rx55Rt6IYR4xJj7vS+yvvS8N0sCK4QQ4pFj7kOMfLgRQohHk/x+z75s/buTBFYIIYQQTicfPoUQ2dX169cJDg4mODiYJ554giJFiiQ/jouLS/PavXv3MnDgQIt91KpVy1HhZnuyBlYIIYQQWY/WIEmtECIbyJs3L3/88QcAY8eOxc/Pj6FDhyafT0hIwM3NdNoVEhJCSEiIxT527NjhmGCtkJiYiKurq9nH5qT1PB1JRmCFEEI8crTWNh0XQgghHKlr164MGTKEBg0a8NZbb7F7925q1apF5cqVqVWrFidPngRg8+bNtGrVCjCS327dulG/fn2eeuoppk6dmnw/Pz+/5Pb169fn+eefp0yZMrz00kvJ722rV6+mTJky1K5dm4EDBybfN6XExESGDRtGtWrVqFixIrNmzUq+b4MGDejSpQtBQUGpHsfExPDaa68RFBRE5cqV2bRpEwDz5s2jY8eOPPfcczRtmjm7r8kIrBCOdOwYfP89XLsGRYtCly7w5JPOjkqIx4qXlxfXr19PVdDjfqVDLy8vJ0YnhBAiIxUfsSrD7n1+fEub2p86dYr169fj6urKnTt32LJlC25ubqxfv55Ro0axZMmSVNecOHGCTZs2cffuXUqXLk2fPn1SbS9z4MABjh49SuHChQkNDWX79u2EhITQq1cvtmzZQmBgIJ07dzYZ05w5c8iZMyd79uwhNjaW0NDQ5MRz9+7dHDlyhMDAQDZv3vzA40mTJgFw+PBhTpw4QdOmTTl16hQAO3fu5NChQ+TJk8em1ye9JIEVwhGuXYP27WHvXoiPh4QE8PCAsWPhuedgwQLw9nZ2lEI8FgICAggPD+fq1aupzt3fa04IIYTIaB07dkyeenv79m1effVVTp8+jVKK+Ph4k9e0bNkST09PPD09KVCgAFeuXEn1vlW9evXkY8HBwZw/fx4/Pz+eeuqp5K1oOnfuzOzZs1Pdf926dRw6dIiffvopOa7Tp0/j4eFB9erVH9jKJuXjbdu2MWDAAADKlClDsWLFkhPYJk2aZFryCpLACmG/iAh45hm4cMFIXu+7v2h/5Up49lnYsAGsWD8ghLCPu7u77PMqhBDC6Xx9fZP//M4779CgQQOWLl3K+fPnqV+/vslrPD09k//s6upKQkKCVW2sXSKjtWbatGk0a9bsgeObN29+IN6H40/r/g9fl9EkgRXCXjNmwOXLDyavKcXEwL59RiLbpk3mxiaEEEII8RixdZpvZrl9+zZFihQBjHWjjlamTBnOnj3L+fPnKV68ON9//73Jds2aNWPGjBk0bNgQd3d3Tp06lRxXWurWrcvixYtp2LAhp06d4uLFi5QuXZr9+/c7+qlYJEWchLCH1vDZZxAdnXa7iAiYMCFzYspKtDbWBW/fDmfOODsaIUQ2IgW3hBCPkuHDhzNy5EhCQ0NJTEx0+P29vb2ZPn06zZs3p3bt2hQsWJCcOXOmatejRw/KlStHlSpVqFChAr169TI5yvuwvn37kpiYSFBQEC+++CLz5s17YCQ4M6ns+AYREhKi9+7d6+wwhIDISMiVy1jzakmOHHD7dsbHlBVoDXPnwvvvG+uD3dyMKdUlSsC4cdC6tbMjFOIBSql9WmvL+xgIsxzy3jz2vw9b+t2bKBf5nl0IYdnx48cpW7ass8NwuoiICPz8/NBa069fP0qVKsUbb7zh7LCsYurv0Nx7s7wzCGEPFxcjWbO27eNAa+jbFwYONNYFR0YaiXt0NBw5Ap07G6PWQgghhBDCYb766iuCg4MpX748t2/fplevXs4OKUPIGlgh7OHtDYGB8OefabdTCmrVypyYnG3lSli40EhcTYmKgtGjoXFjqFgxc2MTQgghhHhEvfHGG9lmxNUej8mQkBAZaPhw8PFJu42PDwwbljnxONvHH5tPXu+Li4PJkzMnHiFENpX9ljgJIYTIeJLACmGvrl2halXz+7z6+ECHDlCvXqaG5RTx8bBrl+V2iYmwfHnGxyOEyFaStHJ2CEIIIbI4SWCFsJe7O6xbBy+/DF5e4Odn/Nff3/jz0KFGQSP1GHwwi421fq3v/X1yhRBCCCGEsJKsgRXCEby8YNYsY6uc1avh1i0oWBCefdb8yOyjyNfXGHG+c8dy26JFMz4eIYQQQgjxSJERWCEcKWdOo8punz7Qvv3jlbyCMcrcsyd4eKTdztcXhgzJnJiEENlSNtzlTwjxmKpfvz5r16594Njnn39O375907zm/tZjLVq04NatW6najB07lokTJ6bZ97Jlyzh27Fjy43fffZf169fbEn62IwmsEMKxhgwxElRzU6bd3CB/fujSJXPjEkIIIYTIAJ07dyYsLOyBY2FhYXTu3Nmq61evXk2uXLnS1ffDCez7779P48aN03UvWyUmJqb52JyEhAS7+pUEVgjhWIUKwdatUKCAsQ44JX9/eOop2LbNcuVmIYQQQohs4Pnnn2flypXExsYCcP78ef766y9q165Nnz59CAkJoXz58owZM8bk9cWLF+fatWsAjBs3jtKlS9O4cWNOnjyZ3Oarr76iWrVqVKpUiQ4dOhAVFcWOHTtYvnw5w4YNIzg4mDNnztC1a1d++uknADZs2EDlypUJCgqiW7duyfEVL16cMWPGUKVKFYKCgjhx4kSqmBITExk2bBjVqlWjYsWKzJo1C4DNmzfToEEDunTpQlBQUKrHMTExvPbaawQFBVG5cmU2bdoEwLx58+jYsSPPPfccTZs2tev1ljWwQgjHK18eLl6EpUvh66/h+nUICIC+faFpU+sLPQkhhBBC2GJszgy8922Th/PmzUv16tVZs2YNbdq0ISwsjBdffBGlFOPGjSNPnjwkJibSqFEjDh06RMWKFU3eZ9+
"text/plain": [
"<Figure size 1152x576 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrazenie regularyzacyjne** zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej funkcja kosztu\n",
"\n",
"$$\n",
2021-04-07 15:03:18 +02:00
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} \\left( h_\\theta(x^{(i)}) - y^{(i)} \\right) \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h,theta,X,y,lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0/m \\\n",
" * -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0) \\\n",
" + lamb/(2*m) * np.sum(np.power(theta[1:] ,2))\n",
" return j\n",
"\n",
"def dJ_(h,theta,X,y,lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" g[1:] += lamb/m * theta[1:]\n",
" return g"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(min=0.0, max=0.5, step=0.005, value=0.01, description=r'$\\lambda$', width=300)\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "16432480b157427989a9ae11e70ecf69",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=lamb)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label='training error')\n",
" plt.plot(Lambda, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(r'$\\lambda$')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHmCAYAAABK9WIBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeXSd5WHv+9+jrWFr1tYsS/KAMWBsC9vIxgFCIBRKuKeZSIgzlnYFmnRI2t5LoV33kDSnPU17U06am1IWZEF6szhwU9IM51ySEu4NAYIB2wzGeB5lWbLmWdqStvZz/3jePUiWZVneW9Irfz9rvWu/7570CBLwl+d539dYawUAAAAAwGKXsdADAAAAAABgNghYAAAAAIAvELAAAAAAAF8gYAEAAAAAvkDAAgAAAAB8gYAFAAAAAPhCWgPWGHOHMeagMeaIMebBaV4vNsb8D2PMO8aY94wxv5fO8QAAAAAA/Muk6z6wxpiApEOSbpPULGmnpE9ba/clveevJBVbax8wxlRIOiip2lo7lpZBAQAAAAB8K50zsFslHbHWHvOC9BlJH5nyHiup0BhjJBVI6pYUSeOYAAAAAAA+lc6ArZV0Kum42Xsu2XclrZXUIuldSV+11kbTOCYAAAAAgE9lpvG7zTTPTV2v/NuS3pb0QUmrJf3SGPOytbZ/0hcZc5+k+yQpPz//2quuuioNwwUAAAAALLTdu3d3WmsrpnstnQHbLKk+6bhObqY12e9J+qZ1J+IeMcYcl3SVpDeS32StfUzSY5LU2Nhod+3albZBAwAAAAAWjjHm5LleS+cS4p2S1hhjVhljsiVtl/SzKe9pknSrJBljqiRdKelYGscEAAAAAPCptM3AWmsjxpg/lvQfkgKSnrDWvmeM+ZL3+qOS/ouk7xtj3pVbcvyAtbYzXWMCAAAAAPhXOpcQy1r7nKTnpjz3aNJ+i6Tb0zkGAAAAAMDSkNaABQAAAIB0GR8fV3Nzs8Lh8EIPBXMQDAZVV1enrKysWX+GgAUAAADgS83NzSosLNTKlStlzHQ3QcFiZa1VV1eXmpubtWrVqll/Lp0XcQIAAACAtAmHwyorKyNefcgYo7KysguePSdgAQAAAPgW8epfc/l7R8ACAAAAwBz09vbqkUcemdNn77zzTvX29s74noceekgvvPDCnL5/qSJgAQAAAGAOZgrYiYmJGT/73HPPqaSkZMb3fOMb39Bv/dZvzXl8F2rqmM/3O1zo+1KBgAUAAACAOXjwwQd19OhRbdy4Uffff79efPFF3XLLLfrMZz6jDRs2SJI++tGP6tprr9W6dev02GOPxT+7cuVKdXZ26sSJE1q7dq3uvfderVu3TrfffrtGRkYkSffcc4+effbZ+Pu/9rWvafPmzdqwYYMOHDggSero6NBtt92mzZs36w/+4A+0YsUKdXZ2njXW559/Xu973/u0efNmffKTn9Tg4GD8e7/xjW/oxhtv1L/927+ddfz0009rw4YNWr9+vR544IH49xUUFOihhx7Sddddpx07dqTnL/A0uAoxAAAAAN9b+eD/k7bvPvHN/2Xa57/5zW9q7969evvttyVJL774ot544w3t3bs3fmXdJ554QqWlpRoZGdGWLVt01113qaysbNL3HD58WE8//bQef/xx3X333frRj36kz33uc2f9vPLycr355pt65JFH9K1vfUvf+9739Nd//df64Ac/qL/8y7/UL37xi0mRHNPZ2am/+Zu/0QsvvKD8/Hz9/d//vR5++GE99NBDktztbF555RVJLspjxy0tLdq2bZt2796tUCik22+/XT/5yU/00Y9+VENDQ1q/fr2+8Y1vzP0v7BwwAwsAAAAAKbJ169ZJt4X5zne+o2uuuUbbtm3TqVOndPjw4bM+s2rVKm3cuFGSdO211+rEiRPTfvfHP/7xs97zyiuvaPv27ZKkO+64Q6FQ6KzPvfbaa9q3b59uuOEGbdy4Uf/6r/+qkydPxl//1Kc+Nen9seOdO3fq5ptvVkVFhTIzM/XZz35WL730kiQpEAjorrvums1fkpRiBhYAAAAAUiQ/Pz++/+KLL+qFF17Qjh07lJeXp5tvvnna28bk5OTE9wOBQHwJ8bneFwgEFIlEJLn7qZ6PtVa33Xabnn766fOOOfl4pu8OBoMKBALn/dmpRsACAAAA8L1zLfNNp8LCQg0MDJzz9b6+PoVCIeXl5enAgQN67bXXUj6GG2+8UT/84Q/1wAMP6Pnnn1dPT89Z79m2bZv+6I/+SEeOHNHll1+u4eFhNTc364orrpjxu6+77jp99atfVWdnp0KhkJ5++mn9yZ/8Scp/hwvBEmIAAAAAmIOysjLdcMMNWr9+ve6///6zXr/jjjsUiUTU0NCg//yf/7O2bduW8jF87Wtf0/PPP6/Nmzfr5z//uWpqalRYWDjpPRUVFfr+97+vT3/602poaNC2bdviF4GaSU1Njf7u7/5Ot9xyi6655hpt3rxZH/nIR1L+O1wIM5sp58WksbHR7tq1a6GHAQAAAGCB7d+/X2vXrl3oYSyo0dFRBQIBZWZmaseOHfryl78cv6iUH0z399AYs9ta2zjd+1lCDAAAAAA+1dTUpLvvvlvRaFTZ2dl6/PHHF3pIaUXAAgAAAIBPrVmzRm+99dZCD2PecA4sAAAAAMAXCFgAAAAAgC8QsAAAAAAAXyBgAQAAAAC+QMACAAAAwDwpKCiQJLW0tOgTn/jEtO+5+eabdb5bh37729/W8PBw/PjOO+9Ub29v6ga6SBGwAAAAADDPli1bpmeffXbOn58asM8995xKSkpSMbTzikQiMx7P9nNzQcACAAAAwBw88MADeuSRR+LHX//61/WP//iPGhwc1K233qrNmzdrw4YN+ulPf3rWZ0+cOKH169dLkkZGRrR9+3Y1NDToU5/6lEZGRuLv+/KXv6zGxkatW7dOX/va1yRJ3/nOd9TS0qJbbrlFt9xyiyRp5cqV6uzslCQ9/PDDWr9+vdavX69vf/vb8Z+3du1a3XvvvVq3bp1uv/32ST8npqOjQ3fddZe2bNmiLVu26De/+U38d7vvvvt0++236wtf+MJZxydPntStt96qhoYG3XrrrWpqapIk3XPPPfrzP/9z3XLLLXrggQcu+q8594EFAAAA4H9fL07jd/dN+/T27dv1p3/6p/rDP/xDSdIPf/hD/eIXv1AwGNSPf/xjFRUVqbOzU9u2bdOHP/xhGWOm/Z5/+Zd/UV5envbs2aM9e/Zo8+bN8df+9m//VqWlpZqYmNCtt96qPXv26Ctf+Yoefvhh/epXv1J5efmk79q9e7eefPJJvf7667LW6rrrrtMHPvABhUIhHT58WE8//bQef/xx3X333frRj36kz33uc5M+/9WvflV/9md/phtvvFFNTU367d/+be3fvz/+3a+88opyc3P19a9/fdLx7/zO7+gLX/iCfvd3f1dPPPGEvvKVr+gnP/mJJOnQoUN64YUXFAgE5vbXPwkBCwAAAABzsGnTJrW3t6ulpUUdHR0KhUJavny5xsfH9Vd/9Vd66aWXlJGRodOnT6utrU3V1dXTfs9LL72kr3zlK5KkhoYGNTQ0xF/74Q9/qMcee0yRSEStra3at2/fpNeneuWVV/Sxj31M+fn5kqSPf/zjevnll/XhD39Yq1at0saNGyVJ1157rU6cOHHW51944QXt27cvftzf36+BgQFJ0oc//GHl5ubGX0s+3rFjh/793/9dkvT5z39ef/EXfxF/3yc/+cmUxKtEwAIAAADAnH3iE5/Qs88+qzNnzmj79u2SpKeeekodHR3avXu3srKytHLlSoXD4Rm/Z7rZ2ePHj+tb3/qWdu7cqVAopHvuuee832OtPedrOTk58f1AIDDtEuJoNKodO3ZMCtWYWBSf6zhZ8u8z0/suFAELAAAAwP/Oscw33bZv3657771XnZ2d+vWvfy1J6uvrU2VlpbKysvSrX/1KJ0+enPE7brrpJj311FO65ZZbtHfvXu3Zs0eSm/3Mz89XcXGx2tra9POf/1w333yzJKmwsFADAwNnLSG+6aabdM8
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=0.01, trainsetsize=m)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label='training error')\n",
" plt.plot(M, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(u'trainset size')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHgCAYAAACcrIEcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd5gV5dnH8e+wS+9NQRDBjlRhKYoFRBCxUxQVFbtGo8bEmtcSElOMGjXGJBasBKOAFSliREVRASugKCIgoEiRXnd33j9m4ezCsrRzdvbsfj/XtZdzP/Occ+4FUX47M88ThGGIJEmSJEklXbm4G5AkSZIkaWcYYCVJkiRJacEAK0mSJElKCwZYSZIkSVJaMMBKkiRJktKCAVaSJEmSlBYy425gV9WrVy9s2rRp3G1IkiRJklJg6tSpS8IwrF/YubQLsE2bNmXKlClxtyFJkiRJSoEgCOZu75y3EEuSJEmS0oIBVpIkSZKUFgywkiRJkqS0kHbPwEqSJEkSwKZNm5g/fz7r16+PuxXthkqVKtG4cWPKly+/068xwEqSJElKS/Pnz6d69eo0bdqUIAjibke7IAxDli5dyvz582nWrNlOv85biCVJkiSlpfXr11O3bl3DaxoKgoC6devu8tVzA6wkSZKktGV4TV+783tngJUkSZKk3bB8+XIefvjh3Xpt7969Wb58eZFzbr/9dsaPH79b719aGWAlSZIkaTcUFWBzcnKKfO3rr79OrVq1ipwzePBgjj/++N3ub1dt3fOOvoddnZcMBlhJkiRJ2g0333wz3377LW3btuWGG25gwoQJdOvWjXPOOYdWrVoBcPrpp9O+fXtatGjBI488suW1TZs2ZcmSJcyZM4fmzZtz6aWX0qJFC3r27Mm6desAGDRoEMOHD98y/4477qBdu3a0atWKr776CoDFixfTo0cP2rVrx+WXX85+++3HkiVLtul13LhxHHHEEbRr147+/fuzevXqLe87ePBgjjrqKF544YVt6mHDhtGqVStatmzJTTfdtOX9qlWrxu23306nTp2YNGlSan6BC+EqxJIkSZLSXtObR6Xsvef8+aRCx//85z8zbdo0Pv30UwAmTJjARx99xLRp07asrDtkyBDq1KnDunXr6NChA3379qVu3boF3uebb75h2LBhPProo5x55pmMGDGCgQMHbvN59erV4+OPP+bhhx/mnnvu4bHHHuN3v/sdxx13HLfccgtjxowpEJI3W7JkCX/4wx8YP348VatW5S9/+Qv33Xcft99+OxBtZzNx4kQgCuWb64ULF9K5c2emTp1K7dq16dmzJy+99BKnn346a9asoWXLlgwePHj3f2F3g1dgJUmSJClJOnbsWGBbmAcffJA2bdrQuXNnvv/+e7755pttXtOsWTPatm0LQPv27ZkzZ06h792nT59t5kycOJEBAwYA0KtXL2rXrr3N6z744ANmzJhBly5daNu2LU899RRz587dcv6ss84qMH9zPXnyZLp27Ur9+vXJzMzk3HPP5Z133gEgIyODvn377swvSVJ5BVaSJEmSkqRq1apbjidMmMD48eOZNGkSVapUoWvXroVuG1OxYsUtxxkZGVtuId7evIyMDLKzs4FoP9UdCcOQHj16MGzYsB32nL8u6r0rVapERkbGDj872QywkiRJktLe9m7zTaXq1auzatWq7Z5fsWIFtWvXpkqVKnz11Vd88MEHSe/hqKOO4vnnn+emm25i3Lhx/Pzzz9vM6dy5M1dddRWzZs3iwAMPZO3atcyfP5+DDz64yPfu1KkT1157LUuWLKF27doMGzaMX/7yl0n/HnaFtxBLkiRJ0m6oW7cuXbp0oWXLltxwww3bnO/VqxfZ2dm0bt2a2267jc6dOye9hzvuuINx48bRrl07Ro8eTcOGDalevXqBOfXr1+fJJ5/k7LPPpnXr1nTu3HnLIlBFadiwIX/605/o1q0bbdq0oV27dpx22mlJ/x52RbAzl5xLkqysrHDKlClxtyFJkiQpZl9++SXNmzePu41YbdiwgYyMDDIzM5k0aRJXXnnllkWl0kFhv4dBEEwNwzCrsPneQqw9l5MN4++AxV9Bt1uhUfu4O5IkSZLKhHnz5nHmmWeSm5tLhQoVePTRR+NuKaUMsNpzb90Fkx6Kjue+D2c9AwcW34bLkiRJUll10EEH8cknn8TdRrHxGVjtmVnjYeJ9iXrTWvjPAJg2Ir6eJEmSJJVKKQ2wQRD0CoJgZhAEs4IguLmQ8zWDIHg1CILPgiCYHgTBhansR0m28gcYefm247mbYPjF8FHpvn1BkiRJUvFKWYANgiAD+AdwInAYcHYQBIdtNe0qYEYYhm2ArsC9QRBUSFVPSqLcHBh5KaxdEtXV9obLJkC9Q/ImhPD6b2DCnyHNFgqTJEmSVDKl8gpsR2BWGIazwzDcCDwHbL3mcghUD4IgAKoBy4DsFPakZHn7LzDn3bwigD6Pwj6Hw0VjoFG+BcMm/AlG3wi5ubG0KUmSJKn0SGWAbQR8n6+enzeW30NAc2Ah8AVwbRiGJp2SbvYEePvuRH3sTbD/sdFxlTpw/stwwHGJ8x89El2tzd5YrG1KkiRJJU21atUAWLhwIf369St0TteuXdnR1qH3338/a9eu3VL37t2b5cuXJ6/REiqVATYoZGzre0lPAD4F9gHaAg8FQVBjmzcKgsuCIJgSBMGUxYsXJ79T7bxVi2DEpWz5rWx6NBx7Y8E5FavB2f+FFn0SY9OGw3Nnw8Y1xdaqJEmSVFLts88+DB8+fLdfv3WAff3116lVq1YyWtuh7OzsIuudfd3uSGWAnQ/sm69uTHSlNb8LgZFhZBbwHXDo1m8UhuEjYRhmhWGYVb9+/ZQ1rB3Y/Nzrmp+iump96PsYlMvYdm5mhehc1sWJsVnj4enTYe2y4ulXkiRJSqGbbrqJhx9+eEt95513cu+997J69Wq6d+9Ou3btaNWqFS+//PI2r50zZw4tW7YEYN26dQwYMIDWrVtz1llnsW7dui3zrrzySrKysmjRogV33HEHAA8++CALFy6kW7dudOvWDYCmTZuyZEm0Ps19991Hy5YtadmyJffff/+Wz2vevDmXXnopLVq0oGfPngU+Z7PFixfTt29fOnToQIcOHXjvvfe2fG+XXXYZPXv25Pzzz9+mnjt3Lt27d6d169Z0796defPmATBo0CCuv/56unXrxk033bTHv+ap3Ad2MnBQEATNgAXAAOCcrebMA7oD7wZBsDdwCDA7hT1pT7x7L3z3dl4RQJ9HoHqD7c8vlwEn3QtV60XPzALM/wie6A3njYQa+6S8ZUmSJJURd9ZM4XuvKHR4wIABXHfddfziF78A4Pnnn2fMmDFUqlSJF198kRo1arBkyRI6d+7MqaeeSrT0z7b++c9/UqVKFT7//HM+//xz2rVrt+XcXXfdRZ06dcjJyaF79+58/vnnXHPNNdx333289dZb1KtXr8B7TZ06lSeeeIIPP/yQMAzp1KkTxx57LLVr1+abb75h2LBhPProo5x55pmMGDGCgQMHFnj9tddey69+9SuOOuoo5s2bxwknnMCXX3655b0nTpxI5cqVufPOOwvUp5xyCueffz4XXHABQ4YM4ZprruGll14C4Ouvv2b8+PFkZBRy4WsXpSzAhmGYHQTB1cBYIAMYEobh9CAIrsg7/y/g98CTQRB8QXTL8U1hGC5JVU/aA3MmRgsybXb0rws+57o9QQDdboUqdaPFnAAWfwmPnwDnvwR1D0hNv5IkSVKKHX744fz0008sXLiQxYsXU7t2bZo0acKmTZu49dZbeeeddyhXrhwLFixg0aJFNGhQ+MWfd955h2uuuQaA1q1b07p16y3nnn/+eR555BGys7P54YcfmDFjRoHzW5s4cSJnnHEGVatWBaBPnz68++67nHrqqTRr1oy2bdsC0L59e+bMmbPN68ePH8+MGTO21CtXrmTVqlUAnHrqqVSuXHnLufz1pEmTGDlyJADnnXceN96YeMywf//+SQmvkNorsIRh+Drw+lZj/8p3vBDomcoelASrF0f7um5
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
2021-04-21 11:24:35 +02:00
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"livereveal": {
"start_slideshow_at": "selected",
2021-04-07 15:03:18 +02:00
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}