2021-03-02 08:32:40 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
2021-04-06 09:51:15 +02:00
"## Uczenie maszynowe – zastosowania\n",
2021-03-02 08:32:40 +01:00
"# 5. Regresja wielomianowa. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ – szerokość działki, $x_2$ – długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ – powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
2021-04-07 15:03:18 +02:00
{
"cell_type": "markdown",
2021-04-14 08:03:54 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-04-07 15:03:18 +02:00
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
2021-03-02 08:32:40 +01:00
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
2021-04-14 08:03:54 +02:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
2021-03-02 08:32:40 +01:00
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 7,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 8,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print('Algorithm does not converge!')\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 9,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv('data_flats.tsv', header=0, sep='\\t',\n",
" usecols=['price', 'rooms', 'sqrMetres'])\n",
"data = np.matrix(alldata[['sqrMetres', 'price']])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(m, 3 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 10,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 11,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-14 08:03:54 +02:00
"[<matplotlib.lines.Line2D at 0x27dd04c6250>]"
2021-03-02 08:32:40 +01:00
]
},
2021-04-14 08:03:54 +02:00
"execution_count": 11,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAFvCAYAAADkPtfiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVzU1foH8M+ZGYZlQNxABHdB3Jfcy0ottWixbMEs27y3LKvrre7PunWX7m273ZZrpZWZmVlmpWkL2eoS5YZr4gbuCAKKwgCyzZzfH4cRxBlmBmb4DjOf9+vla3Tmy3eewODxnOc8j5BSgoiIiIi0o9M6ACIiIqJAx4SMiIiISGNMyIiIiIg0xoSMiIiISGNMyIiIiIg0xoSMiIiISGPNMiETQiwQQuQJIXa5eP2tQojdQoh0IcTH3o6PiIiIyB2iOfYhE0JcBqAYwCIpZV8n1yYA+BTAWCnlaSFEtJQyryniJCIiInJFs1whk1KuA1BQ+zkhRHchxCohxBYhxC9CiJ7VL/0RwBwp5enqj2UyRkRERD6lWSZkDswD8LCUcjCAxwHMrX6+B4AeQohfhRAbhBBXaRYhERERkR0GrQPwBCFEOICLAXwmhLA9HVz9aACQAGA0gA4AfhFC9JVSnmnqOImIiIjs8YuEDGql74yUcqCd17IAbJBSVgI4JITYB5WgbW7KAImIiIgc8YstSyllEVSydQsACGVA9csrAIypfr4t1BbmQU0CJSIiIrKjWSZkQoglANYDSBRCZAkhpgG4HcA0IcQOAOkAJlZf/h2AU0KI3QBWA/iLlPKUFnETERER2dMs214QERER+ZNmuUJGRERE5E+YkBERERFprNmdsmzbtq3s0qULACD7zFmcKqlAm3AjYiNDtQ2MiJpWVhaQm+v49ZgYIC6u6eKhgFFpkdh7oggCQK/2LaDXCacfQ4Fhy5YtJ6WUUQ352GaXkHXp0gVpaWkAgN+zCnHdm6lobTLityevgNHABT+igDF/PjBzJlBScuFrJhPw7LPAtGlNHxf5vXfXHcRzKXtwVZ8YvD11sNbhkA8RQhxp6Mc26wymb1wLJLaLQEFJBVbv40QkooCSnAzoHHwL0+nU60Re8MW24wCAGwZxBZY8p1knZEII3Dy4AwDg8y1ZGkdDRE0qIgJISVGPJpN6zmSqeT48XNv4yC/tO2HG7pwitAgxYEzPBu1MEdnV7LYs65o4KBYvrtqL1XvzcLK4HG3Dg51/EBH5h1GjgOxsYOlSIDMTiI9XK2NMxshLVmxXq2PX9I9FsEGvcTTkT5p9QhYdEYLRPaLw0948rNyejWmjumodEhE1pfBw1opRk7BaJVZWb1feyO1K8rBmvWVpw21LIiLyto2HCpBdWIa4lqEY0rmV1uGQn/GLhGxsr2i0DAvCnpwipGcXah0OERH5oRW1Vsd0bHVBHuYXCVmwQY+JA2IBcJWMiIg8r6zSgpTfcwAANwyK1Tga8kd+kZABwM2DOwIAVm7PRkWVVeNoiIjIn/y8Nw/m8ir0i4tEfHSE1uGQH/KbhIw9yYiIyFvYe4y8zW8SMvYkIyIibzhdUoE1+/KgE8B1A9prHQ75Kb9JyADVk0yvE+d6khERETXW17/noNIiMSohCtERIVqHQ37KrxIyW0+yKqvEyu3ZWodDRER+wHa6chK3K8mL/CohA9iTjIiIPOfoqVJsOXIaYUY9xvdpp3U45Mf8LiFjTzIiIvIU26ikCX1iEGZs9sNtyIf5XUIWbNDjhoFqWZmrZERE1FBSSizfqn6O8HQleZvfJWRAzbYle5IREVFDbThYgMOnShHTIgSXdG+jdTjk5/wyIesT2wI9Y1RPsp/3sicZERG5b8mmowCAW4d0gEHvlz8uyYf45d8w9iQjIqLGOF1SgVW7TkAI4NahHbUOhwKA1xIyIURHIcRqIcQeIUS6EOJPdq4ZLYQoFEJsr/71d0+9/8SBcaon2b485JvZk4yIiFy3fNtxVFisuDQhCh1ahWkdDgUAb66QVQF4TErZC8AIADOEEL3tXPeLlHJg9a9/eerNoyKCMSYxCharxMrqUzJERETOSCnxSfV25ZRhXB2jpuG1hExKmSOl3Fr9ezOAPQCa9JhK7W1LKWVTvjURETVTW4+eRkZeMdqGB+OKXuw9Rk2jSWrIhBBdAAwCsNHOyyOFEDuEEN8KIfp48n3H9myHVmFB2HvCjPTsIk/emoiI/NSSTccAqH/UB7GYn5qI1/+mCSHCASwDMFNKWTcr2gqgs5RyAIA3AKxwcI/7hBBpQoi0/Px8l9/baNBhInuSERGRi4rKKvH1TjV6bzKL+akJeTUhE0IEQSVjH0kpl9d9XUpZJKUsrv59CoAgIURbO9fNk1IOkVIOiYqKciuGmp5kx9mTjIiI6rVyezbKKq0Y2a0NurQ1aR0OBRBvnrIUAN4DsEdK+aqDa2Kqr4MQYlh1PKc8GYetJ9np0kr2JCMiIoeklFiyURXzT2YxPzUxb66QXQJgKoCxtdpaJAkhpgshpldfczOAXUKIHQBeBzBZerj6nj3JiIjIFb8fL8TunCK0DAvChD4xWodDAcZrk1KllKkAhJNr3gTwprdisJk4MA4vfLv3XE+yqIhgb78lERE1M7Zi/kmDOiAkSK9xNBRoAuL4CHuSERFRfUrKq/Bl9c+H27hdSRoIiIQMYE8yIiJy7Oud2SipsGBI51ZIaBehdTgUgAImIWNPMiIicsS2XTl5WCeNI6FAFTAJGXuSERGRPXtPFGH7sTOICDHgmn7ttQ6HAlTAJGQAe5IREdGFPqleHbthYBxCjSzmJ20EVELGnmRERFRbWaUFy7eqXRP2HiMtBVRCxp5kRERU27e7clBUVoX+HSLRJzZS63AogAVUQgYANwyKg0EnzvUkIyKiwLVkY3Ux/1AW85O2Ai4haxsejNGJ0exJRkQU4DLzirHpcAHCjHpcPzBW63AowAVcQgbUFPd/lsaeZEREgWrpZjW38voBsQgP9trgGiKXBGRCNrZnNFqbjNiXa8a2Y2e0DoeIiJpYeZUFy7aqXRL2HiNfEJAJmdGgw61D1GmaxeuPaBwNERE1tR9256KgpAI9YyIwoAOL+Ul7AZmQAcDtwztBCODrnTkoKKnQOhwiImpCtt5jtw3rBCGExtEQBXBC1rF1GMYkRqPCYsWnace0DoeIiJrI0VOlSM08iWCDDjdUT3Ah0lrAJmQAMHVEZwDARxuPwGJlcT8RUSBYmqaK+a/p1x6RYUEaR0OkBHRCdlmPKHRsHYpjBWexbn++1uEQEZGXVVqs+CzN1pmfxfzkOwI6IdPrBG4frlbJPtzA4n4iIn/389485JnL0T3KhKFdWmkdDtE5AZ2QAcCtQzrCaNBh9b48HCso1TocIiLyok82qe1KFvOTrwn4hKy1yYhr+7WHlMBHG49qHQ4REXlJ9pmzWLs/H0a9DpMu6qB1OETnCfiEDABury7u/zTtGMqrLBpHQ0RE3vBp2jFYJTC+Tzu0Nhm1DofoPEzIAFzUqSV6t2+BgpIKfPv7Ca3DISIiD7NYJT7dXNN7jMjXMCEDIITA1JEs7ici8lfrMvKRXViGTq3DMLJbG63DIboAE7JqEwfGIiLYgC1HTiM9u1DrcIiIyINsxfzJQztCp2MxP/keJmTVwowG3DRYFXku3sDifiIif5FnLsNPe/Kg1wncMpjF/OSbmJDVckd1cf+KbcdRVFapcTREROQJn2/JQpVV4spe0YhuEaJ1OER2MSGrJT46HBd3b4OzlRYs35KldThERNRIVqs8N0icnfnJlzEhq8M23/LDDUcgJedbEhE1Z+sPnsLRglLEtQzFZQlRWodD5BATsjqu7N0O7VoE40B+CdYfPKV1OERE1AhLqov5bxnSAXoW85MPY0JWR5Beh8lD1bL2YrbAICJqtgpKKvB9ei50Qo3JI/JlTMjsuG1YJ+h1At+l5yK3qEzrcIiIqAGWb81ChcWKy3tEIbZlqNbhENWLCZkdMZEhGN+7HSxWeW65m4iImg8pa75/s5ifmgMmZA7YivuXbDqKSotV42i
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 12,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-07 15:03:18 +02:00
"[[ 397521.22456017]\n",
" [-841359.33647153]\n",
" [2253763.58150567]\n",
" [-244046.90860749]]\n"
2021-03-02 08:32:40 +01:00
]
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmQAAAFvCAYAAADkPtfiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3iUVfYH8O+dkjZpkEJIKKGEjoDSRQRsgIJ9g2DHgouu6LqrrrvqFhd/6FpBXDc2xIIiIkpbXSwE6U0SWkJLB0JIMunJzP398WZICFOTmbxTvp/n4RmZefPOMdHk5Nx7zxFSShARERGRejRqB0BEREQU6JiQEREREamMCRkRERGRypiQEREREamMCRkRERGRypiQEREREanMJxMyIcR7QohTQogMJ6//jRBivxAiUwjxiafjIyIiInKF8MU+ZEKI8QAqACyRUg5ycG0KgM8BTJJSnhVCxEspT7VHnERERETO8MkKmZTyZwAlzZ8TQvQSQqwTQuwUQmwUQvRrfOl+AIuklGcbP5bJGBEREXkVn0zIbHgHwCNSyksAPAHgrcbn+wDoI4TYJITYIoSYrFqERERERFbo1A7AHYQQ4QDGAvhCCGF5OrjxUQcgBcAEAF0AbBRCDJJSlrZ3nERERETW+EVCBqXSVyqlHGrltTwAW6SU9QCOCSEOQUnQtrdngERERES2+MWSpZSyHEqydSsACMWQxpdXApjY+HwslCXMo6oESkRERGSFTyZkQohPAWwG0FcIkSeEmA1gFoDZQoi9ADIBXN94+XoAZ4QQ+wH8AOAPUsozasRNREREZI1Ptr0gIiIi8ic+WSEjIiIi8idMyIiIiIhU5nOnLGNjY2VycrJH7p1bUoXS6nokRIYgLiLY8QcQkXry8oCTJ22/npAAJCW1XzxEAMxS4kChEWYpkRIfjhC9Vu2QqB3t3LmzWEoZ15qP9bmELDk5GTt27PDIvdfuK8RDH+/C0K7RWDn3Uo+8BxG5SVoaMG8eUFl54WsGA/CPfwCzZ7d/XBTQPt56As98lYERyR3wxZyxaodD7UwIcaK1H8sly2Yu7xuHEL0Ge3JLUVhWrXY4RGRPaiqgsfEtTKNRXidqR1JKfLRZ+Xl8++juKkdDvoYJWTNhQTpM6BMPAFiXUaRyNERkV0QEsGaN8mgwKM8ZDE3Ph4erGx8FnJ0nzuJgkRExhiBMHpSgdjjkY3xuydLTpgxOwLrMIqzNKMI9l/ZQOxwismfcOKCgAFi2DMjOBnr3VipjTMZIBUu3KNWx1BFdEazj3jFyDROyFib1i0eQVoPtx0tw2ljLzf1E3i48nHvFSHXFFbVYs68IQgAzR3VTOxzyQVyybCEiRI9xKbGQEvjvfi5bEhGRY5/vyEWdyYxJfePRpUOY2uGQD2JCZoVl7Z/7yIiIyBGTWeKTrTkAgNvHcDM/tQ4TMiuu6t8JWo3A5iNnUFpVp3Y4RETkxX46fAp5Z6vRtWMoLk9pVQsqIiZk1nQwBGFMzxg0mCW+22+n8SQREQW8c60uRnWHRiNUjoZ8FRMyG7hsSUREjuSWVOHHw6cRpNPg1uFd1Q6HfBgTMhuuHtgJQgAbs4phrKlXOxwiIvJCH2/NgZTAdYM7o6MhSO1wyIcxIbMhPiIEI7p3RJ3JjA0HT6kdDhEReZnaBhM+35ELgJv5qe2YkNnBZUsiIrJl7b4ilFTWYUDnSAzrGq12OOTjmJDZYUnIfjx0GtV1JpWjISIib/JRY2f+O8Z0hxDczE9tw4TMjsToUAzpGo3qehN+OsxlSyIiUuwvKMfOE2cREazD9UMT1Q6H/AATMgemNFbJ1nLZkoiIGi3dqlTHbr6kC8KCOIWQ2o4JmQOWhGzDgVOobeCyJRFRoCuvqcfK3fkAgNtHc24luQcTMge6xxjQv3MkjLUN2JRdrHY4RESksq925aOqzoQxPWPQOz5C7XDITzAhc8K5Zct9XLYkIgpkUkosbbaZn8hdPJaQCSG6CiF+EEIcEEJkCiEetXLNBCFEmRBiT+OfZz0VT1tYErLvDpxEvcmscjRERKSWrcdKkHWqAvERwbhqQCe1wyE/4smdiA0Afi+l3CWEiACwUwjxnZRyf4vrNkopr/NgHG2W0ikCveIMOHK6EluPlmBcSqzaIRERkQosrS5mjOwGvZaLTOQ+HvuvSUpZKKXc1fjPRgAHACR56v08bcqgzgCAtRmFKkdCRERqOGWswfqMImg1AreN5NxKcq92Se+FEMkAhgHYauXlMUKIvUKItUKIge0RT2tYmsSuzzwJk1mqHA0REbW3Zdty0WCWuKp/J3SOClU7HPIzHk/IhBDhAL4EME9KWd7i5V0AuksphwB4E8BKG/d4QAixQwix4/Tp054N2IaBiZHo2jEUxRW12HnirCoxEBGROhpMZnyyLQcAcPtobuYn9/NoQiaE0ENJxj6WUq5o+bqUslxKWdH4z2sA6IUQF2zQklK+I6UcLqUcHhcX58mQbRJCcNmSiChA/e/gKRSW1aBnrAFje8WoHQ75IU+eshQA3gVwQEr5io1rEhqvgxBiZGM8ZzwVU1udW7bMKIKUXLYkIgoUllYXs0Z3h0bDuZXkfp48ZXkpgDsA7BNC7Gl87k8AugGAlPJtALcAeEgI0QCgGsAM6cWZztAu0UiIDEFBWQ325pVhaNdotUMiIiIPO1ZciY1ZxQjRa3DLxV3UDof8lMcSMillOgC7v0ZIKRcCWOipGNxNoxGYPCgBH/xyHGszCpmQEREFgI8bq2PThyQiKkyvcjTkr9hExUWWZct1XLYkIvJ7NfUmfLEzDwBwx+hkdYMhv8aEzEUjkjsixhCEE2eqcKDQqHY4RETkQd/sLUBZdT2GdI3G4C5RaodDfowJmYu0GoGrByrjMtbxtCURkd+SUp7rzH/7qG4qR0P+jglZK0xubH+xLpPDxomI/NWOE2fxa14ZOoTpMW1IotrhkJ9jQtYKY3rGIDJEh8MnK3DkdIXa4RARkQe8u/EYAOCO0d0RoteqHA35OyZkrRCk0+DKAZZlS1bJiIj8zYkzlVi/vwhBWg1uH8PO/OR5TMhaiV37iYj81/ubjkNKYPrQRMRHhKgdDgUAJmStdFlKLAxBWmTklyO3pErtcIiIyE3KquvxxY5cAMDscT1UjoYCBROyVgrRazGxXzwALlsSEfmTZdtzUFlnwqW9Y9C/c6Ta4VCAYELWBly2JCLyLw0mMz7YdBwAcN+4nuoGQwGFCVkbTOgbh2CdBrtySlFUVqN2OERE1EZrM4pQUFaDnnEGXN4nTu1wKIAwIWsDQ7Du3P+w69mTjIjIp0kpkZautLqYPa4HNBq745iJ3IoJWRtNGazMtlyzj8uWRES+bFfOWezNLUWHMD1uGtZF7XAowDAha6Mr+ndCsE6DrcdKkF9arXY4RETUSmmNjWBnjeqO0CA2gqX2xYSsjSJD9LiqsUnsyt35KkdDREStkVtShfWZRdBrBe5kI1hSARMyN7j5YqW0vWJXHqSUKkdDRESuen/TcZglMH1IEuIj2QiW2h8TMje4LCUWseFBOHK6Er/mlakdDhERuaC8ph7LtucAYCNYUg8TMjfQaTWYPiQJgFIlIyIi37FsWy4q60wY2ysGAxLZCJbUwYTMTW66WEnIvvm1EHUNZpWjISIiZzSYzPjgl+MAgPsuY3WM1MOEzE0GJkaib6cIlFTW4afDp9UOh4iInLAuswj5pdXoGWfAhD7xaodDAYwJmZsIIc5VybhsSUTkG95tbAR776VsBEvqYkLmRtcPTYIQwP8OnEJZVb3a4RARkR07T5zF7pxSRIfpz52WJ1ILEzI3SogKwbjesagzmfHtvgK1wyEiIjveTT8KAJg1qhsbwZLqmJC5WdOyJZvEEhF5q9ySKqzLsDSCTVY7HCImZO52zcAEhAVpsfPEWRwvrlQ7HCIisuKDX5RGsNMuSkQnNoIlL8CEzM3CgnSYMqgzAGAFRykREXkdY009lm3PBQDcy0aw5CWYkHmAZdnyq90cpURE5G2Wbc9FRW0DxvSMwaCkKLXDIQLAhMwjRveMQee
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 13,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1,x2,n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n+1):\n",
" for i in range(m+1):\n",
" X.append(np.multiply(np.power(x1,i),np.power(x2,(m-i))))\n",
" return np.hstack(X)"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 14,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
2021-04-14 08:03:54 +02:00
"execution_count": 14,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv('polynomial_logistic.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 15,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c='r', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1p, X2p, c='g', marker='o', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" return fig"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 16,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3df3Ac533n+c8XEqGUQSQx9cOhafOkXaJcK/k2tMjTJmuWaceRV4LKJignBXm1Xt2eqliqO5VIUcmKqew5Lmev4tNdSEO7XqcUrCvOFi9CtgRSTMRIlnVZ+7guJwa5+kGeogDRObICns1QXnuIXAG05nt/9DTRGMwAM8DM9PN0v19VU5jp7gGexsz0fPvp7/N9zN0FAACAcPXl3QAAAACsjIANAAAgcARsAAAAgSNgAwAACBwBGwAAQOAI2AAAAAJ3dd4NyMN1113nN954Y97NAAAAWOL06dN/6+7X1y8vZcB24403ampqKu9mAAAALGFmf91oOZdEAQAAAkfABgAAEDgCNgAAgMARsAEAAAQu94DNzL5sZt83s7NN1puZPW5mM2b2spndmll3h5m9Vlt3qHetBgAA6J3cAzZJvyfpjhXW3ylpqHbbJ+lLkmRmV0n6Ym39zZI+ZWY3d7WlAAAAOcg9YHP3b0h6a4VN9kj6fU98S9JPm9lmSbdJmnH31919QdKTtW0BAAAKJfeArQVbJH038/jN2rJmywEAAAolhoDNGizzFZY3/iVm+8xsysymLly40LHGNeUuHTuW/GxlOQAAQBMxBGxvSnpv5vF7JM2usLwhd3/C3Xe6+87rr18240PnHT8u3X239PDDi8GZe/L47ruT9QDWhhMiACUTQ8B2QtI/r40W/TlJP3T385K+LWnIzG4ys35J99S2DcPIiLR/vzQ2thi0Pfxw8nj//mQ9gLXhhAhAyeQ+l6iZ/YGkD0u6zszelPQbkjZIkrv/jqSTkoYlzUj6O0n/orbux2b2oKTnJF0l6cvufq7nO9CMmXTkSHJ/bCy5SUmwduRIsj507skX38jI0vY2Ww70SvaESEo+U5wQASgw8xJeOti5c6f3bPJ3d6kv05FZrcYT5Bw7lvRWZIPMbE/h5KS0d2/erURZZd+LqZhOiABEozJf0cS5CU1fnNbQtUMavWVUg9cMduVvmdlpd9+5bDkBWxfF/oVSfxm3vhcjlv1AccV8QgQgCqfeOKXho8OqelVzl+c0sGFAfdank/ee1K6tuzr+95oFbDHksMWpPtipVpfntIUuvaybtruvj2AN4Ug/Y1mxfLYARKEyX9Hw0WFVFiqauzwnSZq7PKfKQrL80sKlnrWFgK1bjh9fHtxkg59YkqKzuXgpgjXkrQgnROgeRhGjQybOTajq1Ybrql7VxNmJnrWFgK1bRkaSHK9scJMGP5OT8SRF04uBEBXlhAjdwShidMj0xekrPWv15i7PaeatmZ61hYCtW8yShPz6nqhmy0NELwZC1coJEb0s5UVZJXTI0LVDGtgw0HDdwIYBbdu0rWdtYdABmmOUKGLG+7fcYh/0hSBU5ivacniLKguVZesG+wc1+8isNvZv7OjfZJRoBgFbi6jDhpgxyhmMIkYHhDJKlIANQHHRy1JevPbooEsLlzRxdkIzb81o26ZtGn3/aMd71lIEbBkEbEBBtNILLNHLUjb0riJi1GEDUDyrjQY8doxRzmXEKGIUEAEbgHitNBrwoYekr3+dUc5lVJSySkAGl0QBxK1ZrtKHPiR98pOMEgUQFXLYMgjYgIJpNBpQYpQzgOiQwwagmJrNxiHFX7y6lyg0DASNgA1AvJiNo3OYzgkIGgEbgHgxGrBzmM6p8+i1RAeRwwYgXszG0VkUm+0spkfDGjDoIIOADQCaYDqnzqGAL9aAQQcAgJU1G8BRwhP7jqi/RN/XR7CGNSNgAwAwgKNb0qAti2ANa0DABgBgAEe30GuJDiFgAwAwnVM30GuJDro67wYAAAKQFhRudTlW16zXUkqW797N/xYtI2ADAKAb0l7LbHmZNGjbvZteS7SFgA0AgG6g1xIdRA4bAABA4AjYAAAAAhdEwGZmd5jZa2Y2Y2aHGqz/VTN7sXY7a2Zvm9mm2rrvmNkrtXVMXwAAAAon94DNzK6S9EVJd0q6WdKnzOzm7Dbu/r+5+3Z33y7p1yR93d3fymzykdr6ZVM5RItJgwEAQE3uAZuk2yTNuPvr7r4g6UlJe1bY/lOS/qAnLcvT8ePJpMHZWj1pTZ+776aIJQAUGSftqBNCwLZF0nczj9+sLVvGzN4h6Q5JT2UWu6SvmtlpM9vXtVb22sjI8gKL2QKMDAcHgOLipB11Qijr0WhCtWanDh+X9J/rLod+0N1nzewGSc+b2V+4+zeW/ZEkmNsnSVu3bl1vm7uvvsDi2Fhyn0mDAaD4siftUnLc56S91Mxz7lY1s5+X9Fl3/ye1x78mSe7+Ww22PSbpP7r7/9Hkd31W0iV3/99X+ps7d+70qalIxie4S32ZjtBqlWANAMoge2UlxUl74ZnZ6UY5+SFcEv22pCEzu8nM+iXdI+lE/UZm9lOSdkt6OrNswMwG0/uSPibpbE9a3QtMGgwA5ZW90pIiWCut3AM2d/+xpAclPSfpVUl/6O7nzOwBM3sgs+leSV9197nMsndJOmVmL0n6c0nPuPuzvWp7V6QJpdXq0u7vt9+W7rqLSYMBoCw4aUdGCDlscveTkk7WLfuduse/J+n36pa9Lulnu9y83koTTe+6S3rmmSRYO3xYOngweZwGbUwaDADFVT/QLJvDJtHTVkK597ChTppomgZnabCWfmhPnFicTBgAUEzHjy8N1tLLo+lABEaJlk7ugw7yEPygAxJNAaDc3JOgbGRk6XG/2XIURrNBBwRsoWJ0KACgWwgIgxXyKFHUI9EUANBNFOaNDgFbaOoTTavV5TMexIhpVlBAlfmKxs+M69HnH9X4mXFV5it5NwloDbPpRIdLoqE5diw5u8nmrGU/SJOTcY4OLep+obROvXFKw0eHVfWq5i7PaWDDgPqsTyfvPaldW3fl3TxgdeRLB4lLorEYGUmCl8OHky5p98XRQZOT0p49cfZIcTaHAqnMVzR8dFiVhYrmLielIecuz6mykCy/tHAp5xYCLYixMG+Jr9YQsIXGLOlpevrppfkFZklQc/BgnPkF9UPS+/qWD1kHIjFxbkJVrzZcV/WqJs5O9LhFwBrEmC9d4tw7ArZQFbFHKsazOaCB6YvTV3rW6s1dntPMWzM9bhHQpljzpYv43diiIGY6QAPZ4GZsbDHHIOYeqWZnc7HuD0pr6NohDWwYaBi0DWwY0LZN23JoFdCGZoV5pbBn0ynid2OLGHQQuqLUY1tpmpUSfNAQuDZrUlXmK9pyeIsqC8tHhQ72D2r2kVlt7N/Yi5YDaxN7HbaifDc2wKCDGMWYX9BMp6dZKXHiKbqgzbyYwWsGdfLekxrsH9TAhgFJSc/aYH+ynGANwUvzpeuDnGbLQ1Kk78Z2uHvpbjt27PDgVavu+/e7S8nPRo9jUq26T04ub3ez5auZnFz+v8j+jyYnO9NulMMaP2+V+YqPnx73Q88f8vHT416Zr/S44UDJFO27sQFJU94gdsk9eMrjFkXARkCyshJ8aNFj2fdQeuO9BISlBN+NzQI2cthC5ZHnF/SCU/QRHeZt5sXwOQV6qwSfOXLYYhNzfkGvUCYEneRryIspcU0oIBcl/m4kYEO81vIFCzSS7a1tpyZViWtCAegt6rAhTvVfjNkyIRI9bWjPWmtSlbgmFIDeIocNcWIyeXTSevNi2s19A4AmyGFDsYyMJEFZthcj7e2YnORSFNqznrwYLs0D6AECNsSpxImnCMhac98AoE3ksAHAWsU6HyOA6BCwAcBapZfmszluadC2e3fPLs1X5iuaODeh6YvTGrp2SKO3jGrwmsGe/G0AvcGgg04pQTE/AOE59cYpDR8dVtWrmrs8p4ENA+qzPp2896R2bd2Vd/MAtIlBB91GAU0APVaZr2j46LAqCxXNXZ6TJM1dnlNlIVl+aeFSzi0E0CkEbJ1
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ – funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 17,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób, \n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0/(1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X*theta, eps)\n",
"\n",
"def J(h,theta,X,y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0)/m\n",
" if lamb > 0:\n",
" j += lamb/(2*m) * np.sum(np.power(theta[1:],2))\n",
" return j\n",
"\n",
"def dJ(h,theta,X,y,lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" if lamb > 0:\n",
" g[1:] += lamb/float(y.shape[0]) * theta[1:] \n",
" return g\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 18,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta]) \n",
" return theta, errors"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 19,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, Xpl, Ypl, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)\n",
"print(r'theta = {}'.format(theta))"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 20,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02),\n",
" np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 21,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-04-14 08:03:54 +02:00
"<ipython-input-20-d7e55b0bd73a>:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-07 15:03:18 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n"
2021-03-02 08:32:40 +01:00
]
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVxUVf8H8M9BAWHADXeUx43ccyPNtMg0U0pF08hsecoyy3Jr0bbnycq0LBVbbLGnxayoX4iZZlm2oZmimYkbuCsuCKLDgAzMnN8fw+gwzCDLzNxlPu/Xa14w995hvgPMvd8553vOEVJKEBEREZF6BSgdABERERFVjAkbERERkcoxYSMiIiJSOSZsRERERCrHhI2IiIhI5ZiwEREREalcbaUDUEKjRo1k69atbXeOHgVOnwaaNAFatSp/n4i8Ky8P2L+//HvO/l5s1w6oX1+5+IiIfGjr1q1npJSNnbf7ZcLWunVrpKWl2e5ICUyfDiQm2i4OADB1KrBwISCEckFWlZRASgoQH182bnfbidTC8T04bpztvWe/r8X3IhFRDQghDrvc7o8T58bExMiLCRtgu2AEOPQOW63au0CsWAGMHl32Aud4IUxOBkaNUjpKItcc/1ftmKwRkcKMRUYkpSchIycD0RHRSOiSgPDgcK8+pxBiq5Qyptx2v0/Y9HKhcHwd9vjZSkFaoocPTkSkG6lHUhG3PA5WaYWp2ARDoAEBIgBrxq/BgKgBXntedwmbfw86cE5yrFbb18RE23YtJbNC2JIye/wBAUzWSDvs70VHWnsPEpFuGIuMiFseB6PZCFOxCQBgKjbBaLZtzzfn+zwm/07YUlLKJzWOSU9KitIRVo09fkdM1kjt9PTBiXxDSlsZiPP/hrvtRFWUlJ4Eq7S63GeVViTtTPJxRP6esMXH22q7HJMae9KTnGzbryVspSAt0tsHJ/K+lBRbza7j+c1+/hs9mv8zVGMZORkXW9acmYpNyMzN9HFE/p6wCWErxHdugXK3Xc3YSkFadbkPTiNHsjWFyoqPL39+czz/ae3DNqlOdEQ0DIEGl/sMgQa0b9jexxH5e8KmJ2ylIK263AenlSvZmkJlsWaXvCyhSwIChOsUKUAEIKFrgo8j4ihR/eA8bKRXHAFN7nBkMXmR2kaJMmEjIvXTy/Q75Dn8nyAfyDfnI2lnEjJzM9G+YXskdE1AWFCYV5+TCZsDJmxEGsTWFLJjqyvpGOdhIyL1utw0DVYrR0DTJazZJT/EhI2IlHe5aRpGjOAIaLpEb1MyEVWCXy7+TkQq4zhNA1C2i+vmm4HVq8u3pgC2/bGxXCfX39hHEFd2O5EOMGEjIuU5J2H2xG3qVGDBAtvUHo4jne3Hx8ayNYWI/AIHHRCRenBgARH5OQ46ICJ149Jq3sF1N4l0gQkbESmPS6t5D9fdJNIFJmxEpDxO0+A9XHfT+9iKST7AGjYiUh6XVvMurgrgXStW2ForHX+njr/z5GSOXqVK40oHDpiwEZHf4YAO7+HKC+RBHHRAROSvOKDDu5y78AMCmKyRxzFhIyLSMw7o8A3HuQTtmKyRBzFhIyLSMw7o8A22YpKXMWEjItIzrrvpfWzFJB/g0lRERHrGdTe9z10rJsD1bsljmLARERHVhL0Vk+vdkhcxYSMiIqoJtmKSD7CGjYiIiEjlmLARERERqRwTNiIiIiKVU0XCJoQYKoTYK4TIFELMcrH/CSHE9tLbTiGERQjRsHTfISHEP6X7uN4UERER6Y7igw6EELUAvAXgRgDHAGwRQnwjpdxlP0ZKOR/A/NLjhwOYLqXMdfgxA6WUZ3wYNhEREZHPqKGFrQ+ATCnlASmlGcAXAEZWcPw4AJ/7JLLLkRJYsaL8pIjuthMRETnidYQqSQ0JWySAow73j5VuK0cIEQpgKICvHTZLAD8IIbYKISZ6LUpXUlKA0aPLzmRtn/F69Ggu+UJERBXjdYQqSfEuUQCuVsZ195FiOIANTt2h/aWUWUKIJgDWCSH2SCl/K/cktmRuIgBERUXVNGab+PhLy48AtkkSHZcn4WSJRERUEV5HqJLUkLAdA9DK4X5LAFlujr0dTt2hUsqs0q+nhRArYOtiLZewSSnfA/AeAMTExHimjdl5+RH7G85xeRIiIiJ3eB2hShJS4f5xIURtAPsADAJwHMAWAHdIKdOdjqsH4CCAVlJKU+k2A4AAKaWx9Pt1AF6QUq6t6DljYmJkWpoHB5RKCQQ49C5brXyTERFR5fE6QqWEEFullDHO2xWvYZNSlgB4BMD3AHYD+FJKmS6EmCSEmORw6CgAP9iTtVJNAaQKIf4GsBnA6sslax4K+lIxqL3WwNG0aSwUJSKiynF1HXGsaSOCOrpEIaVcA2CN07Z3nO5/BOAjp20HAHT3cnjl2YtEp0yx3V+8uOz3ixdfaubmJyQiInLHnqzZa9Yca9gAXkfoIlUkbJrjXCTqKnFLTARiY7nwLxERuZeSUjZZc65p43WESilew6YEj9SwSWnr+ly8+NI2+xsOsL0J4+P5yYiIiNyT0vX1wt120j13NWxM2GqCRaJERETkQaoddKBZLBIlIiIt4GoKusCErTqci0St1ks1bUzaiIhITbiagi5w0EF1sEiUSJeMRUYkpSchIycD0RHRSOiSgPDgcKXDIqoZrqagC6xhqw5/KBL1h9dI5CD1SCrilsfBKq0wFZtgCDQgQARgzfg1GBA1QOnwiGrGsWfIjqspqBJr2DxJiEstaI79/+62axGb0MmPGIuMiFseB6PZCFOxbW5uU7EJRrNte745X+EIiWrIsSfITivJGmvwADBhqxk9JzWOTej218cmdNKppPQkWKXV5T6rtCJpZ5KPIyLyMC0PlNPztbYKWMNWE3quC+CCxORHMnIyLrasOTMVm5CZm+njiIg8SOurKej5WlsFTNhqQu9Jjf31OdY86OF1ETmJjoiGIdDgMmkzBBrQvmF7BaIi8hCtD5TT+7W2kjjowBP0OoEui1TJTxiLjIhcEAmj2VhuX3hQOLIey0JYUJgCkRF5gF4Gken1WuuEgw68Rct1ARXhXHOkVdUoUA4PDsea8WsQHhQOQ6ABgK1lLTzItp3JGmmafUCcc3Ljbrsa6fVaWxVSSr+79e7dW3qE1Srl1KlSAravru5rVXJy+dfh+PqSk6v286xW22OcfyfuthNVVw3+d41FRrl061I5a90suXTrUmksMvooaCJyS8/XWhcApEkXuYviyZMSN48lbJ5OatTE0wmWnn9XpC5+dnIn0j0/u364S9hYw1YTUid1Ab4gKxilxLo48jTH/zc7/p8RaZOfXWvd1bAxYSPf4UWUfElWo0DZzy4MRKQ+HHRAytPyTNukLfYPB44qU6DMCTqJSKWYsJHvVPciSlQVzt3vVRnhzBU+iEilOHEu+UZFNWwAW9rIc2oySSgn6CQilWING/nGihW2LiXHC59jEpecrO6Ztkk7PFGHVp36t8uGJVFUaEbB+QIUnC/EhYIiFBeVoLiouPRWAkuJ5eLxwv58AggMDkRQcCAC61z6GhoeAkO9UISE1bl0LBFpHgcdOGDCpgAWc5NWVGFwjKXEgpwTZ5F9NAfZR8/gzPFc5J0+h3PZ55GXfR55p88hL/s8THkmmM4XwmpxvcB8TQQECBjqhcJQ34C6EeFo2Kw+GjStjwZN66FBs/qIaN4ATVs3RrPWTRDeMIzJHZHKuUvY2CVKvmGfUbuy24mU4KLrPv/hqTie+Amy9hUg65ohyDpwClmZJ3HqUDZysnJhtZb90Fs7sBbqN6mHeo3ron6TeoiMbo6w+gaElLaIGeqGICQ8BHUMwQgMDixtPauNwOBA1KpdCxAAHH6k1Wq92BJnvmBrjSsqNKPQWAjTuQLk55lgOlcA07kCnM8xIvtYDval7Ufe6XPlYgsJq3MxeWt5RQu06hiJqI62r/Ua1fXBL5iIqostbEREAAqMhTjw5jIcfnoeDve8Docj2uLw7mPIyTpb5riIFg3Qol0zNGvTBE1aNULjVhFo1DICTUq/htU3qKIVy2Kx4HxOPnKO5+L
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')\n",
"plot_decision_boundary(fig, theta, Xpl)"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 22,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv('polynomial_logistic.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, Xpl, Ypl, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 23,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-04-14 08:03:54 +02:00
"<ipython-input-20-d7e55b0bd73a>:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-07 15:03:18 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n"
2021-03-02 08:32:40 +01:00
]
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVxU1fsH8M8ZBEQYF3DHXTHXNKW01MykUkxFs8gsv/mzTMvdXMr2zbJc0CwrK7PsG5WIlrj2rRR33MUN3BVXXBgGZJk5vz+G0QFmWGfm3pn5vF+vecHce2d4gJl7nznnOecIKSWIiIiISL00SgdAREREREVjwkZERESkckzYiIiIiFSOCRsRERGRyjFhIyIiIlI5JmxEREREKldB6QCUUL16ddmoUSPTnbNngcuXgZo1gfr1C98nIse6cQM4frzwe878XmzaFKhaVbn4iEh1Lp+5ihtX0uBf2Q81G1SHt6+30iHZza5du65KKWsU3O6RCVujRo2QkJBguiMlMGECEBVlujgAwLhxwJw5gBDKBVlaUgKxsUBERP64bW0nUgvL9+Dgwab3nvm+K74XicjhDAYD/vhyHb57/WcYTxgx9J2n8MSEx+FVwUvp0MpNCHHa6nZPnDg3NDRU3k7YANMFQ2PRO2w0ut4FYvlyYODA/Bc4ywthTAwwYIDSURJZZ/laNWOyRkTFuHIuFZ+P+RZbVuxE0/aNMGnRKIR0aGK359dl6RCdGI2k1CSEBIUgsnUktL5auz2/NUKIXVLK0ELbPT5hc5cLheXvYY6frRTkStzhgxMROZ2UEvHLd+DzMd/ixuWbeGpyfzz31iD4VPQp1/PGn4lH+NJwGKUR+hw9/L39oREaxA2JQ9cGXe0UfWG2EjbPHnRQMMkxGk1fo6JM210pmRXClJSZ49domKyR6zC/Fy252nuQiBQhhEC3gZ2w6OBsPPJcd/zy8XKM7DAFh7YeLfNz6rJ0CF8aDl22DvocPQBAn6OHLtu0PT073V7hl5hnJ2yxsYWTGsukJzZW6QhLxxy/JSZrpHbu9MGJnENKUxlIwdeGre3kEbTVAvDqdy/jo9XTcUt/C+O7vomvXl2CrMysUj9XdGI0jNJodZ9RGhF9MLq84ZaaZydsERGm2i7LpMac9MTEmPa7ErZSkCtytw9O5HixsaaaXcvzm/n8N3AgXzMe7t7H2uObA7MR/mIYfp/9B0Z1nIqjO5NL9RxJqUm3W9YK0ufokXytdM9nD56dsAlhKsQv2AJla7uasZWCXFVxH5z692drCuUXEVH4/GZ5/nO1D9tkd/6VK2H8whGYseYNZOoyMfaB6Vj81i/Iyc4p0eNDgkLg7+1v/bm9/dEssJk9wy0RDjpwFxwlSu6Kr22yxl0GjJHDpd/QY8G477Dhx41odk9jTF0yBo1aFz3Pqi5Lh+DZwdBl6wrt0/pokTIpBQE+AQ6Jl6NELbhlwsZ52MhdcQQ02cKRxVQKm2N3YO5LX0GflokRM59D/9G9IIp4vahtlCgTNiJSP7amUEF8TVAZXL98E7OGf4Htq3bj3t73YPJ3L6NaLdsrqaRnpyP6YDSSryWjWWAzRLaJdFjLmhkTNgtM2IhcEFtTyIytrlQOUkr88eU6fPXqD6ik9cOkb19G58c7Kh3WbZyHjYjUq7hpGoxGjoCmOziymMpBCIF+Lz+GLxI+QWDdaniz38f4csLiEg9IUAoTNiJSXnHTNPTrxxHQdIe7TclEimjYqj7mb5uBiDG9ERO1ChO6vYkLJy8pHZZN7BIlIuUV1cXVpw+wahVHiRKRw2yK2Y5Zw78AALz63cvoOqCTYrGwhs0CEzYiFbJVRD57NrBiBUdAE5FDXThxCR88PQfHEo7jw1Wv477e9ygSBxM2C0zYiFSKAwuISEHZWTmI+2YD+o56FF5eXorEwEEHRKRuXFrNMbjuJlGJ+fh6I2J0b8WStaIwYSMi5XFpNcfhuptEboEJGxEpj9M0OA7X3XQ8tmKSE7CGjYiUx6XVHIurAjgW17slO+KgAwtM2IjI43BAh+Nw5QWyIw46ICLyVBzQ4VgFu/A1GiZrZHdM2IiI3BkHdDiHOWmzxGSN7IgJGxGRO+OADudgKyY5GBM2IiJ3xnU3HY+tmOQEFZQOgIiIHEgI6yMUbW2n0rPVigmYtnfvzr81lRsTNiIiovIwt2JaTj9jTtq6d2crJtkFEzYiIqLyYCsmOQFr2IiIiIhUjgkbERERkcoxYSMiIiJSOVUkbEKIXkKIo0KIZCHENCv7Jwsh9ubdDgohDEKIwLx9p4QQB/L2cb0pIiIicjuKDzoQQngBWADgEQDnAOwUQqyUUh4yHyOl/BTAp3nH9wUwQUp5zeJpekgprzoxbCIiIiKnUUML230AkqWUJ6SU2QB+AdC/iOMHA/ivUyIrjpTA8uWFJ0W0tZ2IiMgSryNUQmpI2IIBnLW4fy5vWyFCiEoAegFYZrFZAlgnhNglhBjhsCitiY0FBg7MP5O1ecbrgQO55AsRERWN1xEqIcW7RAFYWxnX1keKvgA2F+gO7SKlTBFC1ASwXghxREq5sdAPMSVzIwCgQYMG5Y3ZJCLizvIjgGmSRMvlSThZIhERFYXXESohNSRs5wDUt7hfD0CKjWOfRoHuUCllSt7Xy0KI5TB1sRZK2KSUXwP4GgBCQ0Pt08ZccPkR8xvOcnkSIiIiW3gdoRISUuH+cSFEBQDHAPQEcB7ATgDPSCkTCxxXBcBJAPWllPq8bf4ANFJKXd736wG8J6VcU9TPDA0NlQkJdhxQKiWgsehdNhr5JiMiopLjdYTyCCF2SSlDC25XvIZNSpkLYDSAtQAOA/hVSpkohBgphBhpcegAAOvMyVqeWgDihRD7AOwAsKq4ZM1OQd8pBjXXGlgaP56FokREVDLWriOWNW1EUEeXKKSUcQDiCmxbWOD+YgCLC2w7AaCdg8MrzFwkOnas6f68efm/nzfvTjM3PyEREZEt5mTNXLNmWcMG8DpCt6kiYXM5BYtErSVuUVFA9+5c+JeIiGyLjc2frBWsaeN1hPIoXsOmBLvUsElp6vqcN+/ONvMbDjC9CSMi+MmIiIhsk9L69cLWdnJ7tmrYmLCVB4tEiYiIyI5UO+jAZbFIlIiIXAFXU3ALTNjKomCRqNF4p6aNSRsREakJV1NwCxx0UBYsEiVyS7osHaITo5GUmoSQoBBEto6E1lerdFhE5cPVFNwCa9jKwhOKRD3hdySyEH8mHuFLw2GURuhz9PD39odGaBA3JA5dG3RVOjyi8rHsGTLjagqqxBo2exLiTguaZf+/re2uiE3o5EF0WTqELw2HLlsHfY5pbm59jh66bNP29Ox0hSMkKifLniAzV0nWWIMHgAlb+bhzUmPZhG7+/diETm4qOjEaRmm0us8ojYg+GO3kiIjszJUHyrnztbYUWMNWHu5cF8AFicmDJKUm3W5ZK0ifo0fytWQnR0RkR66+moI7X2tLgQlbebh7UmP+/SxrHtzh9yIqICQoBP7e/laTNn9vfzQLbKZAVER24uoD5dz9WltCHHRgD+46gS6LVMlD6LJ0CJ4dDF22rtA+rY8WKZNSEOAToEBkRHbgLoPI3PVaWwAHHTiKK9cFFIVzzZGrKkOBstZXi7ghcdD6aOHv7Q/A1LKm9TFtZ7JGLs08IK5gcmNruxq567W2NKSUHnfr2LGjtAujUcpx46QETF+t3XdVMTGFfw/L3y8mpnTPZzSaHlPwb2JrO1FZleO1q8vSyUW7Fslp66fJRbsWSV2WzklBE5FN7nyttQJAgrSSuyiePClxs1vCZu+kRk3snWC589+K1MXDTu5Ebs/Drh+2EjbWsJWHdJO6AGeQRYxSYl0c2Zvl682MrzMi1+Rh11pbNWxM2Mh5eBElZ5JlKFD2sAsDEakPBx2Q8lx5pm1yLeYPB5ZKUqDMCTqJSKWYsJHzlPUiSlQaBbvfSzPCmSt8EJFKceJcco6iatgAtrSR/ZRnklBO0ElEKsUaNnKO5ctNXUqWFz7LJC4mRt0zbZPrsEcdWlnq34iI7IA1bKSsiAhTUmb
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')\n",
"plot_decision_boundary(fig, theta, Xpl)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 24,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix([\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4) \n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5) \n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X5 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)).reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 25,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAT+0lEQVR4nO3df4zteV3f8dd79kKQmWlcwgXWhRbqnYCWP8TeEpRJQ0Xa9bZxW6OZNVFXc5NNm1Kx17RS20jSNC1pGlPbWJvNQtEUYQhi3dhblaJEb7Rk765bYbmSmVCF27u6lzbB2Wkb3M6nf5y5vdfLvXtnl5nve+6cxyPZnJnzPXPOO9987/Dk+2tqjBEAAKa10D0AAMA8EmEAAA1EGABAAxEGANBAhAEANBBhAAANDizCquq9VfVUVX3qmudeUlUfraqN3cc7D+rzAQAOs4PcE/a+JPdc99w7k3xsjLGS5GO73wMAzJ06yJu1VtWrk/ziGOP1u99/JslbxhhPVtVdST4+xnjtgQ0AAHBITX1O2MvHGE8mye7jyyb+fACAQ+FY9wA3U1UPJHkgSRYXF//86173uuaJAAD+pEcfffQLY4zjz+dnp46wP6yqu645HPnUzV44xngwyYNJcvLkyXH+/PmpZgQA2JOq+v3n+7NTH458OMn9u1/fn+QXJv58AIBD4SBvUfGBJL+V5LVVdbGqTid5d5K3VdVGkrftfg8AMHcO7HDkGOO7b7LorQf1mQAAtwt3zAcAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGx7oHAOCI2tpK1teTjY1kZSVZW0uWl7ungkNDhAGw/86dS06dSnZ2ku3tZHExOXMmOXs2WV3tng4OBYcjAdhfW1uzANvamgVYMnu88vzTT/fOB4eECANgf62vz/aA3cjOzmw5IMIA2GcbG1f3gF1vezvZ3Jx2HjikRBgA+2tlZXYO2I0sLiYnTkw7DxxSIgxgHm1tJQ89lPzIj8wet7b2773X1pKFm/zPy8LCbDng6kiAuXPQVy4uL8/e6/rPWFiYPb+09JV/BhwBIgxgnlx75eIVV87fOnUquXRpfyJpdXX2Xuvrs3PATpyY7QETYPD/iTCAebKXKxdPn96fz1pa2r/3giPIOWEA88SVi3BoiDCAeeLKRTg0RBjAPHHlIhwaIgxgnly5cnF5+eoescXFq887cR4m48R8gHnjykU4FEQYwDxy5SK0czgSAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKBBS4RV1d+tqieq6lNV9YGqelHHHAAAXSaPsKq6O8kPJjk5xnh9kjuS3Df1HAAAnboORx5L8lVVdSzJi5NcapoDAKDF5BE2xvjvSf5Fks8leTLJF8cYv3L966rqgao6X1XnL1++PPWYAAAHquNw5J1J7k3ymiRfk2Sxqr7n+teNMR4cY5wcY5w8fvz41GMCAByojsOR35rkv40xLo8x/jjJR5J8c8McAABtOiLsc0neVFUvrqpK8tYkFxrmAABo03FO2CeSfDjJY0k+uTvDg1PPAQDQ6VjHh44x3pXkXR2fDQBwGLhjPgBAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAg2PdAwC029pK1teTjY1kZSVZW0uWl7unAo44EQbMt3PnklOnkp2dZHs7WVxMzpxJzp5NVle7pwOOMIcjgfm1tTULsK2tWYAls8crzz/9dO98wJEmwoD5tb4+2wN2Izs7s+UAB0SEAfNrY+PqHrDrbW8nm5vTzgPMFREGzK+Vldk5YDeyuJicODHtPMBcEWHA/FpbSxZu8mtwYWG2HOCAiDBgfi0vz66CXF6+ukdscfHq80tLvfMBR5pbVADzbXU1uXRpdhL+5ubsEOTamgADDpwIA1haSk6f7p4CmDMORwIANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANCgJcKq6qur6sNV9btVdaGqvqljDgCALseaPvcnkvzSGOM7q+qFSV7cNAcAQIvJI6yq/lSSv5jk+5NkjPGlJF+aeg4AgE4dhyP/bJLLSf5dVf12VT1UVYsNcwAAtOmIsGNJvjHJT40x3pBkO8k7r39RVT1QVeer6vzly5ennhEA4EB1RNjFJBfHGJ/Y/f7DmUXZnzDGeHCMcXKMcfL48eOTDggAcNAmj7Axxh8k+XxVvXb3qbcm+fTUcwAAdOq6OvLvJHn/7pWRn03yA01zAAC0aImwMcbjSU52fDYAwGHgjvkAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA1EGABAAxEGANBAhAEANBBhAAANRBgAQAMRBgDQQIQBADQQYQAADUQYAEADEQYA0ECEAQA0EGEAAA2OdQ8AzJmtrWR9PdnYSFZWkrW1ZHm5eyqAyYkwYDrnziWnTiU7O8n2drK4mJw5k5w9m6yudk8HMCmHI4FpbG3NAmxraxZgyezxyvNPP907H8DERBgwjfX12R6wG9nZmS0HmCMiDJjGxsbVPWDX295ONjennQegmQgDprGyMjsH7EYWF5MTJ6adB6CZCAOmsbaWLNzkV87Cwmw5wBwRYcA0lpdnV0EuL1/dI7a4ePX5paXe+QAm5hYVwHRWV5NLl2Yn4W9uzg5Brq0JMGAuiTBgWktLyenT3VMAtHM4EgCgwS0jrKreXlV3TjEMAMC82MuesFckeaSqPlRV91RVHfRQAABH3S0jbIzxj5KsJHlPku9PslFV/7SqvvaAZwMAOLL2dE7YGGMk+YPd/55JcmeSD1fVPz/A2QAAjqxbXh1ZVT+Y5P4kX0jyUJK/N8b446paSLKR5O8f7IgAAEfPXm5R8dIk3zHG+P1rnxxj7FTVXzuYsQAAjrZbRtgY48eeZdmF/R0HAGA+uE8YAEADEQYA0ECEAQA0EGEAAA1EGABAg7YIq6o7quq3q+oXu2YAAOjSuSfsHUnc4gIAmEstEVZVr0zyVzO7Az8AwNzp2hP2LzP7c0c7N3tBVT1QVeer6vzly5enmwwAYAKTR9junzp6aozx6LO9bozx4Bjj5Bjj5PHjxyeaDgBgGh17wt6c5Nur6veSfDDJt1TVv2+YAwCgzeQRNsb4B2OMV44xXp3kviS/Osb4nqnnAADo5D5hAAANjnV++Bjj40k+3jkDAEAHe8IAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoIEIAwBoIMIAABqIMACABiIMAKCBCAMAaCDCAAAaiDAAgAYiDACggQgDAGggwgAAGogwAIAGIgwAoMHkEVZVr6qqX6uqC1X1RFW9Y+oZAAC6HWv4zGeS/PAY47GqWk7yaFV9dIzx6YZZAABaTL4nbIzx5Bjjsd2vt5JcSHL31HMAAHRqPSesql6d5A1JPtE5BwDA1NoirKqWkvxckh8aY/zRDZY/UFXnq+r85cuXpx8QAOAAtURYVb0gswB7/xjjIzd6zRjjwTHGyTHGyePHj087IADAAZv8xPyqqiTvSXJhjPHjU38+kGRrK1lfTzY2kpWVZG0tWV7ungpgrnRcHfnmJN+b5JNV9fjucz86xjjbMAv
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 26,
2021-03-02 08:32:40 +01:00
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-14 08:03:54 +02:00
"[<matplotlib.lines.Line2D at 0x27dd8b81340>]"
2021-03-02 08:32:40 +01:00
]
},
2021-04-14 08:03:54 +02:00
"execution_count": 26,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXiU9b3+8fubHSZhD/sOkR0SS13Ro3VHLXXBhJ722Nbfse2pBY0b7latlbpC29MeThfb05YEXFHRui9oXdCEsJOwh7AkLGGyL/P9/TGxUJpAEmbmO8v7dV1eSeaZzHNfz/VkvHmWzxhrrQAAABBaca4DAAAAxCJKGAAAgAOUMAAAAAcoYQAAAA5QwgAAAByghAEAADgQtBJmjPm9MWavMWb1EY/1Msa8YYwpbvnaM1jrBwAACGfBPBL2tKSLj3psrqS3rLUZkt5q+RkAACDmmGAOazXGDJf0srV2YsvPGySdY63dZYwZIOlda+2YoAUAAAAIU6G+JqyftXaXJLV87Rvi9QMAAISFBNcB2mKMuV7S9ZLk8Xi+MnbsWMeJAAAdUVZZq31VDUpKiFNG31TFGeM6EhBwn3/+eYW1Nr0zvxvqErbHGDPgiNORe9t6orV2oaSFkjR16lS7YsWKUGUEAJygdzbs1Xf/8JmGxBk9+8MzNGVID9eRgKAwxmzr7O+G+nTkUknXtnx/raQXQ7x+AECQlXvrdeuSlZKkmy8cQwED2hDMERWLJP1d0hhjTKkx5jpJj0i6wBhTLOmClp8BAFHCWqtbn1mpiqoGnT6yt75/9kjXkYCwFbTTkdbaWW0sOi9Y6wQAuPX0R1v17oZy9eiaqCeypygujuvAgLYwMR8AEBDrdh3Sz15dL0l65MrJGtC9i+NEQHijhAEATlhdY7NmLypQQ5NPs04Zqosn9ncdCQh7lDAAwAn76SvrVLy3SiPTPbrnsnGu4wARgRIGADghb67do//7eJsS440W5GSpa1LYjqAEwgolDADQaXsP1em2Z4skSbddNFYTB3V3nAiIHJQwAECn+HxWNy9Zqf3VDToro4+umzbCdSQgolDCAACd8rvlW/RBcYV6eZL0+EzGUQAdRQkDAHTY6p2V+vnf/OMo5l01WX27pThOBEQeShgAoENqGpo0O69Ajc1W3z5tmC4Y3891JCAiUcIAAB3y4MvrtLm8Whl9U3XXpYyjADqLEgYAaLfXVu/Wok+3KykhTgtmZSklMd51JCBiUcIAAO2yq7JWc5/zj6OYe/FYjRvQzXEiILJRwgAAx9Xss8rNX6mDNY06Z0y6vnvmcNeRgIhHCQMAHNfC9zfr75v3qU9qkh69eoqMYRwFcKIoYQCAY1q546Aef32DJOnRq6coPS3ZcSIgOlDCAABtqq5v0py8AjX5rL5zxnCdO7av60hA1KCEAQDadP/SNdq6r0Zj+6dp7iVjXccBogolDADQqpeLyrTk81IlM44CCApKGADgX5QeqNEdz62SJN196Tid1C/NcSIg+iS4DgAACC9fjqPw1jXp/HF99a3ThnXuhbxeKT9fKi6WMjKk7GwpjTIHfIkSBgD4J//9Tok+3bpf6WnJmnfV5M6No1i+XJo+XfL5pOpqyeORcnOlZcukadMCHxqIQJyOBAD8w+fbDuipt4olSY/PnKLeqZ0YR+H1+guY1+svYJL/65ePV1UFMDEQuShhAABJkreuUTfmF6jZZ/WfZ43Q2Seld+6F8vP9R8Ba4/P5lwOghAEA/O59cY127K/V+AHddMtFYzr/QsXFh4+AHa26Wiop6fxrA1GEEgYA0AsFO/V8wU6lJPrHUSQnnMA4iowM/zVgrfF4pNGjO//aQBThwnwAiEVH3Lm4Y/hY3b1rgCTp3ssmaHTf1BN77exs/0X4rYmL8y8HQAkDgJhzxJ2LTTW1mvPtR1U1oK8u6p+gWacMOfHXT0vz3wV59N2RcXH+x1NPsOQBUYISBgCx5Mg7FyUtmPZNfTFgjPp7K/TI7+6Q+X8bA1OSpk2Tysr8R9tKSvynILOzKWDAEShhABBLjrhz8bNB4/XL07NlrE9PvPyEetYe8i+/7rrArCs1NXCvBUQhShgAxJKWOxcrkz268fJb5IuL1w8+XqIzthf5l3PnIhAylDAAiCUZGbIej+4670fa2b2vJu/aqNwP/uJfxp2LQEgxogIAYkl2tp4dd45eHne2ujbUav5LjynJ1+Rfxp2LQEhxJAwAYsjW+jjdd+EPpWbp/g+e1ogDZdy5CDhCCQOAGNHY7NOc/EJVN0uXjkvXzNFXSOdP4s5FwBFKGADEiKfe3KiVOw5qYPcUPTwzS6brKa4jATGNa8IAIAZ8vHmf/vvdTYoz0pPZmereNdF1JCDmUcIAIModrGnQTfmFslb60bmjderI3q4jARAlDACimrVWdz6/Srsq65Q5pIdmn5fhOhKAFpQwAIhii1fs0LJVu5WanKAFOVlKjOdtHwgX/DUCQJTaVF6l+5eulSQ9MGOChvbu6jgRgCNRwgAgCjU0+XRjXqFqG5s1I3Ogrsga5DoSgKNQwgAgCj3+xgat2lmpwT276MFvTJQxxnUkAEehhAFAlPmwpEL/895mxRlpfk6muqUwjgIIR5QwAIgi+6sblLu4UJI0+7wMfWVYL8eJALSFEgYAUcJaq9ufLdKeQ/WaOqynbjh3tOtIAI6BEgYAUeKvn27XG2v3KC05QU9mZyqBcRRAWOMvFACiQMlerx582T+O4qdXTtKQXoyjAMIdJQwAIlx9U7N+vKhQdY0+XXnyIH19ykDXkQC0g5MSZoy5yRizxhiz2hizyBiT4iIHAESDn7+2Qet2HdLQXl31wIyJruMAaKeQlzBjzCBJsyVNtdZOlBQvKSfUOQAgGry3sVy/W75F8XFG83MylZqc4DoSgHZydToyQVIXY0yCpK6SyhzlAICIVVFVr5sXr5Qk5V5wkrKG9nScCEBHhLyEWWt3SnpM0nZJuyRVWmtfP/p5xpjrjTErjDErysvLQx0TAMKatVa3PVOkiqp6nTqil37wb6NcRwLQQS5OR/aUNEPSCEkDJXmMMd86+nnW2oXW2qnW2qnp6emhjgkAYe1Pf9+mt9fvVbcU/ziK+Dg+lgiINC5OR54vaYu1ttxa2yjpOUlnOMgBABFpw26vfrpsnSTpkasma2CPLo4TAegMFyVsu6TTjDFdjf8TZc+TtM5BDgCIOHWNzZq9qEANTT5lTx2i6ZMGuI4EoJNcXBP2iaRnJH0haVVLhoWhzgEAkeiRV9drwx6vRvTx6N7Lx7uOA+AEOLmX2Vp7n6T7XKwbACLV2+v36OmPtiox3mhBTpY8jKMAIhoT8wEgAuz11unWJUWSpJsvHKNJg7s7TgTgRFHCACDM+XxWtywp0r7qBp0xqreuP2uk60gAAoASBgBh7g8fbdX7G8vVo2uinrgmU3GMowCiAiUMAMLYmrJKzXt1vSRp3lWT1b87H7ULRAtKGACEqdqGZs3JK1RDs0/fPHWoLprQ33UkAAFECQOAMPXQK2tVsrdKo9I9uudSxlEA0YYSBgBh6PU1u/WXT7YrKT5OC2ZlqUtSvOtIAAKMEgYAYWbPoTrd/qx/HMVtF4/RhIGMowCiESUMAMKIz2eVu7hQB2oadVZGH33vzBGuIwEIEkoYAISR//1gsz4s2afeniQ9fs0UxlEAUYwSBgBhYlVppR57fYMk6edXT1bfNMZRANGMEgYAYaCmoUlz8grU2Gx17enDdN64fq4jAQgyShgAhIEHXlqrzRXVGtMvTXdMH+c6DoAQSHAdAACc83ql/HypuFjKyJCys6W0tJCt/tVVu5T32Q4lJcRp/qxMpSQyjgKIBZQwALFt+XJp+nTJ55OqqyWPR8rNlZYtk6ZNC/rqyw7Wau5zqyRJd14yVmP7dwv6OgGEB05HAohdXq+/gHm9/gIm+b9++XhVVVBX3+yzuim/UJW1jTp3TLquPWN4UNcHILxQwgDErvx8/xGw1vh8/uVB9Jv3NumTLfvVJzVZj86cImMYRwHEEkoYgNhVXHz4CNjRqqulkpKgrbp
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, _bias_) – zachodzi **niedostateczne dopasowanie** (_underfitting_)."
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 27,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-14 08:03:54 +02:00
"[<matplotlib.lines.Line2D at 0x27e2530e820>]"
2021-03-02 08:32:40 +01:00
]
},
2021-04-14 08:03:54 +02:00
"execution_count": 27,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3yV5cH/8e91ssmAACFA2CTsKWGDIm4caBVxVlstWldt+1StHc/z9Pm1tba1tVatVq0LFQsOFMRRQUWGhr1JGBkEQoCQBZnn+v1xAiKCBEjOdcbn/XrxSnLOCfl6v26O31zXfV+XsdYKAAAA/uVxHQAAACAcUcIAAAAcoIQBAAA4QAkDAABwgBIGAADgACUMAADAgWYrYcaY54wxu40xa494rLUx5kNjTHbDx+Tm+vkAAACBrDlHwp6XdOFRjz0g6T/W2gxJ/2n4GgAAIOyY5lys1RjTTdK71toBDV9vkjTBWrvTGNNB0gJrbe9mCwAAABCg/H1NWKq1dqckNXxs5+efDwAAEBAiXQc4HmPMNEnTJCk+Pn5Ynz59HCcCAAD4umXLlu2x1qacyvf6u4QVGWM6HDEduft4L7TWPi3paUnKzMy0WVlZ/soIAADQKMaY3FP9Xn9PR86WdFPD5zdJetvPPx8AACAgNOcSFa9KWiyptzGmwBhzi6SHJJ1njMmWdF7D1wAAAGGn2aYjrbXXHuepc5rrZwIAAAQLVswHAABwgBIGAADgACUMAADAAUoYAACAA5QwAAAAByhhAAAADlDCAAAAHKCEAQAAOEAJAwAAcIASBgAA4AAlDAAAwAFKGAAAgAOUMAAAAAcoYQAAAA5QwgAAAByghAEAADhACQMAAHCAEgYAAOBApOsAAIDwYa1Vndeqps7r+1Pv+1h91Ne+z+u/9lyrFtE6o0srtUmIcf2fATQJShgAoEmVHqzV5zl79MnaQn2xvkDlNV5VR0SqxhOpmnqvrD29v79H23id0TVZmV2TNaxrsnqmJMjjMU0THvAjShgA4LRYa7WusEyfbC7WJ5uKtSyvRPXeQ00rUjKSvJK8XklShMcoOsKj6MiGPxEexUR+/etjfb6j5KBWFezX1j2V2rqnUjOXFUiSWsZFaVhDIRvWNVmDO7VSXHSEk2MBnAxKGADgpJUeqNVnOcVasKlYn2wuVnF59eHnIow0YscGTcheqjO3r1C7in2Krq/1/YmLVeSOAikh4ZR+bm29V+sLy7Qst0TLckuUlbtPRWXV+njjbn28cbckKdJj1L9jkoZ1ba3Mbr5ilpoU2yT/3UBTMvZ0x4X9IDMz02ZlZbmOAQBhy+v1jXYt2LRbCzYXa0VeibxH/O+jfVKsJvRO0Vm9UjR20Vwl/fRHUmXlN/+i+Hjp0UelW25pklzWWu3Yf/CrUra9RBt3lX0tmyR1So47PH05vHtr9U5NlDFMYeL0GWOWWWszT+V7GQkDABxTSWWNPs32jXR9urlYeypqDj8X6TEa2b21zuqdogm9U75eal7efOwCJvkez8lpsozGGHVKbqFOyS00eUiaJKmiuk4r8/YrK3efluWWaEXefhWUHFRByUG9tbJQknRGl1a6e2KGJvROoYzBGUoYAOCwmjqvZmTl643lBVqVv/9rI0odW8bqrN7tNKF3isamt1VCzHH+F5KR4RvxOt5IWHp684RvkBATqXEZbTUuo60kqd5rtbmoXFm5JVqeW6L5m3Zred5+fe/5LzUgLUl3nZ2h8/ulcnE//I7pSAAIR+Xl0owZUna2lJEh79VXa862Cv3pg03K3XtAkhQVYTSie2tN6OUrXuntEho3alReLqWl+T4eLTFRKiw85WvCmkJldZ1eWZqnpz7dqj0VvmvZeqcm6s6J6bp4YAdFUMZwEk5nOpISBgDhZuFCadIk392KlZVa1GuEHhp7vVa36ylJ6pkSrx+d20vn9Gmn+OONdp3kz1B8vOTxSHPnSuPGNeF/zKmrqq3XjC/z9Y9PtmhnaZUk3/IXd5ydrslDOioqgvXMcWKUMABA4xwxSrU+pbv+MOEmfdLD9/+PdpUl+vGUkZoytqcim6KAVFT4RttycnxTkFOnOh0BO56aOq9mLS/QEwtylL/voCTfhfw/nNBTVw3rpJhIlrvA8VHCAACN88wzyv/Vb/XIsO/orf4TZI1HidWVun3JTH1vw3/U4s8PN9mdi8Gmtt6r2SsL9fiCHG0t9l3P1j4pVred1UPXDO/C2mM4Ju6OBACcUElljR7PrteL1/9FNZFRiqqv1Y3LZ+uuxa+r9cEy34ua8M7FYBMV4dGVwzrp8qFpmrtmp/7+cY42FZXrf99Zr8fn5+gH43vo+lFdj39DAnCSOJMAIMQdrKnXc59v0z8WbFG56SRFSpevm6+ffvayOpcWffVCP9y5GAwiPEaXDu6oiwd20EcbivTYxzlas6NUv39vo578ZIu+P7a7bhrTTS3jolxHRZBjOhIAQlRdvVczlxXoLx9tVlGZ7y7A8T2Sdf9Dt2vAtjXf/IYAuHMxEFlr9cnmYj32cY6W5ZZIkhJjInXTmG66ZVx3JcdHO04Il7gmDABwmLVWH64v0sPvb1LO7gpJ0oC0JD1wYV/f2llBcOdiILLWavHWvfr7xzlatGWvJKltQrT+MnWIxmekOE4HVyhhAABJUtb2fXrovY3Kahix6dw6Tv91fm9dOqjj1xcjDZI7FwPVstx9+sO8Tfpi2z4ZI90xoad+fG6vprmrFEGFEgYAYa64vFq/fGuN3l/nu8ardXy07p6YrutHdlV0JMWgOdR7rR6fn6O/frRZXisN75asv107VB1axrmOBj+ihAFAGFuVv1+3v7xMO0urFBcVoVvHd9e0M3soMZYLx/1hyda9+tFrK1RUVq3kFlH689WDNbFPqutY8JPTKWH8egQAQez1rHxNeWqxdpZW6YwurfTxf52ln57fmwLmR6N6tNHce8ZrQu8UlRyo1fefz9Jv56xXTZ3XdTQEOEoYAAShmjqvfv32Wt03c7Vq6ry6fmQXvTZtNFNhjrRJiNFzNw3Xzy/qowiP0T8/26YpTy1W/r4DrqMhgFHCACDI7C6v0vXPLNGLi3MVHeHRQ98ZqN9eMZBrvxzzeIxuO6unXr9ttNJaxWlV/n5N+ttnem/NTtfREKD4FwsAQWRFXokufWyhvtxeotSkGM24bZSuGdHFdSwcYVjXZM25Z5zO75eq8qo6/XD6cv367bWqqq13HQ0BhhIGAEFixpd5mvrUEhWVVWt4t2S9c/c4De2S7DoWjqFVi2g9deMw/c+l/RQd4dGLi3P1nScWaWtxhetoCCCUMAAIcDV1Xv3izTW6f9Ya1dR79d3RXTX91lFqlxjrOhq+hTFGN4/trlk/HKOubVpo/c4yXfrYQr29cofraAgQlDAACGC7y6p07T+XaPrSPEVHevTwVYP0m8kDuP4riAzs1FLv3j1OlwzqoMqaev3otZW6f+ZqHaxhejLc8a8YAALUstx9uuSxhVqWW6IOLWP179tG6+rMzq5j4RQkxkbpsWuH6ndXDFRMpEczsvI1+fGF2lxU7joaHKKEAUAAemVpnq55eol2l1drRPfWeufucRrcuZXrWDgNxhhdN7KL3rpzrHqmxGtzUYUu+/tCvf5lvoJh4XQ0PUoYAASQ6rp6/fyN1XrwzTWqrbe6eUw3Tb91pNomxLiOhibSt0OSZt81Tlee0UlVtV7dN2u1fjxjJXdPhqFI1wEAAD5FZVW6/eVlWpG3X9GRHv3uioG6algn17HQDOJjIvXnqwdrdM82+tVba/XWykLtqajRMzdlKjYqwnU8+AkjYQAQALK2+67/WpG3Xx1bxmrW7WMoYGHgqmGd9NadY9U2IUYLc/bo1heyGBELI05KmDHmx8aYdcaYtcaYV40x3GcNICxZa/XSklxd8/QSFZdXa1QP3/VfAzu1dB0NftK7faJe/cFIilgY8nsJM8akSbpHUqa1doCkCEnX+DsHALhmrdVv3l2vX721VnVeq1vGddfLt4xUG67/CjsZqRSxcORqOjJSUpwxJlJSC0mFjnIAgDM
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 28,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2021-04-14 08:03:54 +02:00
"[<matplotlib.lines.Line2D at 0x27e250169d0>]"
2021-03-02 08:32:40 +01:00
]
},
2021-04-14 08:03:54 +02:00
"execution_count": 28,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFoCAYAAAAfEiweAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU1eH///fJvhMgCzsBEkABFUFEBeqCG221tSK471pbq1Zr66efbp/+PvXbRWu1WvtRUNGqgFatVeuKG4LsyCokQIAQyAYkkz0zc35/TLBUWRLIzJnl9Xw8fCSZmcy8vY/L5c2595xrrLUCAABAaMW5DgAAABCLKGEAAAAOUMIAAAAcoIQBAAA4QAkDAABwgBIGAADgQNBKmDHmCWNMpTFmzX6P9TDGvGOMKW7/2j1Ynw8AABDOgjkS9pSk87702D2S3rPWFkl6r/1nAACAmGOCuVirMaZA0mvW2pHtP2+QdLq1dqcxprekD6y1w4IWAAAAIEyF+pqwfGvtTklq/5oX4s8HAAAICwmuAxyMMeYmSTdJUnp6+pjhw4c7TgQACFdev9WGXR75rdWgnHRlJIftX2+IMsuWLau21uYeye+Gei+tMMb03u90ZOXBXmitfUzSY5I0duxYu3Tp0lBlBABEmP9+ebX2Ltqms4bnaeY1J7mOgxhijNl6pL8b6tORr0q6uv37qyX9I8SfDwCIMiWV9Zq9ZLvijHTP+Zw1QeQI5hIVz0taKGmYMabMGHO9pN9KOtsYUyzp7PafAQA4Yr/91+fy+a2mjxugovxM13GADgva6Uhr7aUHeeqsYH0mACC2fLq5Ru+ur1BaUrzumFzkOg7QKayYDwCISH6/1b1vrJck3TxpiPIyUxwnAjqHEgYAiEj/XFWuVWW1ystM1o2TBrmOA3QaJQwAEHGa23z6/ZsbJEl3nTNUaUksSYHIQwkDAEScpxeWasfeJg3Lz9TFY/q7jgMcEUoYACCi7G1s1cPzSiRJ90wZrvg44zgRcGQoYQCAiPLneSWqa/ZqQmGOTh96RAuVA2GBEgYAiBhbaxr09MJSGSP915ThMoZRMEQuShgAIGL8/q0NavNZfXt0X43o0811HOCoUMIAABFh+bY9en3VTiUnxOlH5wxzHQc4apQwAEDYs9bq3tcDC7NeP2GQ+mSnOk4EHD1KGAAg7L21tkJLt+5Rj/Qkfff0Ia7jAF2CEgYACGsLSqr1q1fXSpLumFykrJREx4mArsESwwCAsLS3sVW/eX29XlhWJkk6oX+2Lh03wHEqoOtQwgAAYcVaq3+u2qlf/3OtqutblRQfpx+cWaibvzZEifGcwEH0oIQBAMLGjr1N+vkrazTv80pJ0rhBPfT/LhqlIbkZjpMBXY8SBgBwzue3enphqf7w1gY1tvqUmZKgn045RtPG9lcctyVClKKEAQCc+nxXne75+2qt3L5XknT+yF76nwtGKC8rxXEyILgoYQAAJ5rbfHp4Xon++uEmef1W+VnJ+v8uHKlzRvRyHQ0ICUoYACDkPt1co5++tFqbqxskSVeMH6Afnzec5ScQUyhhAIDg8HikOXOk4mKpqEiaNk21CSn67b/W6/nF2yVJhXkZ+u1FozS2oIfjsEDoUcIAAF1v/nxpyhTJ75caGmTT0/Wvh2frlxfeqaoWq8R4o++fUahbTh+i5IR412kBJyhhAICu5fEECpjHI0naldFTPz/7u3pn6ClSi9WYfln67dQTVJSf6Tgo4BYlDADQtebMkd9vVZ6Vq3cLT9Z9k65SfXKaMloa9ZOFz+nyG76huPyJrlMCzlHCAABHrMXr05bqBm2qbNCmqnqVVNZrU3GGNt/0lJqS/r3ExNkbF+rX7/5VvT010hnDHSYGwgclDABwWLWNbSqp8mhTZYNKquq1qbJeJVX12r67UX77pRebTClJyq3frcKa7bpq+es6b+MCGUlKT5cKCx38HwDhhxIGALHoADMXbUaGKj0t2rDLo5L2krWpsl6bqupVXd96wLeJM9KgnHQNyU3XkLwMDcnNUGFGnIZMHKtu1bsO8Atx0rRpQf6fAyIDJQwAYoz9+GNVX3yZNnbvq40Zedq4tlXFHz+mjQOGq67twL+TmhivIXnpgZKVm6EheRkqzMvQwJ5pB57d+PIL/zE7UunpgQL2xhtSBveBBCRKGABEtd0NrdpY4fn3f+W1Kt6wQ3uu/etXX9wmdUtJ0LBeWSrMD5StwrxA4eqdldK5ezhOmCCVlwdG20pKAqcgp02jgAH7oYQBQBSw1mr1jlqtKqtVcYVHGyvqVVzpOfBpxNRMZbY0qKh6m4ZVbVVR9TYNrd6qoY3Vyr33f2RuuL5rQmVkSNd30XsBUYgSBgARzFqrj4qr9dB7xVq2dc9Xnk9PildhfqaG5mVoWK9MFb02V0MfvU+9PDU64LjWppKgZwYQQAkDgAhkrdX7Gyr14Hsl+mz7XklSdlqizhyep6H5mRqWn6mi/Az16Zb6n6cRN+RI/uYDvykzF4GQooQBQASx1uqddRV6aF6x1uyokyT1TE/SjZMG64rxA5WRfJjD+rRp0p13Hvg5Zi4CIUUJA4AI4Pdbvbl2l/48r0TrdwbKV05Gsr77tcG67OQBSkvq4OE8MzMwQ5GZi4BzlDAACGM+v9Xrq3fq4XnF2lhRL0nKz0rWd782RJeOG6CUxCO4+TUzF4GwQAkDgDDk9fn1z1Xl+vO8Em2uapAk9emWolvOKNTUMf2OrHztj5mLgHOUMAAII20+v15ZsUOPvF+i0ppGSVK/7qn6/hmF+s6J/ZSUEOc4IYCuQgkDgDDQ6vXrpeVleuSDEm3f3SRJGtgzTd8/o1DfHt1XifGULyDaUMIAwKEWr08vLC3Tox9s0o69gfI1OCddt55ZqAuO76MEyhcQtShhAOBIXXObrntyiZa2L7JamJehH5xZqG8c10fxnblFEICIRAkDAAdqG9t01ZOL9dn2verdLUU/+/qxOn9kr87dnxFARKOEAUCI7Wlo1RUzF2lteZ36dU/V8zeOV/8eaa5jAQgxShgAhFB1fYuumLFIn+/yaGDPND1/43j1yU51HQuAA5QwAAiRyrpmXT5jkYor6zU4N13P3zhe+VkprmMBcIQSBgAhsKu2WZc9/qk2VzdoaH6Gnr1hvHIzk13HAuAQJQwAgqxsT6Mue3yRtu1u1DG9s/S368epZwYFDIh1lDAACKJtNY269PFPtWNvk0b17aZnrh+n7LQk17EAhAFKGAAEyZbqBl32+KfaWdus0QOy9dS149QtNdF1LABhghIGAEFQUunRZY8vUqWnRScVdNcT15ykzBQKGIB/c3I/DGPMD40xa40xa4wxzxtjmB4EIGps2OXR9Mc+VaWnRacM7qmnrh1HAQPwFSEvYcaYvpJukzTWWjtSUryk6aHOAQDBsLa8VtMfW6jq+lZNLMrRE9ecpPRkTjoA+CpXR4YESanGmDZJaZLKHeUAgC6zqmyvrpy5WLVNbTpzeJ7+cvmJSkmMdx0LQJgK+UiYtXaHpPskbZO0U1KttfbtL7/OGHOTMWapMWZpVVVVqGMCQKcs27pHlz++SLVNbTrn2Hz99YoxFDAAh+TidGR3SRdKGiSpj6R0Y8wVX36dtfYxa+1Ya+3Y3NzcUMcEgA5bvGW3rpq5SJ4Wr74+qrceufxEJSU4ueQWQARxcZSYLGmLtbbKWtsm6SVJpzrIAQBHbUFJta5+YrEaWn361gl99OD0E5QYTwEDcHgujhTbJI03xqQZY4yksyStd5ADAI7KRxurdO1TS9TU5tPFY/rp/ktOUAIFDEAHubgmbJGkFyUtl7S6PcNjoc4BAEdj3ucVumHWUrV4/bp03AD9/jvHKT7OuI4FIII4mR1prf2lpF+6+GwAOFrryuv03WeWq9Xn19WnDNSvLhihwMA+AHQc4+YA0AmtXr9+9MJnavX5NW1sfwoYgCNGCQOATnjk/RKt21mnAT3S9ItvHksBA3DEKGEA0EFrdtTqkfdLJEm/v/g4VsIHcFQoYQDQAftOQ3r9VtecWqDxg3u6jgQgwlHCAKAD/jy
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (_variance_) – zachodzi **nadmierne dopasowanie** (_overfitting_)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu – to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
2021-04-07 15:03:18 +02:00
"### Obciążenie (błąd systematyczny, *bias*)\n",
2021-03-02 08:32:40 +01:00
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
2021-04-07 15:03:18 +02:00
"### Wariancja (*variance*)\n",
2021-03-02 08:32:40 +01:00
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.3. Regularyzacja"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 29,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(h, fJ, fdJ, theta, X, Y, \n",
" alpha=0.001, maxEpochs=1.0, batchSize=100, \n",
" adaGrad=False, logError=False, validate=0.0, valStep=100, lamb=0, trainsetsize=1.0):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na wykładzie 11)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
" \n",
" XT, YT = X, Y\n",
" \n",
" m_end=int(trainsetsize*len(X))\n",
" \n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv] \n",
" XT, YT = X[mv:m_end], Y[mv:m_end] \n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
" \n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n,1)\n",
" \n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end,:], YT[start:end,:]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
" \n",
" if logError:\n",
" errorsX.append(float(i*batchSize)/m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i*batchSize)/m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
" \n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 30,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:,0], data[:,1], n)\n",
"Y = data[:,2]"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 31,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" plt.subplot(121)\n",
" plt.scatter(X[:, 2].tolist(), X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100, cmap=plt.cm.get_cmap('prism'));\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=alpha, adaGrad=adaGrad, maxEpochs=maxEpochs, batchSize=100, \n",
" logError=True, validate=validate, valStep=1, lamb=lamb)\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02),\n",
" np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1),yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
" plt.ylim(-1,1.2);\n",
" plt.xlim(-1,1.2);\n",
" plt.legend();\n",
" plt.subplot(122)\n",
" plt.plot(err[0],err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2],err[3], lw=3, label=\"Validation error\");\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8);"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 32,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-04-14 08:03:54 +02:00
"<ipython-input-17-7e1d5e247279>:5: RuntimeWarning: overflow encountered in exp\n",
2021-04-07 15:03:18 +02:00
" y = 1.0/(1.0 + np.exp(-x))\n",
2021-04-14 08:03:54 +02:00
"<ipython-input-31-f0220c89a5e3>:19: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2021-04-07 15:03:18 +02:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
2021-03-02 08:32:40 +01:00
"No handles with labels found to put in legend.\n"
]
},
{
"data": {
2021-04-14 08:03:54 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHWCAYAAABOj2WsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3hUxdfA8e9k0wuh9yogXSkBRESkS1UEpVlAEQFpgnQUXnpvCtgR4Sdgo1elN6VI712CtNBCetn7/jFECNlN3exmk/N5njwhe2fvPQkpe+7MOaMMw0AIIYQQQgghhMjoXBwdgBBCCCGEEEIIkRySwAohhBBCCCGEcAqSwAohhBBCCCGEcAqSwAohhBBCCCGEcAqSwAohhBBCCCGEcAqSwAohhBBCCCGEcAqSwAohhBBZjFLqZaXUaaXUOaXUEAvH/ZVSq5RSh5VSx5VSXRwRpxBCCPEkJfvACiGEEFmHUsoEnAEaAYHAPqCDYRgnHhszDPA3DGOwUioPcBrIbxhGlCNiFkIIIeLIDKwQQgiRtdQAzhmGceFhQroEeOWJMQbgp5RSgC9wB4ixb5hCCCFEQpLACiGEEFlLIeDKYx8HPnzscZ8D5YB/gaNAX8MwzPYJTwghhLDO1dEBpEbu3LmN4sWLOzoMIYQQmcSBAweCDMPI4+g47ERZeOzJeqImwCGgPlAS+F0ptcMwjOB4J1KqG9ANwMfHp1rZsmXTFtn1I2CO1f8u8AwoU9rOJ4QQwmlZ+9vslAls8eLF2b9/v6PDEEIIkUkopS47OgY7CgSKPPZxYfRM6+O6ABMN3SjjnFLqIlAW2Pv4IMMwvgK+AggICDDS/Ld5QhGIfJgjD9kKnv5pO58QQginZe1vsywhFkIIIbKWfUBppVQJpZQ70B5Y+cSYf4AGAEqpfEAZ4IJdoxRCCCEscMoZWCGEEEKkjmEYMUqpXsAGwAR8ZxjGcaVU94fHvwDGAN8rpY6ilxwPNgwjyGFBCyGEEA9JAiuEEEJkMYZhrAXWPvHYF4/9+1+gsb3jike2+RNCCGGBJLBCCCEylejoaAIDA4mIiEhwzNPTk8KFC+Pm5uaAyETSLPWXEkIIIR6RBFYIIUSmEhgYiJ+fH8WLF0dvY6oZhsHt27cJDAykRIkSDoxQCCGEEKklTZyEEEJkKhEREeTKlSte8gqglCJXrlwWZ2aFEEII4RwkgRVCCJHpPJm8JvW4yIikBlYIIURCksAKIYQQImOQ+wtCCCGSIAmsEEIIIYQQQginIAmsEEKITMewsgWLtceFEEII4RykC7EQtnTiBCxdCkFBUKQIdOwIRYs6OiohshRPT09u376doJFTXBdiT09PB0YnhBBCiLSQBFYIWwgKgtdeg/37IToaYmLA3R1GjYKWLeGHH8DLy9FRCpElFC5cmMDAQG7dupXgWNw+sMIJyGy5EEIICySBFSKtQkKgVi24fFknr3GiovT71auhaVPYtAlMJsfEKEQW4ubmJvu8Oi3p4iSEECJxUgMrRFrNmwdXr8ZPXh8XEQEHDuhEVgghhBBCCJFqksAKkRaGAdOnQ3h44uNCQmDyZPvElJEYhq4L3rULzp93dDRCCCGEEMLJSQIrRFqEhen61+Q4dix9Y8lIDAO++w5KlIAaNaB5c6hUSb+tXOno6IQQQgghhJOSBFaItHBxSX6jEZcs8uNmGNCzJ/Tpo+uCQ0Ph/n09S33sGHTooGethRDiSUpqYIUQQiQui7yiFiKdeHnpWcakKAXPP5/+8WQEq1fDwoU6cbUkLAxGjIAjR+wblxBCCCGEcHqSwAqRVoMGgbd34mO8vWHgQPvE42gTJlhPXuNERcGMGfaJRwghhBBCZBqSwAqRVp07Q7Vq1vd59faGNm2gbl27huUQ0dHw119Jj4uNlVpYIYQQQgiRYpLACpFWbm6wcSO89RZ4eoKvr37v56f//fHHMH9+1qjtioxMfq1v3D65QghhSXL7CwghhMhSXB0dgBCZgqcnfPml3ipn7Vq4dw/y5YOmTa3PzGZGPj56xjk4OOmxRYqkfzxCCCeTBW70CSGESBNJYIWwJX9/3WU3q1IKunWD2bMTn2H18YH+/e0XlxBCCCGEyBRkCbEQwrb699cJqrUl066ukCcPdOxo37iEEEIIIYTTkwRWCGFbBQrAjh2QN6+uA36cnx889RTs3Jl052YhRBYnNbBCCCESkiXEQgjbq1AB/vkHli2Db76B27ehcGHo2RMaN05+oychRNaSFZrdCSGESBNJYIUQ6cPdHdq1029CCCGEEELYgEyDCCGEEEIIIYRIm5gouLAVrh9N18vIDKwQQgghhBBCiJQLuQlnf4cz6+H8Foh6AFXeglc+T7dLSgIrhBBCiIzHkCZOQgiR4RgGXD8CZzbopPXq3yRound2I5jN6dbzRBJYIYQQQmQQ0sRJCCEynKgwvTT4zHqdnD64Zn1s9mLw9MsQHQYevukSjiSwQgghhBBCCCEeeXBDJ6yn18GFLRATYXmcMkHR5+DpJlC6CeQpk+4d5SWBFYkzDNi9G6ZM0e/NZqhSBQYOhEaNZMsDWzh1CmbOhBUrICoKiheHAQOgTRvw8HB0dEIIIYQQIrMzDLh5Ek6v1Unr1f3Wx3rlgFKNdNJaqoH+2I4kgRXWmc3QtSssXQrh4Y/qkf74A/bsgbp19T6f7u6OjdORTp+GTZt04lm2rE7qTabkP//zz2HQIIiOhpgY/didO/DBBzB6NGzfDnnzpk/sQgiRoUkNrBBCpKvYaPhnj05YT6+Fu5esj81TVi8NfvplKFwdTI5LI21yZaXUd0AL4KZhGBUtHFfALKAZEAZ0Ngzj74fHXn54zAR8YxjGRFvEJGxg1CidvIaFJTwWGgpbtkD37vDdd3YPzeEuXoSOHeHwYf2x2awTeQ8PmD0bOnRI+hyrV+vkNTw84bGQELhwQSfEhw7JTLcQImuQ33VCCJG+Iu7DuT900np2o/7YEmWCYs9DmaY6ac1V0r5xJsJWqfP3wOfAD1aONwVKP3yrCcwDaiqlTMAcoBEQCOxTSq00DOOEjeISqRUWBtOnW05e44SHw+LFMGEC5Mtnv9gc7Z9/oHp1uHtXJ65xIiPhwQM9ax0eDu++m/h5hg61nLzGiY7WSeyWLVC/vm1iF0IIIYQQWcvdyw/rWdfCpZ1gjrE8zt1PLwku2xxKNQTvnPaNM5lsksAahrFdKVU8kSGvAD8YhmEAfyqlsiulCgDFgXOGYVwAUEoteThWElhHW7Uq+a2vlyyBvn3TN56MpFcvuHcvfvL6uLAwPaZNG/D3tzzm7Fk4fz7pa4WEwNy5ksAKIYQQQojkMQy4cQxOroJTa/S/rclWWM+ylmkKxV8A14zff8Vei5cLAVce+zjw4WOWHq9p6QRKqW5AN4CiRYumT5TikWvXdF1nUiIi9IxkVnHjBmzcCLGxiY9TChYsgD59LB+/dk0vOU5sBjZOVvr6CiGEEEKIlDObIXCvTlpProJ7l62PLVAZyjTTSWv+Sk5XvmGvBNbSV8VI5PGEDxrGV8BXAAEBAdLZIb35+4Orq14WmxhXV8iZMZcXpIu//wZPz6S/LmFhutmVtQTW3z/pJDhOVvr6CiFEHEP+1AshRKJiouDS9oczrWsh9KblcSZ3KFH3UT2rfyH7xmlj9kpgA4Eij31cGPgXcLfyuHC0Fi2gZ8+kx7m5Qdu26R9PIgwMjrKJbSwgmCByU5QGdKUU1W1/MWvLhlM6tlIl8PPTS4QT4+cHXbok/5pCCOHUnGsWQAgh7C4qVDdhOrkKzmyESCtNmDyyQenGUK6lrmv18LNvnOnIXgnsSqDXwxrXmsB9wzCuKaVuAaWVUiWAq0B7oKOdYhKJyZMHXnsNfvtNLxO2xN0datSAMmXsG9tjbnCRsTTmHteJQCeDChd2sIiisZVpd2QO98+HEx0ZQ3RkNNFRMeQulJOSlYuTp3AuVEqXTDzzTNKzrwBeXvDCC9aPu7jAsGEweHDijbK8vKB165TFKIQQQgghMo+wO7oJ08lVcH4zxFh5be6TRy8NLtcKStRxinr
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 1152x576 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrazenie regularyzacyjne** – zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów – mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej – funkcja kosztu\n",
"\n",
"$$\n",
2021-04-07 15:03:18 +02:00
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} \\left( h_\\theta(x^{(i)}) - y^{(i)} \\right) \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
2021-03-02 08:32:40 +01:00
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ – parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej – gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej – funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej – gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 33,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h,theta,X,y,lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0/m \\\n",
" * -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0) \\\n",
" + lamb/(2*m) * np.sum(np.power(theta[1:] ,2))\n",
" return j\n",
"\n",
"def dJ_(h,theta,X,y,lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" g[1:] += lamb/m * theta[1:]\n",
" return g"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 35,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(min=0.0, max=0.5, step=0.005, value=0.01, description=r'$\\lambda$', width=300)\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)"
]
},
{
"cell_type": "code",
2021-04-14 08:03:54 +02:00
"execution_count": 36,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2021-04-14 08:03:54 +02:00
"model_id": "187489841a6b4b6a8fac8d4e5f2e28de",
2021-03-02 08:32:40 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
2021-04-14 08:03:54 +02:00
"execution_count": 36,
2021-03-02 08:32:40 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)"
]
},
{
"cell_type": "code",
2021-04-07 15:03:18 +02:00
"execution_count": 38,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=lamb)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label='training error')\n",
" plt.plot(Lambda, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(r'$\\lambda$')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8)"
]
},
{
"cell_type": "code",
2021-04-07 15:03:18 +02:00
"execution_count": 39,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHmCAYAAABK9WIBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeZRc1X3u/Wd3VVdXz3NrFhoQltCAEEKSDcYSGIxJArYxII8hiY1D4im5l0D83nggyY2dOH5t3oRwsRe2rxfBF4ONnQvGRDGDwQIjMQg0gGapNfY8d1d31X7/OKeqTlUP6qFOVQ/fz1q1zj5DVe8WIPrp3x6MtVYAAAAAAEx2ebnuAAAAAAAAo0GABQAAAABMCQRYAAAAAMCUQIAFAAAAAEwJBFgAAAAAwJRAgAUAAAAATAm+BlhjzLXGmLeMMQeMMXcNcb/cGPMfxpjXjTG7jTF/5Gd/AAAAAABTl/FrH1hjTEDS25KullQv6WVJH7HW7vE88yVJ5dbaO40xtZLekjTbWhvxpVMAAAAAgCnLzwrsBkkHrLWH3ED6Y0k3pD1jJZUaY4ykEknNkgZ87BMAAAAAYIryM8DOk3Tcc17vXvP6F0krJJ2U9IakL1hrYz72CQAAAAAwRQV9/GwzxLX08crvk/SapCslLZX0n8aY31hr21M+yJjbJN0mScXFxZcsX77ch+5m2Nm90kCv065dIeWHc9sfAAAAAJgCdu7c2WitrR3qnp8Btl7SAs/5fDmVVq8/kvR160zEPWCMOSxpuaTfeR+y1t4v6X5JWr9+vd2xY4dvnc6Y/3WFdOp1p33bj6S5F+e2PwAAAAAwBRhjjg53z88hxC9LWmaMWWyMCUnaKukXac8ck3SVJBljZkl6h6RDPvYpe4KeiutAX+76AQAAAADThG8VWGvtgDHms5J+JSkg6QFr7W5jzJ+69++T9LeSfmCMeUPOkOM7rbWNfvUpqwKhZDs+lBgAAAAAMG5+DiGWtfYJSU+kXbvP0z4p6Ro/+5AzKRVYdgUCAAAAgInyNcDOaMGCZJsKLAAAAJBx/f39qq+vV28vP29PReFwWPPnz1d+fv6o30OA9UtKgGUOLAAAAJBp9fX1Ki0t1aJFi2TMUJugYLKy1qqpqUn19fVavHjxqN/n5yJOM5t3CHGUAAsAAABkWm9vr6qrqwmvU5AxRtXV1WOunhNg/cIQYgAAAMB3hNepazz/7AiwfgkwhBgAAACYzlpbW3XvvfeO673XXXedWltbR3zmy1/+srZt2zauz5+uCLB+YQ4sAAAAMK2NFGCj0eiI733iiSdUUVEx4jN333233vve9467f2OV3udzfQ9jfS4TCLB+SdlGhwALAAAATDd33XWXDh48qLVr1+qOO+7QM888oy1btuijH/2oVq9eLUn6wAc+oEsuuUQrV67U/fffn3jvokWL1NjYqCNHjmjFihX69Kc/rZUrV+qaa65RT0+PJOnWW2/VI488knj+K1/5itatW6fVq1dr3759kqSGhgZdffXVWrdunT7zmc/ovPPOU2Nj46C+PvXUU3rnO9+pdevW6aabblJnZ2fic++++25dfvnl+slPfjLo/KGHHtLq1au1atUq3XnnnYnPKykp0Ze//GVt3LhR27dv9+cPeAisQuyXYCjZZg4sAAAA4KtFdz3u22cf+frvDXn961//ut5880299tprkqRnnnlGv/vd7/Tmm28mVtZ94IEHVFVVpZ6eHl166aW68cYbVV1dnfI5+/fv10MPPaTvfve7uvnmm/Xoo4/q4x//+KCvV1NTo1deeUX33nuvvvnNb+p73/uevva1r+nKK6/UX//1X+vJJ59MCclxjY2N+ru/+ztt27ZNxcXF+sY3vqFvfetb+vKXvyzJ2c7m+eefl+SE8vj5yZMntWnTJu3cuVOVlZW65ppr9Nhjj+kDH/iAurq6tGrVKt19993j/4MdByqwfklZhTiSu34AAAAAyJoNGzakbAtzzz336KKLLtKmTZt0/Phx7d+/f9B7Fi9erLVr10qSLrnkEh05cmTIz/7Qhz406Jnnn39eW7dulSRde+21qqysHPS+F198UXv27NFll12mtWvX6oc//KGOHj2auH/LLbekPB8/f/nll7V582bV1tYqGAzqYx/7mJ577jlJUiAQ0I033jiaP5KMogLrF1YhBgAAAGac4uLiRPuZZ57Rtm3btH37dhUVFWnz5s1DbhtTUJDMDoFAIDGEeLjnAoGABgYGJDn7qZ6LtVZXX321HnrooXP22Xs+0meHw2EFAoFzfu1MI8D6hVWIAQAAgKwZbpivn0pLS9XR0THs/ba2NlVWVqqoqEj79u3Tiy++mPE+XH755Xr44Yd155136qmnnlJLS8ugZzZt2qQ///M/14EDB3T++eeru7tb9fX1uuCCC0b87I0bN+oLX/iCGhsbVVlZqYceekif+9znMv49jAVDiP3CIk4AAADAtFZdXa3LLrtMq1at0h133DHo/rXXXquBgQGtWbNGf/M3f6NNmzZlvA9f+cpX9NRTT2ndunX65S9/qTlz5qi0tDTlmdraWv3gBz/QRz7yEa1Zs0abNm1KLAI1kjlz5ugf/uEftGXLFl100UVat26dbrjhhox/D2NhRlNynkzWr19vd+zYketunNueX0gPf8JpL/99aeuDue0PAAAAMM3s3btXK1asyHU3cqqvr0+BQEDBYFDbt2/X7bffnlhUaioY6p+hMWantXb9UM8zhNgvVGABAAAA+OzYsWO6+eabFYvFFAqF9N3vfjfXXfIVAdYvbKMDAAAAwGfLli3Tq6++mutuZA1zYP3CNjoAAAAAkFEEWL+wjQ4AAAAAZBQB1i9sowMAAAAAGUWA9UuQAAsAAAAAmUSA9QurEAMAAABIU1JSIkk6efKkPvzhDw/5zObNm3WurUO//e1vq7u7O3F+3XXXqbW1NXMdnaQIsH5hDiwAAACAYcydO1ePPPLIuN+fHmCfeOIJVVRUZKJr5zQwMDDi+WjfNx4EWL94AyyrEAMAAADTzp133ql77703cf7Vr35V//zP/6zOzk5dddVVWrdunVavXq2f//zng9575MgRrVq1SpLU09OjrVu3as2aNbrlllvU09OTeO7222/X+vXrtXLlSn3lK1+RJN1zzz06efKktmzZoi1btkiSFi1apMbGRknSt771La1atUqrVq3St7/97cTXW7FihT796U9r5cqVuuaaa1K+TlxDQ4NuvPFGXXrppbr00kv1wgsvJL632267Tddcc40++clPDjo/evSorrrqKq1Zs0ZXXXWVjh07Jkm69dZb9Zd/+ZfasmWL7rzzzgn/mbMPrF9ShhBTgQUAAAB89dVyHz+7bcjLW7du1Re/+EX92Z/9mSTp4Ycf1pNPPqlwOKyf/exnKisrU2NjozZt2qTrr79expghP+ff/u3fVFRUpF27dmnXrl1at25d4t7f//3fq6qqStFoVFdddZV27dqlz3/+8/rWt76lp59+WjU1NSmftXPnTn3/+9/XSy+9JGutNm7cqPe85z2qrKzU/v379dBDD+m73/2ubr75Zj366KP6+Mc/nvL+L3zhC/qLv/gLXX755Tp27Jje9773ae/evYnPfv7551VYWKivfvWrKed/8Ad/oE9+8pP6wz/8Qz3wwAP6/Oc/r8cee0yS9Pbbb2vbtm0KBALj+/P3IMD6JS8oyUiyUmxAikWlvIn/AwMAAAAwOVx88cU6e/asTp48qYaGBlVWVmrhwoXq7+/Xl770JT333HPKy8vTiRMndObMGc2ePXvIz3nuuef0+c9/XpK0Zs0arVmzJnHv4Ycf1v3336+BgQGdOnVKe/bsSbmf7vnnn9cHP/hBFRcXS5I+9KEP6Te/+Y2uv/56LV68WGvXrpUkXXLJJTpy5Mig92/btk179uxJnLe3t6ujo0OSdP3116uwsDBxz3u+fft2/fSnP5UkfeITn9Bf/dVfJZ676aabMhJeJQKsf4xxqrADbll+oE8KFeW2TwAAAAAy6sMf/rAeeeQRnT59Wlu3bpUkPfjgg2poaND
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
2021-04-07 15:03:18 +02:00
"execution_count": 40,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=0.01, trainsetsize=m)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label='training error')\n",
" plt.plot(M, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(u'trainset size')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
2021-04-07 15:03:18 +02:00
"execution_count": 41,
2021-03-02 08:32:40 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2021-04-07 15:03:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHgCAYAAACcrIEcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeZhV1Z3v//eqAYp5ngQRFJB5KFFxiJGoiEYxiUYxjkmMMZNJ9/2lte/9dabufpLuTnuNSZvEpE2iMdhGjYnGgZA4z0wiAjLIIAJSzAUUQ1Xt+8eu4pwaqSrq1K5T9X49z35Y6+x99vkecOBTa+21QhRFSJIkSZLU2uUkXYAkSZIkSQ1hgJUkSZIkZQUDrCRJkiQpKxhgJUmSJElZwQArSZIkScoKBlhJkiRJUlbIS7qAxurbt280bNiwpMuQJEmSJGXAggULtkVR1K+2c1kXYIcNG8b8+fOTLkOSJEmSlAEhhPV1nXMKsSRJkiQpKxhgJUmSJElZwQArSZIkScoKWfcMrCRJkiQBHD58mI0bN3LgwIGkS1ETFBQUMGTIEPLz8xv8HgOsJEmSpKy0ceNGunXrxrBhwwghJF2OGiGKIrZv387GjRsZPnx4g9/nFGJJkiRJWenAgQP06dPH8JqFQgj06dOn0aPnBlhJkiRJWcvwmr2a8mdngJUkSZKkJti1axd33313k9578cUXs2vXrnqv+da3vsW8efOadP+2ygArSZIkSU1QX4AtKyur971PPvkkPXv2rPea733ve5x//vlNrq+xqtd8tO/Q2OuagwFWkiRJkprg9ttvZ82aNUyePJlvfvObPPfcc0yfPp3PfOYzTJgwAYBPfOITnHLKKYwbN4577rnnyHuHDRvGtm3bWLduHWPGjOELX/gC48aNY8aMGZSUlABw44038vDDDx+5/tvf/jaFhYVMmDCBFStWAFBUVMQFF1xAYWEhX/ziFznhhBPYtm1bjVrnzp3LGWecQWFhIZ/+9KfZu3fvkft+73vf4+yzz+b3v/99jf6cOXOYMGEC48eP57bbbjtyv65du/Ktb32L008/nVdffTUzv8G1cBViSZIkSVlv2O1/zti91/3g47W+/oMf/IClS5eyePFiAJ577jneeOMNli5demRl3XvvvZfevXtTUlLCqaeeyuWXX06fPn2q3GfVqlXMmTOHX/ziF1x55ZU88sgjXHvttTU+r2/fvixcuJC7776bH/7wh/zyl7/ku9/9Lh/72Mf4x3/8R55++ukqIbnStm3b+Jd/+RfmzZtHly5d+Ld/+zfuuOMOvvWtbwHxdjYvvfQSEIfyyv6mTZuYNm0aCxYsoFevXsyYMYPHHnuMT3ziE+zbt4/x48fzve99r+m/sU3gCKwkSZIkNZPTTjutyrYwd911F5MmTWLatGm8//77rFq1qsZ7hg8fzuTJkwE45ZRTWLduXa33/tSnPlXjmpdeeonZs2cDMHPmTHr16lXjfa+99hrLli3jrLPOYvLkyfzmN79h/fr1R85fddVVVa6v7L/55puce+659OvXj7y8PK655hpeeOEFAHJzc7n88ssb8lvSrByBlSRJkqRm0qVLlyPt5557jnnz5vHqq6/SuXNnzj333Fq3jenYseORdm5u7pEpxHVdl5ubS2lpKRDvp3o0URRxwQUXMGfOnKPWnN6v794FBQXk5uYe9bObmwFWkiRJUtara5pvJnXr1o3i4uI6z+/evZtevXrRuXNnVqxYwWuvvdbsNZx99tk89NBD3HbbbcydO5edO3fWuGbatGl85StfYfXq1YwYMYL9+/ezceNGRo0aVe+9Tz/9dL7+9a+zbds2evXqxZw5c/ja177W7N+hMZxCLEmSJElN0KdPH8466yzGjx/PN7/5zRrnZ86cSWlpKRMnTuSf/umfmDZtWrPX8O1vf5u5c+dSWFjIU089xaBBg+jWrVuVa/r168evf/1rrr76aiZOnMi0adOOLAJVn0GDBvH973+f6dOnM2nSJAoLC7nsssua/Ts0RmjIkHNrMnXq1Gj+/PlJlyFJkiQpYcuXL2fMmDFJl5GogwcPkpubS15eHq+++ipf+tKXjiwqlQ1q+zMMISyIomhqbdc7hVjHLorgr9+FLUth5veh78ikK5IkSZLahQ0bNnDllVdSXl5Ohw4d+MUvfpF0SRllgNWxWzUXXvq/cfvxEvhs5pYwlyRJkpQycuRIFi1alHQZLcZnYHXs3nks1V7/EmxbnVwtkiRJktosA6yOTdlhePfJqq8tuj+ZWiRJkiS1aQZYHZt1L8KBXVVfW/y7ONhKkiRJUjMywOrYLPtTzdf2bYWVz7R8LZIkSZLaNAOsmq68DFY8keqfcHaqvfC+lq9HkiRJauW6du0KwKZNm7jiiitqvebcc8/laFuH3nnnnezfv/9I/+KLL2bXrl31vKNtMMCq6d5/HfYVxe0u/WHWXalzq/8Cuz9Ipi5JkiSplTvuuON4+OGHm/z+6gH2ySefpGfPns1R2lGVlpbW22/o+5rCAKumS58+POYS6HMSDD8n7kfl8bOwkiRJUht12223cffddx/pf+c73+E///M/2bt3L+eddx6FhYVMmDCBP/7xjzXeu27dOsaPHw9ASUkJs2fPZuLEiVx11VWUlJQcue5LX/oSU6dOZdy4cXz7298G4K677mLTpk1Mnz6d6dOnAzBs2DC2bdsGwB133MH48eMZP348d95555HPGzNmDF/4whcYN24cM2bMqPI5lYqKirj88ss59dRTOfXUU3n55ZePfLebb76ZGTNmcP3119for1+/nvPOO4+JEydy3nnnsWHDBgBuvPFG/v7v/57p06dz2223HfPvufvAqmmiCJY/nuqPuTT+tfAGWPtC3F50H3zkf0GOPyeRJElShn2nRwbvvbvWl2fPns03vvENvvzlLwPw0EMP8fTTT1NQUMAf/vAHunfvzrZt25g2bRqzZs0ihFDrfX7605/SuXNnlixZwpIlSygsLDxy7l//9V/p3bs3ZWVlnHfeeSxZsoRbb72VO+64g2effZa+fftWudeCBQv41a9+xeuvv04URZx++ul89KMfpVevXqxatYo5c+bwi1/8giuvvJJHHnmEa6+9tsr7v/71r/N3f/d3nH322WzYsIELL7yQ5cuXH7n3Sy+9RKdOnfjOd75TpX/ppZdy/fXXc8MNN3Dvvfdy66238thj8XabK1euZN68eeTm5jbt9z+NAVZNs2kh7NkYtwt6wrCPxO3Rl8T9A7tg1wZY+zycND25OiVJkqQMmTJlClu3bmXTpk0UFRXRq1cvhg4dyuHDh/nf//t/88ILL5CTk8MHH3zAhx9+yMCBA2u9zwsvvMCtt94KwMSJE5k4ceKRcw899BD33HMPpaWlbN68mWXLllU5X91LL73EJz/5Sbp06QLApz71KV588UVmzZrF8OHDmTx5MgCnnHIK69atq/H+efPmsWzZsiP9PXv2UFxcDMCsWbPo1KnTkXPp/VdffZVHH30UgOuuu45/+Id/OHLdpz/96WYJr2CAVVOlTx8e/XHIzY/b+QUw8Sp44+dxf+F9BlhJkiS1WVdccQUPP/wwW7ZsYfbs2QA88MADFBUVsWDBAvLz8xk2bBgHDhyo9z61jc6uXbuWH/7wh7z55pv06tWLG2+88aj3iaKoznMdO3Y80s7Nza11CnF5eTmvvvpqlaBaqTIU19VPl/596ruusQywarwoguXpz79eWvV84fWpALviCdi3Hbr0abn6JEmS1P7UMc0302bPns0XvvAFtm3bxvPPPw/A7t276d+/P/n5+Tz77LOsX7++3nucc845PPDAA0yfPp2lS5eyZMkSIB797NKlCz169ODDDz/kqaee4txzzwWgW7duFBcX15hCfM4553DjjTdy++23E0URf/jDH7j//vsb/H1mzJjBT37yE775zW8CsHjx4iOjtvU588wzefDBB7nuuut44IEHOPvss4/6nqbw4UQ13tZlsOO9uN2hK5xYbYR14Hg4rmLeftkhWPI/LVufJEmS1ELGjRtHcXExgwcPZtCgQQB
2021-03-02 08:32:40 +01:00
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
2021-04-21 11:24:35 +02:00
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
2021-03-02 08:32:40 +01:00
"livereveal": {
"start_slideshow_at": "selected",
2021-04-07 15:03:18 +02:00
"theme": "white"
2021-03-02 08:32:40 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}