umz21/wyk/05_Regresja_wielomianowa.ipynb

1740 lines
375 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
2021-04-06 09:51:15 +02:00
"## Uczenie maszynowe zastosowania\n",
"# 5. Regresja wielomianowa. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ szerokość działki, $x_2$ długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print('Algorithm does not converge!')\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv('data_flats.tsv', header=0, sep='\\t',\n",
" usecols=['price', 'rooms', 'sqrMetres'])\n",
"data = np.matrix(alldata[['sqrMetres', 'price']])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(m, 3 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f32d565c2d0>]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAFkCAYAAACw8IoqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdeXxU1f3/8deZrGRhT1jCDhEFFRAQUbSodUMt7rhUqdK6YZVavz+09dtVq/Xb1mrdRSu44r6i1h2xbAERZTMBWcMSCMIkkHXO7487Y0KYyTqTO5l5Px8PHhPOvXPvhwSSD+fcz+cYay0iIiIiEj88bgcgIiIiIq1LCaCIiIhInFECKCIiIhJnlACKiIiIxBklgCIiIiJxRgmgiIiISJxJdDuAaNe1a1fbr18/AAq/38+u0gq6ZCTTs0M7dwMTkda1eTNs3x76ePfukJPTevFI3KistqzethcDHNajPQke43ZIEiWWLFmy01qb1Zz3KgFsQL9+/cjLywPg6817OPuBeXROT+a/t51McqImUEXixowZMG0alJYefCw9He64A6ZMaf24JOY9Pncdd85ZxelDu/PI5SPdDkeiiDFmQ3PfqwymCQ7Pac/gbpkUl1bwyZodbocjIq1p0iTwhPiW6fE4x0Ui4LUvtwBwzgjNMEv4KAFsAmMMF4zsBcDLSza7HI2ItKrMTJgzx3lNT3fG0tNrxjMy3I1PYtKabV5Wbt1L+9RETjy0WSt9IkFpCbiJJo7oyd3vreaT1TvYWVJO14wUt0MSkdYybhwUFsLs2VBQAIMGOTN/Sv4kQl5f5sz+nXlkT1ISE1yORmKJEsAmys5MZfwhWXy0egdvLCtkyrj+bockIq0pI0PP+kmr8Pksb/iXf8/V8q+EmZaAm0HLwCIiEmkLvyumcE8ZOR3bMapvJ7fDkRijBLAZTjosm45pSazaupcVhXvcDkdERGLQ67Vm/zxq/SJhpgSwGVISE5g4rCegWUAREQm/sspq5ny9FYBzRvR0ORqJRUoAm+mCkb0BeGNZIRVVPpejERGRWPLx6h14y6s4IqcDg7Iz3Q5HYpASwGZST0AREYkU9f6TSFMC2EzqCSgiIpGwu7SCT9fswGPg7GE93A5HYpQSwBaYOKInCR7zQ09AERGRlnr7661UVlvG5WaRnZnqdjgSo5QAtkCgJ2CVz/LGskK3wxERkRgQqP49T8u/EkFKAFtIy8AiIhIuG3ftY8mG3aQlJ3Dq0G5uhyMxTAlgC6knoIiIhEtg67fThnYnLVmbdUnkKAFsoZTEBM4Z7kzTaxZQRESay1rLq0udnyOq/pVIUwIYBoFlYPUEFBGR5lqwrpj1u/bRvX0qxw3s4nY4EuOUAIbB0J7tObS70xPw49XqCSgiIk33/KKNAFw0qheJCfrxLJGlv2FhoJ6AIiLSErtLK3jvm20YAxeN7u12OBIHIpYAGmN6G2M+McasMsasMMbc5B//gzFmizFmmf/XhFrvuc0YU2CMWWOMOa3W+On+sQJjzK21xvsbYxYaY/KNMbONMcn+8RT/7wv8x/s1dI+Wmjg8x+kJuGYHRV71BBQRkcZ79cstVFT7OD43i16d0twOR+JAJGcAq4BfW2sPA44BphpjhviP3WutHe7/NQfAf+xiYChwOvCQMSbBGJMAPAicAQwBLql1nb/6r5UL7Aam+MenALuttYOAe/3nhbxHOP6wWZkpnDg4i2qf5Q1/FZeIiEhDrLW84F/+vfRozf5J64hYAmit3WqtXer/2AusAuora5oIvGCtLbfWfgcUAEf7fxVYa9dZayuAF4CJxhgDnAS87H//TOCcWtea6f/4ZeBk//mh7hEWtZeBrbXhuqyIiMSwpRt3k7+jhK4ZKZx8mHr/SetolWcA/UuwI4CF/qEbjDHLjTFPGmM6+cdygE213rbZPxZqvAvwvbW2qs74AdfyH9/jPz/UtcLipEO70SktidXbvKwo3Buuy4qISAx7fpHzY+mCkb1IUvGHtJKI/00zxmQArwDTrLV7gYeBgcBwYCvw98CpQd5umzHenGvVjflqY0yeMSavqKgoyFuCS070MFE9AUVEpJH2llXy9nJnK9GLVfwhrSiiCaAxJgkn+XvWWvsqgLV2u7W22lrrAx6nZgl2M1D7b38voLCe8Z1AR2NMYp3xA67lP94BKK7nWgew1j5mrR1lrR2VlZXVpD9zTU/ALeoJKCIi9XpjWSFllT7GDuhCv67pbocjcSSSVcAGeAJYZa39R63xHrVOOxf4xv/xm8DF/gre/kAusAhYDOT6K36TcYo43rTOQ3afABf43z8ZeKPWtSb7P74A+Nh/fqh7hE2gJ+DufZXqCSgiIiFZa3l+oVP8cbGKP6SVRXIG8DjgcuCkOi1f7jHGfG2MWQ6cCPwKwFq7AngRWAm8B0z1zxRWATcA7+MUkrzoPxdgOnCzMaYA5xm/J/zjTwBd/OM3A7fWd49w/qHVE1BERBrj6y17WLl1Lx3TkjhtaHe3w5E4Y1StWr9Ro0bZvLy8Jr2nyFvOMXd9BMCC204mKzMlEqGJiEgbdturX/P8oo1cdVx/fnf2kIbfIFKHMWaJtXZUc96rcqMIUE9AERGpT2l5FW/6fz5couVfcYESwAhRT0AREQnl7eWFlFZUM6pvJ3K7ZbodjsQhJYARop6AIiISSqD338VH93E5EolXSgAjRD0BRUQkmNXb9rJs0/dkpiZy5hE9Gn6DSAQoAYwg9QQUEZG6XvDP/p0zPId2yWHZjl6kyZQARpB6AoqISG1lldW8utRZFVLvP3GTEsAIUk9AERGp7d1vtrK3rIoje3VgaM8ObocjcUwJYISdMyKHRI/hkzU7KPKWux2OiIi46PmF/uKP0Sr+EHcpAYywrhkpjB+crZ6AIiJxrmBHCYvWF5OWnMBPhvd0OxyJc0oAW0FgGfilPPUEFBGJV7MXO/v+/mRYTzJSEl2ORuKdEsBWcNKh2XROT2bNdi9fbvre7XBERKSVlVdV88pSZxVIvf8kGigBbAXJiR4uGuVUez0zf4PL0YiISGv7YOV2iksrOLR7JsN6qfhD3KcEsJVcNqYPxsDby7dSXFrhdjgiItKKAr3/Ljm6D8YYl6MRUQLYanp3TuPEwdlUVPt4MW+T2+GIiEgr2bhrH/MKdpKS6OEc/w5RIm5TAtiKLj+mLwDPLtxAtU/FICIi8WB2nlP8ceYRPeiQluRyNCIOJYCt6IRDsujduR2bivcz99sit8MREZEIq6z28VJeYOcPFX9I9FAC2IoSPIbLxjizgE8vUDGIiEis+3j1DnZ4yxmYlc7ofp3cDkfkB0oAW9lFo3qTnOjhkzU72FS8z+1wREQkgl5Y5Cz/qvhDoo0SwFbWOT2Zs47ogbXw7MKNbocjIiIRUvj9fj77tojkBA/nHdXL7XBEDqAE0AWX+YtBXszbRHlVtcvRiIhIJLyYtwmfhVOHdqNzerLb4YgcQAmgC47q05EhPdpTXFrBu19vczscEREJs2qf5cXFNb3/RKKNEkAXGGO4fKyKQUREYtXc/CIK95TRp3MaYwd0cTsckYMoAXTJxOE9yUxJZMmG3awo3ON2OCIiEkaB4o9Jo3vj8aj4Q6KPEkCXpCUncv5I56HgZxaoGEREJFbs8Jbx0aodJHgMF45U8YdEJyWALvqpvxjk9S+3sLes0uVoREQkHF5espkqn+XHh2WT3T7V7XBEglIC6KJB2RkcO7AL+yureXXJZrfDERGRFvL5LC8scoo/tPOHRDMlgC4L7A/89IINWKv9gUVE2rL563axsXgfOR3bcUJultvhiISkBNBlPx7SjW7tU1hbVMr8dbvcDkdERFrgeX/xx4WjepGg4g+JYkoAXZaU4OHi0c4ywTNqCSMi0mYVl1bwnxXb8Rhn20+RaKYEMApccnQfEjyG91dsZ/veMrfDERGRZnh16WYqqn386JAsenZs53Y4IvVSAhgFundI5dQh3aj22R+WD0REpO2wtub7t4o/pC1QAhglAsUgzy/aSGW1z+VoRESkKfI27GZtUSlZmSmcdGi22+GINEg
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 397521.34262209]\n",
" [-841360.50134775]\n",
" [2253766.75764159]\n",
" [-244049.33104935]]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAFkCAYAAACw8IoqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdeXxU1f3/8deZrGRh3wOyBhBEQFFQcd8QFffijhZrtepXavut2vZX7bf1q1/aurQqXXDButG6UkWpFjeUfVF2E7YAIezLZF/m/P64MyaETDJJJrmzvJ+PRx4Tzr1z74dEkw/n3M/nGGstIiIiIhI/PG4HICIiIiKtSwmgiIiISJxRAigiIiISZ5QAioiIiMQZJYAiIiIicUYJoIiIiEicSXQ7gEjXuXNn27dv3xa59rb9xRwsqaB721S6ZKa0yD1EJEy2b4ddu4If794dsrJaLx4RwGct63Z68VlLdtcMUpMS3A5JWtGyZcv2Wmu7NOW9SgAb0LdvX5YuXdoi1/5g1U7ufGU5I3u35527TmuRe4hImMyYAVOnQlHR0cfS0+G3v4UpU1o/Lolrryzayi/eXs1JfTvwzztOdTscaWXGmK1Nfa+WgF105uAupCZ5WLntIDsPlbgdjojUZ9Ik8AT5kenxOMdFWpG1lr8vcH7/3zi2j8vRSLRRAuiitOREzhrUFYAPVxe4HI2I1CszE+bMcV7T052x9PTq8YwMd+OTuLNs6wHWF3jplJ7M+OO6ux2ORBktAbvsouHd+XBNAR+sLuDW0/q5HY6I1GfcOMjPh1mzIDcXBg50Zv6U/IkLXl7ozP5NOqk3KYl69k8aRwmgy84Z0pXkBA9Ltuxnj7dMxSAikS4jQ8/6iev2FpYxZ1UBxsD1Y45xOxyJQloCdllmahLjsjtjLfx7rZaBRUSkYf9Yuo3yKh/nDO5Krw5pbocjUUgJYAQIPLuh5wBFRKQhVT7Lq4vyALjxFBV/SNMoAYwA5x/bjQSPYcHGfRwsLnc7HBERiWCffbub7QdK6N2xDWdmN6kFnIgSwEjQIT2ZU/p3otJn+WhtPY1mRUQk7n3X+mVMHzwe43I0Eq2UAEYILQOLiEhDtu0v5tNv95Cc6OGa0b3dDkeimBLACHHBsG4YA1/k7MVbWuF2OCIiEoFeWZSHtXDJ8B50TE92OxyJYkoAI0TXzFRO6tOR8iof89bvdjscERGJMGWVVfxj6TZAxR/SfEoAI4iWgUVEJJgPVhWwv6icoT3aMqp3e7fDkSinBDCCBBLATzfsoaS8yuVoREQkkvzdv/PHTaf0wRgVf0jzKAGMID3bt2FE7/aUVFTx2bdaBhYREcfa/MMs23qAzJRELhvZ0+1wJAYoAYwwF/lnAT/QMrCIiPi9vMiZ/bvqxF6kJWsXV2k+JYARJpAAzlu3m7JKLQOLiMS7w6UVvLNiBwA3jtW+vxIeSgAjTJ9O6Rzboy3eskq+zN3rdjgiIuKyt5fvoLi8ilP6d2Jg10y3w5EYoQQwAn23DLxKy8AiIvHMWsvLNYo/RMKlxRJAY0xvY8wnxph1xpg1xph7/eMPG2N2GGNW+j8m1HjPg8aYXGPMBmPMhTXGx/vHco0xD9QY72eMWWSMyTHGzDLGJPvHU/x/zvUf79vQPSJJIAH8aN0uKqp8LkcjIiJuWbR5Pzm7C+mamcL5Q7u5HY7EkJacAawEfmKtPRYYC9xljBnqP/aEtXak/2MOgP/YtcAwYDzwrDEmwRiTADwDXAQMBa6rcZ3/818rGzgATPGPTwEOWGsHAk/4zwt6j5b7EjRNdrdMBnRJ52BxBYs27Xc7HBERcUmg9cu1Jx9DUoIW7SR8Wuy/JmvtTmvtcv/nXmAdkFXPWy4DXrfWlllrNwO5wMn+j1xr7SZrbTnwOnCZcZognQO84X//TODyGtea6f/8DeBc//nB7hFxLjquBwAfrN7pciQiIuKG3d5S5q4uIMFjuO5k7fsr4dUq/5zwL8GOAhb5h+42xnxjjHneGNPBP5YFbKvxtu3+sWDjnYCD1trKWuNHXMt//JD//GDXijiBptBz1+yiymddjkZERFrbrMXbqPRZzj+2Gz3atXE7HIkxLZ4AGmMygDeBqdbaw8B0YAAwEtgJ/CFwah1vt00Yb8q1asd8uzFmqTFm6Z49e+p4S8sb1rMtvTu2YW9hGcu2HnAlBhERcUdllY9XF+cBcONYFX9I+LVoAmiMScJJ/l6x1r4FYK3dZa2tstb6gL9RvQS7Hag5x90LyK9nfC/Q3hiTWGv8iGv5j7cD9tdzrSNYa/9qrR1trR3dpUuXpvzVm80Yo2VgEZE49Z/1u9l5qJT+ndM5dUAnt8ORGNSSVcAGeA5YZ619vMZ4jxqnXQGs9n8+G7jWX8HbD8gGFgNLgGx/xW8yThHHbGutBT4Brva/fzLwbo1rTfZ/fjUwz39+sHtEpO+WgVcX4IQvIiLxIND65YaxffB4tO+vhF9L7idzGnATsMoYs9I/9nOcKt6ROEuvW4AfAlhr1xhj/gGsxakgvstaWwVgjLkbmAskAM9ba9f4r3c/8Lox5rfACpyEE//r340xuTgzf9c2dI9INLJXe7q3TSX/UClfbz/EyN7t3Q5JRERa2Oa9RXyRs5fUJA9Xn9DL7XAkRrVYAmitnU/dz9zNqec9jwCP1DE+p673WWs3UUcVr7W2FLimMfeIRB6PYfxx3Xnxqy18sHqnEkARkTjwin/2b+KInrRLS3I5GolVaioU4QLLwB9qGVhEJOaVVlTxz2XbAbhpbF93g5GYpgQwwp3UtyOd0pPZuq+YdTu9bocjIiIt6F9f53OopIIRvdszvFc7t8ORGKYEMMIleAwXDHO2//lQ1cAiIjHLWvvdzh83jjnG5Wgk1ikBjALj/e1gPlxT4HIkIiLSUpZuPcA32w/RIS2JS0f0dDsciXFKAKPAKf070TY1kW93FbJxT6Hb4YiISAt47ovNANw0tg+pSRG3Tb3EGCWAUSA50cN5QwPLwJoFFBGJNVv3FTF3bQHJCR5uPEU7f0jLUwIYJbQriIhI7Hrhyy1YCxNH9qRrZqrb4UgcUAIYJU7P7kx6cgKrdxxm2/5it8MREZEwOVRSwT+XbgNgyrh+Lkcj8UIJYJRITUrg7CFdAS0Di4jEkllL8igqr+K0gZ04tkdbt8OROKEEMIpoGVhEJLZUVvl48cstANw2rr+7wUhcUQIYRc4a3IWURA/L8w5ScKjU7XBERKSZPlhdQP6hUvp3SefMQV3cDkfiiBLAKJKekvjdD4i56gkoIhLVrLXMmO+0fpkyrh8ej3E5IoknSgCjzEXDnb2B56zSMrCISDRbnneAr7cdpENaEleO6uV2OBJnlABGmXOP7UZKoodFm/ez42CJ2+GIiEgTzfA3fr5hTB/aJKvxs7QuJYBRpm1qEuf7m0K/s2KHy9GIiEhTbNtfzNw1BSQlGG5W42dxgRLAKHTVCc5SwVvLt2OtdTkaERFprBe+3ILPwsQRWXRtq8bP0vqUAEah07M70zkjmY17ivhm+yG3wxERkUY4XFrBrCV5gBo/i3uUAEahxAQPE0dkAc4soIiIRI9Zi7dRVF7FqQM6MbSnGj+LO5QARqkrT3ASwH99s5PySp/L0YiISCgqq3y8+NUWAG47XbN/4h4lgFFqWM+2DO6Wyf6icj77do/b4YiISAg+XFPAjoMl9O+SzlmDurodjsQxJYBRyhjz3SygloFFRKLDc/7Gz98/TY2fxV1KAKPYZSOzMAb+s243B4vL3Q5HRETqsWzrAVbkHaR9WtJ33RxE3KIEMIp1b5fKuIGdKa/y8d432hlERCSSPTd/EwA3jDlGjZ/FdUoAo1xgGfhtNYUWEYlY2/YX8+HqQOPnvm6HI6IEMNpdOKw7ackJLNt6gC17i9wOR0RE6vDiV07j50uP70k3NX6WCKAEMMqlJSdy0XE9AHhLs4AiIhHHW1rBrCXbAPi+Gj9LhFACGAOql4G1NZyISKSZtWQbhWWVnNK/E8dltXM7HBFACWBMGNu/Ez3apbJtfwlLtx5wOxwREfGrrPL
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1,x2,n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n+1):\n",
" for i in range(m+1):\n",
" X.append(np.multiply(np.power(x1,i),np.power(x2,(m-i))))\n",
" return np.hstack(X)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv('polynomial_logistic.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c='r', marker='x', s=50, label='Dane')\n",
" ax.scatter(X1p, X2p, c='g', marker='o', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" return fig"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df3Ac93nf8c8DiaCnIBKLlBLLlBhJJZpEUmZoCVHSmhM6tmXLcCtCjBLIURo6YUaTXzUjyqnocRpn1GQspzNiqNRtqiCJnYS1YCsARY+ZsrQkO6Op7QhSZIu0RwYsNzYDNVZEOz1CKUjxnv6xu8TicAccgLvd7+6+XzM3uPvuHvBd7N3ec999vs+auwsAAADh6sm7AwAAAFgaARsAAEDgCNgAAAACR8AGAAAQOAI2AACAwBGwAQAABO7ivDuQh0svvdSvuuqqvLsBAACwwNNPP/0P7n5ZY3slA7arrrpKk5OTeXcDAABgATP722btnBIFAAAIHAEbAABA4AjYAAAAAkfABgAAEDgCNgAAgMARsAEAAAQuiIDNzP7YzL5lZidaLDcze9DMps3sS2Z2Q2rZbjObim+7s+s1AABANoII2CR9RNItSyx/h6SB+HaXpP8qSWa2UdIHJP2IpJskfcDMLulqTwEAADIWRMDm7n8l6fQSq+yU9Kce+byk15rZ5ZLeLum4u592929LOq6lAz8AAIDCCSJga8NmSd9MPT4Vt7VqX8TM7jKzSTObfOmll7rWUQAAgE4rSsBmTdp8ifbFje4Pufuguw9edtmiS3R1j7s0MRH9bKcdAACgQVECtlOSrkw9vkLSzBLt4Th8WNq1S7r77vngzD16vGtXtBzAyvBFCEDFFCVgOyLpZ+PZoj8q6R/d/UVJxyS9zcwuiScbvC1uC8fwsLR3r3Tw4HzQdvfd0eO9e6PlAFaGL0IAKubivDsgSWb2MUlvknSpmZ1SNPNznSS5+x9IOippSNK0pFck/Vy87LSZ/UdJT8W/6j53X2ryQvbMpAMHovsHD0Y3KQrWDhyIlofMPfrwGx5e2NdW7UAW0l+EpOi9xBchACVmXsFTB4ODgz45OZntH3WXelIDmvV6MQKdiYloxCIdYKZHCcfHpdtuy7uXqKL06zBRlC9CAIJWm6tp7OSYpl6e0sCmAY1cN6L+9f2Z/G0ze9rdBxe1E7BloMgfLI2ncBtHMoqwDSivon4RAhCsJ7/xpIYODanudc2em1Xfuj71WI+O3nlU27ds7/rfbxWwFSWHrbgaA556fXFOW8iSU7pJn3t6CNYQhuS9lVaE9xSAYNXmaho6NKTa2Zpmz81KkmbPzap2Nmo/c/ZMbn0jYOu2w4cXBzjpAKgIydHpPLwEwRryVPQvQugOZg9jjcZOjqnu9abL6l7X2ImxjHs0j4Ct24aHozyvdICTBEDj48VIjmYkA6EpwxchdB6zh7FGUy9PXRhZazR7blbTp6cz7tE8ArZuM4uS8htHo1q1h4aRDIRouS9CO3cy0lJFlFHCGg1sGlDfur6my/rW9Wnrxq0Z92geARuWxkgGQrTcF6FHH2WkpYrIucUajVw3oh5rHhr1WI9Grh/JuEfzmCWKpVGHDUXE7OZqY/Yw1iDUWaIEbADKqcjldLB67Hd0wJmzZzR2YkzTp6e1deNWjVw/og29GzL52wRsKQRsQAm0M/orMdJSJYysogSowwagXJabETgxwezmqiHnFiVGwAagmJaaEfie90if/Syzm6umDGWUgBY4JQqguFrlK/3Yj0k/8RNcAxdA4ZDDlkLABpRIsxmBErOb28VMcCAo5LABKJ9WV+GQil2wOktcHQAoBAI2AMXEVTg6g6sDdB7XNEUXELABKCZmBHYGVwfoPEYt0QXksAEoJnKvOourA3QO9eCwBq1y2C7OozMAsGZJPlq77WitVS4ggcXqJKOWUhSkJbOYCdawBpwSBYAqIxewO9JBW4JgDWtAwAYAVUYuYHe0GrUkAMYqEbABQJVxdYDOY9QSXUAOGwBUGbmAnddq1FKK2nfs4H+LFSNgAwCgk5JRy/RM5SRo27GDUUusCgEbAACdxKgluoAcNgAAgMARsAEAAAQuiIDNzG4xs+fNbNrM9jdZfsDMno1vXzWz76SWnU8tO5JtzwEAALov9xw2M7tI0ocl3SzplKSnzOyIu385Wcfd706t/+8kvSH1K/7J3bdl1V8AAICshTDCdpOkaXd/wd3PSnpY0s4l1n+XpI9l0rO8uUsTE4tr9rRqBwAUF8d8LCGEgG2zpG+mHp+K2xYxs++TdLWkx1PNrzGzSTP7vJmVa6704cPSrl0LCy0mBRl37aICOQCUCcd8LCH3U6KSml1YrdXXiDskPeLu51NtW9x9xsyukfS4mT3n7l9b9EfM7pJ0lyRt2bJlrX3OxvDwfHVsKarhk66eTS0fACgPjvlYQggB2ylJV6YeXyFppsW6d0j6lXSDu8/EP18ws88oym9bFLC5+0OSHpKkwcHBYowrN1bHTt7E6erZAIBy4JiPJZjnfE7czC6W9FVJb5H0d5KekvTT7n6yYb3vl3RM0tUed9rMLpH0irvPmdmlkj4naWd6wkIzg4ODPjk52fmN6RZ3qSd19rpe540LAGXFMb/SzOxpdx9sbM89h83dX5X0q4qCsa9I+ri7nzSz+8zs1tSq75L0sC+MMH9Q0qSZfVHSE5LuXy5YK4R0gmmSv5D2a79G8ikAlFGzYz4XjIfCOCUqdz8q6WhD2282PP6tJs/7X5J+qKudy0OSePqe90SPH3xw4f0HH5wfOudbFwCUQxKspS8cnzyWOOZXXBABGxo0Jp42C9wOHowuIsx16QCgHA4fXhisNea0ccyvtNxz2PJQiBw29+jU54MPzrclb2IpemMPD/NtCwDKwr35sb1VO0qpVQ4bAVvISDwFAHQKAWEhBDvpAC2QeAoA6CQK8xYaAVuIGhNP6/X5nLYiB21cdgUlUpurafSZUd17/F6NPjOq2lwt7y4BS0vnRyefJRTmLQxOiYZoYiL6tpNOPE2/scbHi5l4WtbtQuU8+Y0nNXRoSHWva/bcrPrW9anHenT0zqPavmV73t0DWksfcxMU5g0Kp0SLZHg4Cl4eeCAaonafny00Pi7t3FnMESm+3aEEanM1DR0aUu1sTbPnZiVJs+dmVTsbtZ85eybnHgJLSM88TYQcrHFm5gICthCZRSNNjz66MN/ALApq9u0rZr5BcqBIgraensVT2IHAjZ0cU93rTZfVva6xE2MZ9whYgaLlR5N3dwEBW8jKOCJVtG93QIOpl6cujKw1mj03q+nT0xn3CGhTEfOjy/g5uEoUzg1ZGS8E3OrbXVG3B5UzsGlAfev6mgZtfev6tHXj1hx6BbShiIV5y/g5uEpMOiiCstRjW+qyKxV88yEQK6xNVZurafMDm1U7u3hWaH9vv2bumdGG3g1Z9BxYmSLXYSvL52AbmHRQVEXLN1hKq293yXD3SnIRSERFp6wwR6Z/fb+O3nlU/b396lvXJykaWevvjdoJ1hCsJD+6MdBp1R6KMn0OroW7V+524403eiHU6+5797pL0c9mj4ukXncfH1/c71btSxkfX/x/SP9/xsc712+U2yrfZ7W5mo8+Per7j+/30adHvTZXy7jjQAWU7XOwDZImvUnsknvwlMetMAEbQUlrFXwTo4vSr5/kxusIyF8FPwdbBWzksIXMC5xvkAWnACQ6yFeYI8P7E+i+Cr7PyGEroqLmG2SFEiHoFF9Fjgz1oYDu43PwAgI2FNdqPmSBRumR2pXUpqI+FIAMUYcNxdT44ZguESIx0ob2rbY2FfWhAGSIHDYUExeSR6esNUdmpblvALAEcthQLsPDUVCWHslIRjzGxzkdhfatJUeG0/IAMkLAhmIiERV5W23uGwCsAjlsALAaRbwuI4DCImADgNVITsunc9ySoG3HjsxOy9fmaho7Oaapl6c0sGlAI9eNqH99fyZ/G0B2mHTQSRUs8AcgP09+40kNHRpS3euaPTervnV96rEeHb3zqLZv2Z539wCsApMOskAhTQAZqc3
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób, \n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0/(1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X*theta, eps)\n",
"\n",
"def J(h,theta,X,y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0)/m\n",
" if lamb > 0:\n",
" j += lamb/(2*m) * np.sum(np.power(theta[1:],2))\n",
" return j\n",
"\n",
"def dJ(h,theta,X,y,lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" if lamb > 0:\n",
" g[1:] += lamb/float(y.shape[0]) * theta[1:] \n",
" return g\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta]) \n",
" return theta, errors"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, Xpl, Ypl, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)\n",
"print(r'theta = {}'.format(theta))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02),\n",
" np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pawel/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" # Remove the CWD from sys.path while we load stuff.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3iT5foH8O/TCZQiFMoqFFAqU2QUFUWrAgp4lCFQEfXowfnzyHIA4lYUF1gXB0+PG7WOUlARRBAFGVIKyBJaNpRdRlpKR3L//kgCIU1KR5J35Pu5rlxt3rxN7rbJmzvPez/3o0QERERERKRfIVoHQERERETlY8JGREREpHNM2IiIiIh0jgkbERERkc4xYSMiIiLSOSZsRERERDoXpnUAWmjQoIG0bNlS6zCIiIiIzrF69eojIhLrvj0oE7aWLVsiMzNT6zCIiIiIzqGU2uVpO0+JEhEREekcEzYiIiIinWPCRkRERKRzTNiIiIiIdI4JGxEREZHOMWEjIiIi0jldJGxKqQ+VUoeUUhu83K6UUm8rpXKUUn8ppbq63PZPpVS24/LPwEVNREREFBi6SNgAfAygbzm39wOQ4LjcD2A6ACilYgA8C+ByAJcBeFYpVc+vkRIREREFmC4SNhH5HUBeObsMAPCp2K0AUFcp1QTAjQAWiEieiBwDsADlJ35EREREhqOLhK0C4gDscbm+17HN23YiIiIi0zBKwqY8bJNytpe9A6XuV0plKqUyDx8+7NPgyiUCzJpl/1qR7URERERujJKw7QXQ3OV6MwC55WwvQ0Q+EJFEEUmMjS2zpqr/ZGQAgwcDY8eeTc5E7NcHD7bfTkRERFQOoyRscwDc5ZgtegWAEyKyH8B8ADcopeo5Jhvc4NimHwMHAqNHAykpZ5O2sWPt10ePtt9ORJXDkWsiCjK6SNiUUl8CWA6gjVJqr1JqpFLqQaXUg45d5gLYDiAHwH8B/B8AiEgegBcBrHJcXnBs0w+lgGnTziZtISFnk7Vp0+y3E1HlcOSaiIKMkiD8JJqYmCiZmZmBfVARe7LmZLMZI1kTsb/5DRx4brzethMFgvtI9bRpZa/zeUlEBqSUWi0iie7bdTHCZnrONxdXriMDesaRDNIjjlwTkR9ZiixIzUrF+AXjkZqVCkuRReuQOMLmd0YfCTB6/GRuRh25JiLdWrp7KfrP7A+b2FBQUoCo8CiEqBDMHTEXPeN7+v3xOcKmlYyMssmN68iA3keoOJJBemXkkWsi0iVLkQX9Z/aHpdiCgpICAEBBSQEsxfbt+cX5msXGhM3fBg4E0tPPTW6cSVB6ujFmiTrjdcVkjbTkPvJrs5WdjU3Bh7OHqZrSNqbBJjaPt9nEhrQNaQGO6CwmbP6mFDBoUNnkxtt2PeJIBumN0UeuyT9Yc0vVlH00+8zImruCkgLk5OUEOKKzmLBR+TiSQXp0vpHrAQM40hKM2PeSqimhfgKiwqM83hYVHoXWMa0DHNFZTNiofBzJID0638j17NkcaQlGrLmlakrukIwQ5Tk1ClEhSO6YHOCIzuIsUSof+7CREXF2c3Dj7GGqBr3OEmXCRkTm5Jq0OTFZMz/+38kH8ovzkbYhDTl5OWgd0xrJHZNRO6J2QB6bCZsLJmxEJlCR0V+AIy3BhCOrZALsw0ZE5nK+GYGzZnF2c7BhzS2ZGBM2IjKm8mYEjhoF/PYbZzcHGzP0vSTygqdEici4vNUrXXMNcOut5460uO6bnm6fTUpEpDOsYXPBhI3IRDzNCAQ4u5mIDIk1bERkPt5W4QCMv8JIoHA5JyJDYMJGRMbEVTh8g8s5ERkCEzYiMibOCPQNLufkexy1JD9gDRsRGRNX4fAdNpv1rVmz7KOTnPRCVcBJBy6YsBERueFyTr7DBr5UDZx0QEREnnmbvBGEH+h9govQkx8wYSMiCmacvOEfzqTNFZM1qgYmbEREwYyTN/yDo5bkY0zYiIiCGZdz8j2OWpIfhGkdABERacjZTLii2+n8vI1aAvbtSUn821KlMWEjIiLyJeeopWtrGWfSlpTEUUuqEiZsREREvsRRS/ID1rARERER6RwTNiIiIiKdY8JGREREpHO6SNiUUn2VUluUUjlKqQkebp+mlFrruGxVSh13uc3qctucwEZORERE5H+aJ2xKqVAA7wHoB6A9gOFKqfau+4jIWBHpLCKdAbwDIN3l5kLnbSJyS8ACDwQR+yLC7j17vG0nIiLj4jGfyqF5wgbgMgA5IrJdRIoBfAVgQDn7DwfwZUAi01pGBjB48LmNFp0NGQcPZgdyIiIz4TGfyqGHhC0OwB6X63sd28pQSrUA0ArAIpfNNZRSmUqpFUopr81tlFL3O/bLPHz4sC/i9r+BA8t2x3btns1ePkRE5sFjPpVDD33YPK2E623c9zYA34qI1WVbvIjkKqUuBLBIKbVeRLaVuUORDwB8AACJiYnGGFd2746dkmL/3rV7NhERmQOP+VQOPYyw7QXQ3OV6MwC5Xva9DW6nQ0Uk1/F1O4DFALr4PkQNub6AnfjCJSIyJx7zyQs9JGyrACQopVoppSJgT8rKzPZUSrUBUA/Acpdt9ZRSkY7vGwC4CsCmgETtT64Fps4hcVdjxrD4lIjIjDwd87lgPEEHp0RFpFQp9W8A8wGEAvhQRDYqpV4AkCkizuRtOICvRM551rYDMEMpZYM9+ZwiIsZP2JyFp6NG2a+//fa537/99tlPYfzURURkDu41a9Omnb0O8Jgf5JQEYdaemJgomZmZWofhneuLFvCeuKWnc106IiKzmDXL/mHdtWbN9f2Ax/ygoJRaLSKJZbYzYdMpEfupz7ffPrvN+SIG7KNwAwfy0xYRkVmIeD62e9tOpsSEzYUhEjbA/iINcSkztNn4YiUioqphQmgI3hI2PUw6IE9YeEpERL7ExryGxoRNj9wLT222ss0UiYiIKoONeQ1N81mi5EFGxrmzhNybKSYlGbPwlMPxZCKWIgvSNqYh+2g2EuonILlDMqIjo7UOi8g7NuY1NNaw6ZFZExvOgCKTWLp7KfrP7A+b2FBQUoCo8CiEqBDMHTEXPeN7ah0eUflYH61rrGEzEqXOJi7OBrrlbTcKDseTCViKLOg/sz8sxRYUlBQAAApKCmAptm/PL87XOEKichitPtq1kXxFtpsYEzY9M1uBqHM43pm0hYSUPfVLpHNpG9NgE5vH22xiQ9qGtABHRFRBRqyPNtv7YDUwYdMzM45IcZ08Mrjso9lnRtbcFZQUICcvJ8AREVWQt/po5/uMHpMfM74PVhEnHeiZGQtEvQ3HG/X3oaCTUD8BUeFRHpO2qPAotI5prUFURBUwcKC9Vti1Dtr5PpOUpM/kx4zvg1XESQdGYJYC0fLWyQvCFx/pRCUn+ViKLIibGgdLsaXMXUVHRCP30VzUjqgdiMiJgodZ3gcrgJMOjMpoBaLl8eVwPAtRyVcqWSMTHRmNuSPmIjoiGlHhUQDsI2vREfbtTNaIfMxM74PVISJBd+nWrZsYgs0mMnq0CGD/6um6kdhsIunpZeP2tr086ell/w6uf5/0dN/FTeZWxdeZpcgiqatTZcKCCZK6OlUsRZYAB04UBMz2PlgBADLFQ+6iefKkxcUwCRuTEu+C8EVMfuT6/HFe+Dwi0l4Qvg96S9hYw6ZnYtIGur4iLjVxTqyFo6qSStbI8PVJ5H9B+DpjDZsRORvluj8ZvW0PNmwRQr4iVaiRYX8oIv/j++AZTNjIuKryJkvkznWktjLNRNkfiogCiH3YyJjc3xxdW4QAHGmjivM2exmwb09K8rzGLftDEVEAsYaNjIkLyZOvVLdGprK1b+WwllpRcPIUTp0sdFxO4fSpYpQUlTgupSgpLoXYzi6NpRyPFRYRhvDIcETUcF4iULN2DURdUOvMJTQstEpxEVHgeKthY8JGxhSEhaikQxWY+CIiKDhxCof3HMGhPUdxeM9RHNl7FHkHjuP44RM4fugEjh86ieOHTqAw/7Rfw61ZuwYuiK2DmMZ1Ua9xXcQ0sn+NbVYfjVrGonGrhohtVh9h4Tz5QqQVJmwumLA
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')\n",
"plot_decision_boundary(fig, theta, Xpl)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv('polynomial_logistic.tsv', sep='\\t')\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1],1)\n",
"theta, errors = GD(h, J, dJ, theta_start, Xpl, Ypl, \n",
" alpha=0.1, eps=10**-7, maxSteps=10000)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pawel/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" # Remove the CWD from sys.path while we load stuff.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmwAAAFmCAYAAADQ5sbeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3hT5RcH8O/bSSllD6FsKShDGRVREByAWESGaFXcKIo/ZargAgcIggLFrcWNUkeLqBVEVKQIQhkqZbWsUgq0UEZaSkdyfn8kgdAmJS1J7k3y/TxPnjY3N81pm9ycvPe851UiAiIiIiLSrwCtAyAiIiKiijFhIyIiItI5JmxEREREOseEjYiIiEjnmLARERER6RwTNiIiIiKdC9I6AC3Ur19fWrZsqXUYREREROfYsGHDERFpUHa7XyZsLVu2RGpqqtZhEBEREZ1DKbXP3naeEiUiIiLSOSZsRERERDrHhI2IiIhI55iwEREREekcEzYiIiIinWPCRkRERKRzukjYlFIfKaVylFJbHNyulFLzlVIZSql/lVJdbW67TymVbrnc57moiYiIiDxDFwkbgE8ADKjg9psARFkuowC8CwBKqboApgK4EkB3AFOVUnXcGikRERGRh+kiYRORPwHkVbDLYACfidlaALWVUo0B3AhguYjkicgxAMtRceJHRERE5HV0kbA5IRLAfpvrWZZtjrYTERER+QxvSdiUnW1SwfbyP0CpUUqpVKVUam5urkuDq5AIkJRk/urMdiIiIqIyvCVhywLQzOZ6UwDZFWwvR0Q+EJFoEYlu0KDcmqrus3gxMGwYMH782eRMxHx92DDz7UREREQV8JaEbQmAey2zRXsAOCEiBwEsA9BfKVXHMtmgv2WbfgwZAowdC8TFnU3axo83Xx871nw7EVUOR66J/Nauf/YiMe4nGI1GrUPxKF0kbEqprwCsAdBOKZWllBqplHpUKfWoZZdkALsBZAD4EMBjACAieQBeAbDecnnZsk0/lALmzj2btAUEnE3W5s41305ElcORayK/teKLP/Hu+E8w9urnsOufvVqH4zFK/PCTaHR0tKSmpnr2QUXMyZqVyeQdyZqI+c1vyJBz43W0ncgTyo5Uz51b/jqfl0Q+SUTw+6LVeHf8Jzh51IDbn7wFd08ZjtCwUK1Dcwml1AYRiS67XRcjbD7P+uZiy3ZkQM84kkF6xJFrIr+llML1d/bCgq1z0f/ePlj02mKMumwi/lmZ5rLHMBQZEL8xHpOWT0L8xngYigwu+9lVxRE2d/P2kQBvj598m7eOXBORy2z67T/MHfU+Du4+jEGjb8RDM0egekRYlX9eSmYKYhbGwCQmFJQUIDw4HAEqAMkjktGreS8XRm6foxE2JmzulpRkHomyTW5sk6DERGDoUM/EUlW28VoxWSOt8XlJRBaFBafxyfOLkDQ/GY1a1Mf4D0ej6w2dKv1zDEUGRM6JhKG4/IhaREgEsidmo0ZIDVeE7BBPiWplyBBzUmb7JmI9nZOY6B2zRK3x2uKbImmp7MivyVR+Njb5H84e9lth4dUweu79mPPnywgKCcKkfi8jbvQHKMwvrNTPSUhLgElMdm8ziQkJWxJcEW6VMGFzN6XMI2hlkxtH2/XIm2vwyDctXlz+tLxtTRtrK/0Ta279Xseel+C9TbMxfMIg/PTBr3i0y1PYsnq70/dPP5qOgpICu7cVlBQgIy/DVaFWGhM2qhhHMkiPzjdyPXgwR1r8EfteEoDQsFA88vq9eP33F2EyCSb2mYIPJ32B4qKS8943ql4UwoPD7d4WHhyONnXbuDpcp7GGjSrmCzV45H/4vPVfrG0kG6cMhXh/4qdIjl+BVp2a45mFY9GqY3OH++u5ho0JG1WMfdjIG3F2s3/j7GEq4++fNuD1ke+i4MQpjJp9Dwb/bwCUg+cEZ4nqCBM2Ij/AkRb/xP87OXDs8HG8PvIdrEvehO4xXfDkgsdQp1Ftu/vmF+cjYUsCMvIy0KZuG8R2jHX7yJoVEzYbTNiIfIAzo78AR1r8CUdW6TxEBN+/vRQfPPU5wmtVx+TPn0C3fpdrHdY52NaDiHzL+WYEJiVxdrO/4exhOg+lFIY8fhPeXj8TtRvUxDMDpuOTKYu8YiF5JmxE5J0qmhE4ZgywciVnN/sbX+h7SR7RqmNzzF/7Kvrd1wcLp32HSf1ewdGDx7QOq0I8JUpE3stRvVLv3sCtt3KWKBGd17JPfseb/4tHWEQY3t3wGupH1tM0Htaw2WDCRuRD7M0IBDi7mYictjdtP377chUemHanw9mjnsIaNiLyPY5W4QC8f4URT+FyTkRo2aEZHpx+l+bJWkWYsBGRd+IqHK7B5ZyIvAITNiLyTpwR6Bpczsn1OGpJbsAaNiLyTlyFw3XYbNa1uDQaXQBOOrDBhI2IqAwu5+Q6bOBLF4CTDoiIyD5Hkzf88AO9S5Q9PR8QwGSNLhgTNiIif8bJG+5hTdpsMVmjC8CEjYjIn3Hyhntw1JJcjAkbEZE/43JOrsdRS3KDIK0DICIiDVmbCTu7nc7P0aglYN7epw//tlRpTNiIiIhcyTpqadtaxpq09enDUUuqEiZsRERErsRRS3ID1rARERER6RwTNiIiIiKdY8JGREREpHO6SNiUUgOUUjuUUhlKqcl2bp+rlNpsuexUSh23uc1oc9sSz0ZORERE5H6aJ2xKqUAAbwO4CUB7AHcqpdrb7iMi40Wks4h0BvAmgESbmwutt4nILR4L3BNEzIsIl+3Z42g7ERF5Lx7zqQKaJ2wAugPIEJHdIlIMYBGAwRXsfyeArzwSmdYWLwaGDTu30aK1IeOwYexATkTkS3jMpwroIWGLBLDf5nqWZVs5SqkWAFoB+M1mczWlVKpSaq1SymFzG6XUKMt+qbm5ua6I2/2GDCnfHdu2ezZ7+RAR+Q4e86kCeujDZm8lXEfjvncA+FZEjDbbmotItlKqNYDflFL/iciucj9Q5AMAHwBAdHS0d4wrl+2OHRdn/t62ezYREfkGHvOpAnoYYcsC0MzmelMA2Q72vQNlToeKSLbl624AfwDo4voQNWT7ArbiC5eIyDfxmE8O6CFhWw8gSinVSikVAnNSVm62p1KqHYA6ANbYbKujlAq1fF8fQE8AWz0StTvZFphah8RtjRvH4lMiIl9k75jPBeMJOjglKiKlSqnHASwDEAjgIxFJU0q9DCBVRKzJ250AFomc86y9FMD7SikTzMnnTBHx/oTNWng6Zoz5+vz5534/f/7ZT2H81EVE5BvK1qzNnXv2OsBjvp9T4odZe3R0tKSmpmodhmO2L1rAceKWmMh16YiIfEVSkvnDum3Nmu37AY/5fkEptUFEosttZ8KmUyLmU5/z55/dZn0RA+ZRuCFD+GmLiMhXiNg/tjvaTj6JCZsNr0jYAPOLNMCmzNBk4ouViIiqhgmhV3CUsOlh0gHZw8JTIiJyJTbm9WpM2PSobOGpyVS+mSIREVFlsDGvV9N8lijZsXjxubOEyjZT7NPHOwtPORxPPsRQZEBCWgLSj6Yjql4UYjvEIiI0QuuwiBxjY16vxho2PfLVxIYzoMhHpGSmIGZhDExiQkFJAcKDwxGgApA8Ihm9mvfSOjyiirE+WtdYw+ZNlDqbuFgb6Fa03VtwOJ58gKHIgJiFMTAUG1BQUgAAKCgpgKHYvD2/OF/jCIkq4G310baN5J3Z7sOYsOmZrxWIWofjrUlbQED5U79EOpeQlgCTmOzeZhITErYkeDgiIid5Y320r70PXgAmbHrmiyNSXCePvFz60fQzI2tlFZQUICMvw8MRETnJUX209X1Gj8mPL74PVhEnHeiZLxaIOhqO99bfh/xOVL0ohAeH203awoPD0aZuGw2iInLCkCHmWmHbOmjr+0yfPvpMfnzxfbCKOOnAG/hKgWhF6+T54YuPdKKSk3wMRQZEzomEodhQ7kdFhEQge2I2aoTU8ETkRP7DV94HncBJB97K2wpEK+LK4XgWopKrVLJGJiI0AskjkhEREoHw4HAA5pG1iBDzdiZrRC7mS++DF0JE/O7SrVs38Qo
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r'$x_1$', ylabel=r'$x_2$')\n",
"plot_decision_boundary(fig, theta, Xpl)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix([\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2) \n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3) \n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4) \n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5) \n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X5 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)).reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAFoCAYAAADw0EcgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAUqUlEQVR4nO3df6zd933X8dfbcbuu915YSrw1azva4at2o3/QYapuvUJVu4rMTA2MTjeTtqWTp2iI0g5PsDLQKk0ICkKDgcYgpKUblPZWWcXCMIzSrhoWLIqTBdrUVPdSWGucLe6GupvLoMvuhz/O9ey518l14vv+2vc8HpJ17j3n3HPe+fp7b57+/ro1xggAAH0OTT0AAMC8EWAAAM0EGABAMwEGANBMgAEANBNgAADN9i3Aqup9VfVEVX3qsvteVFUfrar1ndtb9+v9AQBuVPu5Bez9Se644r53JfnYGGM5ycd2PgcAmCu1nxdiraqXJ/n5Mcardz7/TJI3jDEer6rbk3xijPHKfRsAAOAG1H0M2NeMMR5Pkp3br25+fwCAyR2eeoCrqap7ktyTJAsLC3/8Va961cQTAQD8fg8//PAXxhhHrvXrugPs16vq9st2QT5xtSeOMe5Ncm+SHDt2bJw5c6ZrRgCAPamqX302X9e9C/KBJHfvfHx3kp9rfn8AgMnt52UoPpjkPyd5ZVWdq6oTSd6T5M1VtZ7kzTufAwDMlX3bBTnG+K6rPPSm/XpPAICbgSvhAwA0E2AAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQLPDUw8AwAG0uZmsrSXr68nycrK6miwtTT0V3DAEGADX1+nTyfHjyfZ2srWVLCwkJ08mp04lKytTTwc3BLsgAbh+Njdn8bW5OYuvZHZ78f4nn5x2PrhBCDAArp+1tdmWr91sb88eBwQYANfR+vqlLV9X2tpKNjZ654EblAAD4PpZXp4d87WbhYXk6NHeeeAGJcAA5tHmZnLffckP//DsdnPz+rzu6mpy6Cr/azl0aPY44CxIgLmzn2cpLi3NXufK1z90aHb/4uL1+W+Am5wAA5gnl5+leNHFY7aOH0/On3/ukbSyMnudtbXZMV9Hj862fIkv+D0CDGCe7OUsxRMnnvv7LC5en9eBA8oxYADzxFmKcEMQYADzxFmKcEMQYADzxFmKcEMQYADz5OJZiktLl7aELSxcut+B8tDCQfgA88ZZijA5AQYwj5ylCJOyCxIAoJkAAwBoJsAAAJoJMACAZgIMAKCZAAMAaCbAAACaCTAAgGYCDACgmQADAGgmwAAAmgkwAIBmAgwAoNkkAVZVf6mqHquqT1XVB6vqBVPMAQAwhfYAq6qXJHlHkmNjjFcnuSXJXd1zAABMZapdkIeTfGVVHU7ywiTnJ5oDAKBde4CNMf5Xkr+b5HNJHk/yxTHGv7/yeVV1T1WdqaozFy5c6B4TAGDfTLEL8tYkdyZ5RZKvTbJQVd995fPGGPeOMY6NMY4dOXKke0wAgH0zxS7Ib03yP8YYF8YYv5PkI0m+ZYI5AAAmMUWAfS7J66rqhVVVSd6U5OwEcwAATGKKY8AeTHJ/kkeSfHJnhnu75wAAmMrhKd50jPHuJO+e4r0BAKbmSvgAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0OTz0AwKQ2N5O1tWR9PVleTlZXk6WlqacCDjgBBsyv06eT48eT7e1kaytZWEhOnkxOnUpWVqaeDjjA7IIE5tPm5iy+Njdn8ZXMbi/e/+ST084HHGgCDJhPa2uzLV+72d6ePQ6wTwQYMJ/W1y9t+brS1laysdE7DzBXBBgwn5aXZ8d87WZhITl6tHceYK4IMGA+ra4mh67yI/DQodnjAPtEgAHzaWlpdrbj0tKlLWELC5fuX1ycdj7gQHMZCmB+rawk58/PDrjf2JjtdlxdFV/AvhNgwHxbXExOnJh6CmDO2AUJANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAs0kCrKq+qqrur6r/VlVnq+qbp5gDAGAKhyd6359I8u/GGG+tqucneeFEcwAAtGsPsKr6A0n+ZJK3JckY40tJvtQ9BwDAVKbYBfn1SS4k+WdV9StVdV9VLUwwBwDAJKYIsMNJvinJT40xXpNkK8m7rnxSVd1TVWeq6syFCxe6ZwQA2DdTBNi5JOfGGA/ufH5/ZkH2+4wx7h1jHBtjHDty5EjrgAAA+6k9wMYYv5bk81X1yp273pTk091zAABMZaqzIP9ikg/snAH52STfN9EcAADtJgmwMcajSY5N8d4AAFNzJXwAgGYCDACgmQADAGgmwAAAmgkwAIBmAgwAoJkAAwBoJsAAAJoJMACAZgIMAKCZAAMAaCbAAACaCTAAgGYCDACgmQADAGgmwAAAmgkwAIBmAgwAoJkAAwBoJsAAAJoJMACAZgIMAKCZAAMAaCbAAACaCTAAgGaHpx4AmDObm8naWrK+niwvJ6urydLS1FMBtBJgQJ/Tp5Pjx5Pt7WRrK1lYSE6eTE6dSlZWpp4OoI1dkECPzc1ZfG1uzuIrmd1evP/JJ6edD6CRAAN6rK3NtnztZnt79jjAnBBgQI/19Utbvq60tZVsbPTOAzAhAQb0WF6eHfO1m4WF5OjR3nkAJiTAgB6rq8mhq/zIOXRo9jjAnBBgQI+lpdnZjktLl7aELSxcun9xcdr5ABq5DAXQZ2UlOX9+dsD9xsZst+PqqvgC5o4AA3otLiYnTkw9BcCk7IIEAGj2jAFWVW+vqls7hgEAmAd72QL24iQPVdWHq+qOqqr9HgoA4CB7xgAbY/z1JMtJ3pvkbUnWq+pvVtUf2efZAAAOpD0dAzbGGEl+befPU0luTXJ/Vf2dfZwNAOBAesazIKvqHUnuTvKFJPcl+ctjjN+pqkNJ1pP8lf0dEQDgYNnLZShuS/IdY4xfvfzOMcZ2VX37/owFAHBwPWOAjTF+9GkeO3t9xwEAOPhcBwwAoJkAAwBoJsAAAJoJMACAZgIMAKDZZAFWVbdU1a9U1c9PNQMAwBSm3AL2ziQuYwEAzJ1JAqyqXprkT2d2ZX0AgLky1Rawv5/ZrzDavtoTquqeqjpTVWcuXLjQNxkAwD5rD7CdX1/0xBjj4ad73hjj3jHGsTHGsSNHjjRNBwCw/6bYAvb6JG+pqv+Z5ENJ3lhV/2KCOQAAJtEeYGOMvzrGeOkY4+VJ7kry8THGd3fPAQAwFdcBAwBodnjKNx9jfCLJJ6acAQCgmy1gAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAMwEGANBMgAEANBNgAADNBBgAQDMBBgDQTIABADQTYAAAzQQYAEAzAQYA0EyAAQA0E2AAAM0EGABAMwEGANCsPcCq6mVV9YtVdbaqHquqd3bPAAAwpcMTvOdTSX5ojPFIVS0lebiqPjrG+PQEswAAtGvfAjbGeHyM8cjOx5tJziZ5SfccAABTmfQYsKp6eZLXJHlwyjkAADpNFmBVtZjkZ5P84Bjjt3Z5/J6qOlNVZy5cuNA/IADAPpkkwKrqeZnF1wfGGB/Z7TljjHvHGMfGGMeOHDnSOyAAwD5qPwi/qirJe5OcHWP8ePf7A0k2N5O1tWR9PVleTlZXk6WlqacCmBtTnAX5+iTfk+STVfX
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f32473f2ad0>]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAFoCAYAAADw0EcgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3yW9b3/8fcnG5IQVtgbwh6JUkdFrVhH0YqKmPjoLqe257QHFOuqVq2jFheipz39cbrOOe0hEUVFxVFHq2gdaMIeCTuEEVZIApn39/fHHZVigCTkvq57vJ6PB4+Qe11vL6/cvHONz23OOQEAAMA7cX4HAAAAiDUUMAAAAI9RwAAAADxGAQMAAPAYBQwAAMBjFDAAAACPhayAmdkfzGyPma066rauZvZXMytu+tolVMsHAAAIV6HcA/YnSZcec9ttkt5wzmVJeqPpewAAgJhioRzEamaDJL3onBvb9P16SV9xzu00s96S/uacGxGyAAAAAGHI63PAejrndkpS09ceHi8fAADAdwl+BzgeM7te0vWSlJqaevrIkSN9TgQAaI2t+w7rUE29UpPiNTgzTeZ3ICAEPv74473OuczWPs/rArbbzHofdQhyz/Ee6JybL2m+JE2cONEtW7bMq4wAgFP0lw+26o5nV6lXSoJennWu+nXp6HckICTMbGtbnuf1IcjFkr7T9PfvSHre4+UDAEKsZE+l7ntxjSTpgavGUb6AZoRyDMUCSf+QNMLMSs1shqRfSbrIzIolXdT0PQAgStQ2NOrfFxSppj6gaaf10xUT+vgdCQhLITsE6Zy77jh3XRiqZQIA/PXQK+u1duchDezWUb+YOsbvOEDYYhI+AKBd/H1DuX6/dLMS4kzz8nKUlhy213kBvqOAAQBO2d6qWt301HJJ0o0XDVd2/84+JwLCGwUMAHBKnHO6eeFy7a2q1ZmDu+pH5w/1OxIQ9ihgAIBT8j//2Kq31pcro0Oi5uZmKz6OiV/AyVDAAABttm7XIT2wZK0k6cGrx6lP5w4+JwIiAwUMANAmNfWNmrWgSHUNAeVO7K8p43r7HQmIGBQwAECbPLhkrdbvrtSQ7qm66+uj/Y4DRBQKGACg1d5ct1v//Y+tSowPjpxIZeQE0CoUMABAq+yprNFPF66QJP304hEa1y/D50RA5KGAAQBaLBBwuump5dpfXadzhnXTD84d4nckICJRwAAALfaHdzfrneK96tIxUY9Oz1YcIyeANqGAAQBaZHVZhR56Zb0kac608eqVkeJzIiByUcAAACd1pK5RMxcUqq4xoG+cOUAXj+nldyQgolHAAAAndf9La7SxvFrDeqTpzssYOQGcKgoYAOCEXl29S3/5YJuS4uM0Ly9bHZLi/Y4ERDwKGADguHZV1OjWZ4IjJ265dITG9GHkBNAeKGAAgGYFAk6znyrSwcP1Om94pr5/zmC/IwFRgwIGAGjW/Hc26b2N+9QtNUmPTB/PyAmgHVHAAABfsKL0oB55NThy4uHp49UjnZETQHviw7sAAP+kurZBs/KL1BBw+s7ZAzV5ZM/Wv0hlpVRQIBUXS1lZUm6ulJ7e/mGBCEUBAwD8k3tfWKPNe6s1ome6bp8yqvUvsHSpNGWKFAhI1dVSaqo0e7a0ZIk0aVL7BwYiEIcgAQCfWbJypwqWbVdSQpzmXZetlMRWjpyorAyWr8rKYPmSgl8/vb2qqv1DAxGIAgYAkCSVHTyi25pGTtwxZZRG9urU+hcpKAju+WpOIBC8HwAFDAAgNQacbigo0qGaBk0e2UPfPntg216ouPjzPV/Hqq6WSkraHhKIIhQwAIB++/eN+nDzfnVPS9ZD14yXWRtHTmRlBc/5ak5qqjRsWNtDAlGEk/ABIBYddZViYf/RemxHd0nSo9dOUPe05La/bm5u8IT75sTFBe8HQAEDgJhz1FWKVfUBzfr+k2rMkGYMTtL5wzNP7bXT04NXOx57FWRcXPD2tLT2+W8AIhwFDABiydFXKUq6a8qN2pbRS6N2b9Itv71b+sa2Uy9JkyZJZWXBPWwlJcHDjrm5lC/gKBQwAIglR12l+Pyo87Ro3IVKqa/REy88rOSGuuD9M2ac+nLS0trndYAoRQEDgFjSdJXi9k49dOclP5Yk3fnm75W1b3vwfq5SBDzBVZAAEEuystSQlq4bv36TKpNTddGGf+gbRS8H7+MqRcAzFDAAiCW5ufqPL03Tsn5j1KNyn+a88qQ+GzjBVYqAZzgECQAx5OP99XrijGtkLqC5b/xGXY8c4ipFwAcUMACIEYdq6jUrv0gBST/88gCdM/wHUskFXKUI+IACBgAx4ufPrVLpgSMa27eTbrpsnJQwwe9IQMziHDAAiAHPFpbq+aIydUiM17y8HCUl8PYP+ImfQACIctv2HdbPn1stSbrnitEamsmhRsBvFDAAiGL1jQHNzC9UVW2Dvja2l66d2N/vSABEAQOAqPbEG8Uq2n5QvTNS9ODV42RmJ38SgJCjgAFAlPpg0z79+q0SmUlzc7PVuWOS35EANKGAAUAUqjhcrxsLihRw0r99ZajOGtLN70gAjkIBA4Ao45zTz55bqbKKGk3o31k3fHW435EAHIMCBgBRZuHHpXppxU6lJsVrXm62EuN5qwfCDT+VABBFNu+t1j2LgyMnfjF1rAZ1T/U5EYDmUMAAIErUNQQ0K79Qh+sa9fUJfTTttL5+RwJwHBQwAIgSc1/foBWlFerbuYPuv3IsIyeAMEYBA4Ao8N7Gvfrt3zcqzqTH87KV0SHR70gAToACBgAR7kB1nWYXLJdz0k8mZ+lLg7r6HQnASfhSwMzsRjNbbWarzGyBmaX4kQMAIp1zTrctWqFdh2p02oDOmjl5mN+RALSA5wXMzPpKmilponNurKR4SXle5wCAaJD/0Xa9unq30pMTNC8vRwmMnAAigl8/qQmSOphZgqSOksp8ygEAEatkT5V+8UJw5MT9V41V/64dfU4EoKU8L2DOuR2SHpG0TdJOSRXOudeOfZyZXW9my8xsWXl5udcxASCs1TY0alZ+oWrqA7oqp6+mZjNyAogkfhyC7CJpqqTBkvpISjWzbx77OOfcfOfcROfcxMzMTK9jAkBYe+TV9Vpddkj9u3bQvVPH+B0HQCv5cQjyq5I2O+fKnXP1khZJ+rIPOQAgIr1TXK7/emez4uNM8/JylJ7CyAkg0vhRwLZJOsvMOlpwSuCFktb6kAMAIs6+qlrNfmq5JOmGC7N02oAuPicC0BZ+nAP2gaSnJX0iaWVThvle5wCASOOc063PrFB5Za3OGNRV/3YBIyeASJXgx0Kdc3dLutuPZQNApPrz+1v1+to9Sk9J0Ny8bMXH8VFDQKRiYAwARIANuyt1/0vBszUevHqc+nbu4HMiAKeCAgYAYa6mvlEzFxSqtiGg6af30+Xj+/gdCcApooABQJib88o6rdtVqUHdOuqeKxg5AUQDChgAhLG31u/RH9/dooSmkROpyb6cugugnVHAACBMlVfW6uaFwZETsy8ergn9O/ucCEB7oYABQBhyzunmp5drb1Wdzh7STT88b6jfkQC0IwoYAIShP723RX9bX66MDol6LHcCIyeAKEMBA4Aws3bnIT24ZJ0kac60ceqdwcgJINpQwAAgjHw6cqKuMaDrzuivS8f29jsSgBCggAFAGHngpbUq3lOlIZmp+vnlo/2OAyBEKGAAECZeX7Nb//v+ViXGm57Iy1HHJEZOANGKAgYAYWDPoRrd8swKSdItl4zU2L4ZPicCEEoUMADwWSDgdNPC5dpfXadzs7prxqTBfkcCEGLs3wYQ2yorpYICqbhYysqScnOl9HRPI/x+6Wa9U7xXXVOT9Oj0CYpj5AQQ9ShgAGLX0qXSlClSICBVV0upqdLs2dKSJdKkSZ5EWLWjQg+9+unIifHq0SnFk+UC8BeHIAHEpsrKYPmqrAyWLyn49dPbq6pCHuFwXYNm5heqvtHpW2cN1EWje4Z8mQDCAwUMQGwqKAju+WpOIBC8P8Tue3GtNpVXK6tHmu64bFTIlwcgfFDAAMSm4uLP93wdq7paKik
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, _bias_) zachodzi **niedostateczne dopasowanie** (_underfitting_)."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f3247374bd0>]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAFoCAYAAADw0EcgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3zV1eH/8fe5mWQwAmGFEUKYggiEKQpuRetWtC5a/GqtVrtsrfprv99vrVr7rdYqat2ziqsucAsiqEBA9kogjBAgAUIWZN17fn/cMETQAMnn3PF6Ph48Mu5N8uZ6ub5zzvmcY6y1AgAAgHd8rgMAAABEGwoYAACAxyhgAAAAHqOAAQAAeIwCBgAA4DEKGAAAgMearYAZY542xhQbY5bu97k0Y8zHxpi8hrdtmuvnAwAAhKrmHAF7VtKZB3zuNkmfWmt7Sfq04WMAAICoYppzI1ZjTKak96y1Axo+XiVpnLV2szGmk6QZ1to+zRYAAAAgBHm9BqyDtXazJDW8be/xzwcAAHAu1nWAQzHGXCfpOklKTk4e2rdvX8eJAAAAvm3+/PnbrLXph/t1XhewrcaYTvtNQRYf6o7W2sclPS5JOTk5Njc316uMAAAAjWKMWX8kX+f1FOQ7kq5peP8aSW97/PMBAACca85tKF6W9JWkPsaYQmPMJEn3SjrNGJMn6bSGjwEAAKJKs01BWmsvP8RNpzTXzwQAAAgH7IQPAADgMQoYAACAxyhgAAAAHqOAAQAAeIwCBgAA4DEKGAAAgMcoYAAAAB6jgAEAAHiMAgYAAOAxChgAAIDHKGAAAAAeo4ABAAB4jAIGAADgMQoYAACAxyhgAAAAHqOAAQAAeIwCBgAA4DEKGAAAgMdiXQcAAEQPa63q/Fa1/oBq6/f74/er5lsff/v9mvqAOrRM1OBurdUyMc71XwM4ahQwAECT2lZZo5mLN2rGZwv1TYW0OzZeNbHxqvFb1fkDsvbIv7cxUp8OqRravY2Gdm+jnO5p6prWQsaYpvsLAB6ggAEAjoo/YLVwY6lmrCrR56tLtLiwrOGWFpKR5Jfk9++9f4zPKD7Gp/hYnxJig2/jY32Kj/nux8H3YxTnMyrYXqWlm8q0ckuFVm6p0EtzNkiS0lMTNLRbG+VkttGQ7m00oHMrxceywgahjQIGADhsxRXVmrl6m2asKtYXedtUtrtu723x9bUauWGJxq2dr+PXL1Kb3WVKqK9TfFKi4tcVKKZl6hH/3Oo6v5ZsKlPuulLNX1+q+et3qKSiRh8s26IPlm2RJCXE+jSoS2sN6d5GOQ0jZW2S44/67ww0JQoYAOAH1fsD+mbjTs1YVawZq0q0rKj8W7dntk3SuD7tNbZggUb+6Ra1KCv97jeJlfTaq9KkSUecIzEuRsMy0zQsM01ScE1ZwbYq5a4v1fx1pZq/oVT5xZWau26H5q7bsffrstKTldMwZTm8R5oy2yUfcQagKVDAAAAHtbW8Wp83TCt+kVei8ur6vbclxvk0KqutxvZO17g+7fcVmt8/Lx2sfElSVZWUn9+kGY0xykpPUVZ6ii7N6SpJKq2q1TcbS5W7rlS560u1aONOrS2p0tqSKr2aWyhJGtcnXb84OVtDu6c1aR6gsShgAIC9dtXW69kv1+ndRZu1YvO3R7my2iVrbJ9g4RrRI02JcTHf/Qa9eknJycGydaDkZCk7u5mS79MmOV4n9+2gk/t2kCTV1ge0fHP53inLGatK9v4ZldVWvzglW6Oy2rKQH54y9mguR/FITk6Ozc3NdR0DACJHRYU0ZYqUlyf16qW6iy/RlJU79Y9P8rStskaS1CIuRqN7ttW4Puka27u9urVNatz3zcgIvj1QaqpUVCSlpDTxX+bwlFbV6unZBXp29jpV1ARH9YZ2b6ObTs7WuN7pFDEcFmPMfGttzmF/HQUMAKLMrFnS+PFSICBbVaUPjj1Jfxt5uda26SxJGtS1tW45JVuje7Y7+CjXYXx/VVUFR758PmnaNGnMmCb+yxy5st11ev7LdXpqdoF27gpeRDAwo5VuOjlbp/XrIJ+PIoYfRgEDAPyw/Uao5nQ5Rvec9BMt7NxXktRj52bdOukUnZWTefSjQJWVwRG2/PzgtOOECc5Hvg6lqqZeL369Xk98sVbbKmslSX07purGk7I1fmAnxVDE8D0oYACAH/bkk1r1v3/XX4dfqs+yh0uS2lWW6pbZ/9Zla2Yr7oH7j+oqxXBWXefXK3M36LHP12pLebWk4NWTN47L1rnHdVZcDHuL4buOtICxCB8AokTRzt26P8/ojcvvkzU+Jdfs0nVz39S1895Scl2wcDT1VYrhJDEuRhOP76HLR3TTG/M36ZEZ+VpbUqXfvLZI//h0tW4Ym62LhmYoIfYIpmWBAzACBgARrmxXnR6Zka9nvlyn2vqAYv31umLh+/rFl6+o3a6yfXdMTpYefDBqR8AOVOcP6O2FRXpker7Wbgte1dmpVaKuPzFLlw3vdmTr4xBxmIIEAHxLdZ1fz325TpOn5+/dw+uc/um69fYfq/umNd/9ghC5SjHU+ANWU5ds1uTP8rVqa/DqznYpCbruxB66cmR3JcUzmRTNKGAAAEnBwvDmgkI98PFqFZUFpxZH92yr287qq2O7tA6bqxRDTSBg9fGKrXroszwt3RTcI6172yRN/vEQDcho5TgdXKGAAUCUs9Zq+qpi/fX9VXtHavp1aqnbzuqrE3u1+/aVjWF0lWKosdZqxuoS/fX9lVq5pULxMT7dPr6vrhndBFePIuxQwAAgiq3bVqXfv7FYcwqC5x9mtG6h357RW+cNymA/q2ZSXefXX6au0Atfr5cknXFMB9130SC1SopznAxeooABQJSavqpYt7z8jcqr69U6KU43nZStK0d2Z5G4R6Yt2azfv75YFTX1ymjdQg/9eLCGdGvjOhY8cqQFjE1NACBMWWs1eXq+fvrsPJVX1+vUfh30+W9P0rUnZFG+PDR+YCdNvfkEDerSSpt27talj32lx2euUSAQ+gMccIcCBgBhqLKmXje8uEB/+3CVrJV+dWpvPX7VUKa/HOnWNkmv/Wy0rh3TQ/UBq7unrdSk5+ZpR1Wt62gIURQwAAgzBduqdMHk2fpg2RalJsTqqWtydMupvVjr5Vh8rE93ntNfT16do9ZJcZq+qkRnPThTc9Zudx0NIYgCBgBh5LOVW3Xuw7OUV1yp7PYpevum43VKvw6uY2E/p/bvoGk3n6Cc7m20tbxGlz/xtR76NE9+piSxHwoYAISBQMDqn5/madJzuaqortcZx3TQWzcer6x0to4IRZ1bt9Ar143UjSf1lJX0949X6+qn56i4otp1NIQIChgAhLiK6jpd/+J83f/xaknSrWf00aNXDFVKAjuwh7LYGJ9uPaOvnvvJcLVLidfs/O0a/+AX+iKvxHU0hAAKGACEsPziSp0/ebY+Xr5VqYmxevqaYbrxpGzWe4WRE3una9rNJ2h0z7baVlmrq5+eq//7cJXq/QHX0eAQBQwAQtTHy7fq/MmztaakSr07pOjdm8bopL7tXcfCEWjfMlEvTBqhX53aW0bSw9PzdfkTX2tz2W7X0eAIBQwAQkwgYPXAx6v1X8/nqrKmXuMHdtR/fn68Mtslu46GoxDjM7rl1F566dqRap+aoHnrSjX+wS/02cqtrqPBAQoYAISQ8uo6XfdCrh78NE/GSL87s48m/3iIklnvFTFG9Wyr9285QWN7p6t0V51++myu/jJ1OVdJRhkKGACEiPziCp3/8Gx9sqJYLRNj9czEYfr5uGwOeI5AbVMS9MzEYbrtrL6K8Rk98UWBbn19ESUsivArFQCEgA+XbdGvpyxUVa1ffTum6l9XDVX3tkw5RjKfz+hnY3tqUJfWmvTcPL25YJMk6W8XD1IMF1lEPCcjYMaYXxljlhljlhpjXjbGJLrIAQCuBQJWf/9ola5/Yb6qav06+9hOevPnoylfUWRUz7Z6ZuIwJcXH6M0FmxgJixKeFzBjTIakmyXlWGsHSIqRdJnXOQDAtTp/QDf+e4Ee+ixfPiP94ay+evjywUqKZ3I
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f32bab3cad0>]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAFoCAYAAADw0EcgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3zV1eH/8fe5mWQRICHsmQACskSGoFVxUkerVrQOVNS2vzqodljbandrh35rh1XBuhXrHtSFoqKyQYaMhB0SSELInjf3/P64UREZCSSfc8fr+XjwSHLvTe7bj5fLO+dzzvkYa60AAADgHZ/rAAAAANGGAgYAAOAxChgAAIDHKGAAAAAeo4ABAAB4jAIGAADgsXYrYMaYh4wxRcaYNfvc1tkY85YxJrf5Y6f2en4AAIBQ1Z4jYA9LOmu/226TNM9amyNpXvPXAAAAUcW050asxph+kl611g5v/nqDpJOttYXGmO6S5ltrB7dbAAAAgBDk9RywLGttoSQ1f+zq8fMDAAA4F+s6wMEYY66XdL0kJScnHzdkyBDHiQAAoaqusUm5RVUyknKyUpUQyxozeGPZsmUl1trM1n6f1wVstzGm+z6nIIsO9kBr7QOSHpCksWPH2qVLl3qVEQAQZqY/tFhVG4s1fWJf/er84a7jIIoYY7Ydyfd5/SvCy5KmN38+XdJLHj8/ACDCLMgt0Xsbi5WaEKubpuS4jgO0SHtuQ/GUpI8lDTbG5BtjZkj6o6TTjTG5kk5v/hoAgCPSFLD63dx1kqTvnTJQXVISHCcCWqbdTkFaay89yF1T2us5AQDR5YUVO7WusEI9Oibqmkn9XccBWoxZigCAsFTX2KS/vrlBknTrGYOVGBfjOBHQchQwAEBYmr1giwrL6zS0e5q+Obqn6zhAq1DAAABhp6SqXvfN3yRJ+tnXj5HPZxwnAlqHAgYACDv3zstVVb1fJw/O1KTsDNdxgFajgAEAwsrm4io9uWi7fEb66dnHuI4DHBEKGAAgrNz1+nr5A1YXj+2twd1SXccBjggFDAAQNhZvKdUba3erQ1yMbjl9kOs4wBGjgAEAwoK1Vr9v3nT1upMGqGtaouNEwJGjgAEAwsJrqwu1ckeZMlIS9J2TBriOAxwVChgAIOTV+5t01+vrJUm3nD5IyQntdiEXwBMUMABAyHvs423aUVqr7K4punhsL9dxgKPGrxAAgJBlrdVrqwv1t7dzJUm3Tx2i2BjGDhD+KGAAgJBUUFarX7y4RvPWF0mSzh7eTacM7uo4FdA2KGAAgJDSFLB6fOE2/en19apuaFJqQqxumzpElx7fR8ZwySFEBgoYACBkbNhVqdueX6UV28skSWcN66ZfnT9MWWw5gQhDAQMAOFfX2KR/vZun+97bpMYmq66pCfr1+cN11vBurqMB7YICBgBwavGWUt32/CptLq6WJF02vo9+cvYQpSXGOU4GtB8KGADAifLaRv3xf+v11OLtkqSBmcn6wwUjNK5/Z8fJgPZHAQMAeO71NYW646W1KqqsV1yM0fdOztb3TxmohNgY19EAT1DAAABtr7JSmjNHys2VcnKkadOk1FTtrqjTHS+t0Rtrd0uSxvRJ1x8vHKFBWamOAwPeooABANrWggXS1KlSICBVV0vJyQrccquevO953bXRr8p6v1ISYvXjswbr8vF95fOxtQSiDwUMANB2KiuD5auy8vOb8hI76adn3qglq+skSacd01W/+cZwde/YwVVKwDkKGACg7cyZIwUC8huf8jtm6aWhX9M/J05TQ2ycMqrL9KuBVlOvnMqGqoh6FDAAwBGrafBrc3G1NhVXKa+oSptyY5U37U/a2qmHGmK/2Ebikk/e0E/ffUgdf3CjRPkCKGAAgEOz1mpPdUOwYH1WtIqrtamoSjvLar/8YJMpZQY/7VFRpEHF2/WdRc9p4o7VUnKylJ3t/X8AEIIoYAAQjQ6wSjGQnKKdZbXKLapUXtEXRSuvqErltY0H/DFxMUb9uiRrYGaKsrumaGBqjLKvuFADdmxUcmPdlx/s8wVXQwKggAFAtLEffKDCaVdoY3pP5aZkaeO6Jm388CHl9shWTdOBvyc1IVYDuzaXrM/KVmayendOUlyM78sPfuSfwYn48TGfr4KUzyfNnSulpLT/fyAQBihgABChrLUqqqzXxt2V2ri7Srm7K7WxoFy5m3ep8sp/fvUbmqSM5DgN6pamnM/KVtcUZWemKDM1oeUT5ydPlgoKgiNseXnB047TplG+gH1QwAAgAgQCVku37dWnBeXaWNRctnYf5NRhQpI61ZRrUMl2DSrZpkEl25VTsl2DakrU+a7fSjNmHH2glJS2+TlAhKKAAUAYawpYzV1dqL+/k6uNu6u+cn9aYqwGd0tVTlaqBnVN0aCXn1bOP/6kjJoyHXA8Ky+v3TMDoIABQFjyNwX06qpg8dpUXC1J6t4xUV8blBksW1kpGpSVqq77nzpclymZA0+oZ5Ui4B0KGACEEX9TQC+uLNA/383TlpJg8eqZ3kHfPyVbFx7X8/AXs542TbrllgPfxypFwDMUMAAIAw3+gF5Yka9/vrtJ20trJEl9OifphlOy9c0xPb+6EvFgUlODqxH3u1YjqxQBb1HAACCE1fub9OyyfP3r3U2fb3raPyNZN5ySrfNH9VBsS4vXvlilCDhHAQOAEFTX2KQ5S3bo3+9tUmF5cEPT7K4puvHUbJ0zoodifEd5OR9WKQJOUcAAIITUNjTpycXbdf97m1RUWS9JGpyVqhunZOvs4d2PvngBCAkUMAAIAdX1fj2xaJseeH+zSqoaJElDu6fppinZOmNoN/koXkBEoYABgENV9X49+vFWzfpgi0qrg8VrRK+OuunUHE05pmvLd58HEFYoYADgyK7yOn171kJtbt7Ha1TvdN18Wo5OHpRJ8QIiHAUMABzYWVarbz+4UNv21Cina4ruOHeoJmdnULyAKEEBAwCP7Sit0aUPLlT+3loN75mmx64Zr07J8a5jAfAQBQwAPLS1pFqXPrhQheV1GtU7XY9cM04dO8S5jgXAYxQwAPBIXlGVvv3gQhVV1mts3076z9XHKzWR8gVEIwoYAHhgw65KXTZroUqqGjRhQGfNnn68khN4CwaiFX/7AaCdrS0o1+WzFmlvTaMmZ2fowSvHqkP8YS6aDSCiUcAAoB2tyi/TFbMXq7y2UScPztS/Lz9OiXGULyDaUcAAoJ0s375X02cvVmW9X6cPzdI/vj1aCbGULwAUMABoF0u2luqqhxaruqFJU4/tpr9dMlpxMT7XsQCECCfvBsaYHxhj1hpj1hhjnjLGJLrIAQDt4aNNJbpydrB8nTeyh+6lfAHYj+fvCMaYnpJukjTWWjtcUoykS7zOAQDt4f2Nxbr6P0tU29ikC8f00j3TRimW8gVgP67eFWIldTDGxEpKklTgKAcAtJl31xfp2keXqt4f0KXjeuvPF41QjI9LCwH4Ks8LmLV2p6S/SNouqVBSubX2zf0fZ4y53hiz1BiztLi42OuYANAqb6zdpesfW6oGf0BXTuyr333jWPkoXwAOwsUpyE6SzpfUX1IPScnGmMv3f5y19gFr7Vhr7djMzEyvYwJAi722qlDff2K5GpusZkzur1+dN4zyBeCQXJyCPE3SFmttsbW2UdLzkk5wkAMAjtpLK3fqxqeWyx+w+t7JA/Xzrx8jYyhfAA7NRQHbLmmCMSbJBN+lpkha5yAHAByVZ5fla+aclQpY6aYpOfrxmYMpXwBaxPN9wKy1i4wxz0paLskvaYWkB7zOAQBH46nF23X7C6tlrfTDMwbphlNzXEcCEEacbMRqrb1T0p0unhsAjta7G4r00+dXS5JunzpE15800HEiAOGGzWkAoBXKaxp123OrJEm3nD6I8gXgiFDAAKAVfv3qp9pdUa8xfdL1/VOyXccBEKYoYADQQm9/ulvPLc9XQqxPf/nWSDZZBXDEKGAA0AJlNQ26/YXgvK8fnTlYAzJTHCcCEM4oYADQAr965VMVVdbr+H6ddPWk/q7jAAhzFDAAOIw31u7SCyt2KjHOpz9fxKl
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel='x', ylabel='y')\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (_variance_) zachodzi **nadmierne dopasowanie** (_overfitting_)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie (błąd systematyczny, _bias_)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wariancja (_variance_)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.3. Regularyzacja"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(h, fJ, fdJ, theta, X, Y, \n",
" alpha=0.001, maxEpochs=1.0, batchSize=100, \n",
" adaGrad=False, logError=False, validate=0.0, valStep=100, lamb=0, trainsetsize=1.0):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na wykładzie 11)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
" \n",
" XT, YT = X, Y\n",
" \n",
" m_end=int(trainsetsize*len(X))\n",
" \n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv] \n",
" XT, YT = X[mv:m_end], Y[mv:m_end] \n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
" \n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n,1)\n",
" \n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end,:], YT[start:end,:]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
" \n",
" if logError:\n",
" errorsX.append(float(i*batchSize)/m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i*batchSize)/m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
" \n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:,0], data[:,1], n)\n",
"Y = data[:,2]"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" plt.subplot(121)\n",
" plt.scatter(X[:, 2].tolist(), X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100, cmap=plt.cm.get_cmap('prism'));\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=alpha, adaGrad=adaGrad, maxEpochs=maxEpochs, batchSize=100, \n",
" logError=True, validate=validate, valStep=1, lamb=lamb)\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02),\n",
" np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1),yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
" plt.ylim(-1,1.2);\n",
" plt.xlim(-1,1.2);\n",
" plt.legend();\n",
" plt.subplot(122)\n",
" plt.plot(err[0],err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2],err[3], lw=3, label=\"Validation error\");\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8);"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pawel/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:19: UserWarning: The following kwargs were not used by contour: 'lw'\n",
"No handles with labels found to put in legend.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHWCAYAAABOj2WsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3iT1RfA8e9tm6aDAmXvIRtKgVKKCLIpgoCKgPCTpUAVZQiIIoogrjJEZIigDEGGyBJUQGSKbGSDbJBSZENL93h/f7yMjqRJ27Rpyvk8Tx/aNzf3noRCcnLvPVdpmoYQQgghhBBCCJHTOdk7ACGEEEIIIYQQwhqSwAohhBBCCCGEcAiSwAohhBBCCCGEcAiSwAohhBBCCCGEcAiSwAohhBBCCCGEcAiSwAohhBBCCCGEcAiSwAohhBCPGaXUM0qpk0qpM0qpESZuL6OU2qyUOqCUOqyUamuPOIUQQoiUlJwDK4QQQjw+lFLOwCmgFRAC7AW6aZp2PEmbWcABTdNmKKWqA79pmlbOHvEKIYQQSckMrBBCCPF4CQDOaJp2TtO0WGAJ8FyKNhqQ9/73+YDQbIxPCCGEMMvF3gEIIYQQIluVBC4l+TkEqJ+izRjgd6XUQMATaJk9oQkhhBBpc8gEtlChQlq5cuXsHYYQQohcYv/+/Tc0TSts7ziyiTJxLeV+om7APE3TvlBKNQAWKKV8NE1LTNaRUkFAEICnp2fdqlWrZjio6Jho3G6eACAeF1xK1MxwX0IIIRyfuddmh0xgy5Urx759++wdhhBCiFxCKXXR3jFkoxCgdJKfS5F6iXAf4BkATdN2KqXcgELAtaSNNE2bBcwC8Pf31zLz2nzy9EmqLAwA4KYqQMHR8jovhBCPM3OvzbIHVgghhHi87AUqKaXKK6Vcga7A6hRt/gVaACilqgFuwPVsjVIIIYQwQRJYIYQQ4jGiaVo8MABYD5wAlmqadkwpNVYp1eF+s2FAP6XUIWAx0FuTYwuEEELkAA65hFgIIYQQGadp2m/AbymufZjk++NAw+yOSwghhLBEElghhBC5SlxcHCEhIURHR6e6zc3NjVKlSmEwGOwQmUiLMllbSggh0pbW//nCMaT3tVkSWCGEELlKSEgIXl5elCtXDqUeJUWapnHz5k1CQkIoX768HSMUQghhK+b+zxeOISOvzbIHVgghRK4SHR1NwYIFU72RUUpRsGBB+ZReCCFyEXP/5wvHkJHXZklghRBC5Drm3sjIGxzHoFIdSyuEEObJ/+2OLb1/f5LACiGEEML+5A2oEMIB3bx5k9q1a1O7dm2KFStGyZIlH/4cGxtrVR+vvPIKJ0+eTLPN9OnTWbhwoS1CdniyB1YIIYQQdifpqxDCERUsWJCDBw8CMGbMGPLkycPbb7+drI2maWiahpOT6bnDuXPnWhznzTffzHyw6RAfH4+Li4vZn629X1aQGVghhBC5jrkjS+UoUyGEENnhzJkz+Pj48Prrr+Pn58eVK1cICgrC39+fGjVqMHbs2IdtGzVqxMGDB4mPjyd//vyMGDGCWrVq0aBBA65duwbABx98wOTJkx+2HzFiBAEBAVSpUoUdO3YAEBERwYsvvkitWrXo1q0b/v7+D5PrpPbu3UuTJk2oW7cubdq04erVqw/7ff/992ncuDHTpk2je/fuDBs2jGbNmjFy5Ehu3LhBhw4d8PX15amnnuLo0aMPY3vttddo1aoVr7zySpY+ryAzsELYjqbB3r2wahWEhUHFivDyy1C4sL0jE+Kx4ubmxs2bN1MV9XhQ6dDNzc2O0QkhhMgq5Ub8mmV9Xwh+Nt33OX78OHPnzuWbb74BIDg4mAIFChAfH0+zZs3o1KkT1atXT3afu3fv0qRJE4KDgxk6dChz5sxhxIgRqfrWNI09e/awevVqxo4dy7p165g6dSrFihVj+fLlHDp0CD8/v1T3i4mJYfDgwaxevZpChQqxcOFCRo0axaxZswAICwtj27ZtAHTv3p2zZ8+yceNGnJyc6N+/P/Xr12f16tX8/vvv9O7dm3379gFw4MABtm3bli2vsZLACmEL585Bhw5w4QJERurJrLs7jBgB/frB5Mng7GzvKIV4LJQqVYqQkBCuX7+e6rYHZ82JnE3myYUQuUGFChWoV6/ew58XL17M7NmziY+PJzQ0lOPHj6dKYN3d3WnTpg0AdevW5c8//zTZd8eOHR+2uXDhAgDbt2/n3XffBaBWrVrUqFEj1f1OnDjBsWPHaNmyJQAJCQnJXhe7du2arH3nzp0fLn3evn07v/6qf0gQGBhI7969iYiIAOC5557Ltg+IJYEVIrNCQ6F+fbh1CxITH12PitL/nDNHn5H9/nv7xCfEY8ZgMMg5r45INsEKIXIZT0/Ph9+fPn2ar776ij179pA/f366d+9u8ugYV1fXh987OzsTHx9vsm+j0ZiqjTXbZDRNw9fX12xinDTmlD+n7D/pzynvl5UkgRUis0aPhjt3kievSUVGwrJlMHQo1KqVvbEJIYQQQjwmMrLMN7uEhYXh5eVF3rx5uXLlCuvXr+eZZ56x6RiNGjVi6dKlPP300xw5coTjx4+nalO9enUuX77Mnj17CAgIIDY2ltOnT5ucrU2pcePGLFy4kPfee48//viDUqVKZWvi+oAksEJkRkQELFoEZj4deygmBiZNevxmYRMS4OhR/XkqXVr/EkIIC+QcWCFEbuPn50f16tXx8fHhiSeeoGHDhjYfY+DAgfTs2RNfX1/8/Pzw8fEhX758ydoYjUaWLVvGoEGDCA8PJz4+nmHDhlmVwI4dO5ZXXnkFX19f8uTJY1X15KygHLEio7+/v/Zgw7AQdnXsGDRoAOHhlttWqwYmPgnLleLj9YR94kR9KbWzs57E16kD48bB00/bO0IhklFK7dc0zd/ecTiyzL42nz53hkrz6wJwS+WnwOiLtgpNCJGLnThxgmrVqtk7jBwhPj6e+Ph43NzcOH36NIGBgZw+fTrLj7WxBVN/j+Zem3P+oxEiJzNznlem2zqyhAS9oNXWrfry6aR27oTWrWH+fOjUyT7xCSGEEELkQvfu3aNFixbEx8ejaRozZ850iOQ1vXLfIxIiO1WoAMqKyiMGA9yv9pbrTZtmOnl9ICoKevaEJk3kiCEhxENKqjgJIUSm5M+fn/3799s7jCz3mEwJCZFFXF3h9dfhfiU4s5ydYdCg7InJnjQNxo83n7wm9e23WR+PEEIIIYTIVSSBFSKzPvgAypTRk1lTPDxg5Eh44onsjcsezp7VKzJbEhUFP/6Y9fEIIYQQQohcRRJYITLLywv27IF27cDNDfLkAXd3/bq3t17IaNQoe0eZPaKj9dlma9sKIYQQQgiRDrIHVghbyJ8fli+Hq1dh/Xr92JiyZSEwEHLh5nmzSpaE2Fjr2launLWxCCEcijXlBIQQQgiZgRXClooW1QsU9e8Pbds+Xskr6DPOrVtbfieaJw8MHpw9MQkhHI6cAyuEcBRNmzZl/fr1ya5NnjyZN954I8375cmTB4DQ0FA6mTmZoWnTplg6nmzy5MlEJqk90rZtW+5Ys53LgUkCK4SwrU8/1ZdQm2M0Qo0a0Lx59sUkhHAAMgUrhHA83bp1Y8mSJcmuLVmyhG7dull1/xIlSrBs2bIMj58ygf3tt9/Inz9/hvtLj/j4+DR/NichISFT40oCK4SwLR8fWLsW8uYFT89H15XSf65bV19m/biciyuEEEKIXKtTp0788ssvxMTEAHDhwgVCQ0Np1KjRw3NZ/fz8qFmzJj///HOq+1+4cAEfHx8AoqKi6Nq1K76+vrz00ktERUU9bNe/f3/8/f2pUaMGo0ePBmDKlCmEhobSrFkzmjVrBkC5cuW4ceMGAJMmTcLHxwcfHx8mT578cLxq1arRr18/atSoQWBgYLJxHrh+/Tovvvgi9erVo169evz1118AjBkzhqCgIAIDA+nZsyfz5s2jc+fOtG/fnsDAQDRNY/jw4fj4+FCzZk1+vF+0c8uWLTRr1oz//e9/1KxZM1PP+WO2vlEIkS0aN4bQUFi0CH74Ae7dg0q
"text/plain": [
"<Figure size 1152x576 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrazenie regularyzacyjne** zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej funkcja kosztu\n",
"\n",
"$$\n",
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} h_\\theta(x^{(i)}) - y^{(i)} \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h,theta,X,y,lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0/m \\\n",
" * -np.sum(np.multiply(y, np.log(f)) + \n",
" np.multiply(1 - y, np.log(1 - f)), axis=0) \\\n",
" + lamb/(2*m) * np.sum(np.power(theta[1:] ,2))\n",
" return j\n",
"\n",
"def dJ_(h,theta,X,y,lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0/y.shape[0]*(X.T*(h(theta,X)-y))\n",
" g[1:] += lamb/m * theta[1:]\n",
" return g"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(min=0.0, max=0.5, step=0.005, value=0.01, description=r'$\\lambda$', width=300)\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0246912b4c349958ea4f5391b3ae44f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=lamb)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label='training error')\n",
" plt.plot(Lambda, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(r'$\\lambda$')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()\n",
" plt.ylim(0.2,0.8)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHmCAYAAABK9WIBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdeXhc9WHv/89Xo200Gu3yJsm2bAw2toUxwphAWUIhhNtAQwKBNAvtLbRpc2nTW5rl99yQctvnSdMkTbhJmkKewG1KSAkJDbclCXHLHhNsgzHG+25ZtvZ9JI1m5vv745xZJMuyJM+MdOT363nmOefMnJn5CidYb75nMdZaAQAAAAAw2+XM9AAAAAAAAJgMAhYAAAAA4AkELAAAAADAEwhYAAAAAIAnELAAAAAAAE8gYAEAAAAAnpDRgDXG3GyM2WuMOWCM+dw4ry82xrxgjHnLGLPDGHNLJscDAAAAAPAuk6n7wBpjfJL2SbpRUpOkLZLuttbuStnnEUlvWWv/0RhzsaTnrLVLMzIgAAAAAICnZXIGdoOkA9baQ9basKQfSbptzD5WUom7XiqpOYPjAQAAAAB4WG4GP7tG0vGU7SZJV4zZ50uSnjfG/A9JAUm/ncHxAAAAAAA8LJMBa8Z5buzxyndLetxa+zVjzJWSfmCMWWOtjY36IGPuk3SfJAUCgctWrlyZkQEDAAAAAGbWtm3b2q211eO9lsmAbZJUl7Jdq9MPEf7vkm6WJGvtZmNMoaQqSa2pO1lrH5H0iCQ1NjbarVu3ZmrMAAAAAIAZZIw5eqbXMnkO7BZJK4wx9caYfEl3SXp2zD7HJN0gScaYVZIKJbVlcEwAAAAAAI/KWMBaayOSPi3pl5J2S3rKWvuuMeYhY8yt7m7/U9K9xpi3JT0p6R6bqcsiAwAAAAA8LZOHEMta+5yk58Y898WU9V2SrsrkGAAAAAAAc0NGAxYAAAAAMmVkZERNTU0aGhqa6aFgGgoLC1VbW6u8vLxJv4eABQAAAOBJTU1NCgaDWrp0qYwZ7yYomK2stero6FBTU5Pq6+sn/b5MXsQJAAAAADJmaGhIlZWVxKsHGWNUWVk55dlzAhYAAACAZxGv3jWdPzsCFgAAAACmobu7W9/5znem9d5bbrlF3d3dE+7zxS9+UZs2bZrW589VBCwAAAAATMNEARuNRid873PPPaeysrIJ93nooYf027/929Me31SNHXMkEpnU+ya7XzoQsAAAAAAwDZ/73Od08OBBrVu3Tg888IBefPFFXX/99froRz+qtWvXSpJ+93d/V5dddplWr16tRx55JPHepUuXqr29XUeOHNGqVat07733avXq1brppps0ODgoSbrnnnv09NNPJ/Z/8MEHtX79eq1du1Z79uyRJLW1tenGG2/U+vXr9Ud/9EdasmSJ2tvbTxvr888/ryuvvFLr16/XHXfcof7+/sTnPvTQQ7r66qv14x//WNddd52+8IUv6Nprr9U3v/lNHT16VDfccIMaGhp0ww036NixY4mx/cVf/IWuv/56ffazn83cP+QxuAoxAAAAAM9b+rn/yNhnH/nyfxv3+S9/+cvauXOntm/fLkl68cUX9cYbb2jnzp2JK+t+//vfV0VFhQYHB3X55ZfrQx/6kCorK0d9zv79+/Xkk0/q0Ucf1Z133qmf/OQn+tjHPnba91VVVenNN9/Ud77zHX31q1/V9773Pf31X/+13vve9+rzn/+8fvGLX4yK5Lj29nb9zd/8jTZt2qRAIKC/+7u/09e//nV98YtflOTczubVV1+VJH33u99Vd3e3XnrpJUnSBz7wAX3iE5/QJz/5SX3/+9/X/fffr3/7t3+TJO3bt0+bNm2Sz+ebzj/WaSFgAQAAACBNNmzYMOq2MA8//LCeeeYZSdLx48e1f//+0wK2vr5e69atkyRddtllOnLkyLifffvttyf2+elPfypJevXVVxOff/PNN6u8vPy0973++uvatWuXrrrqKklSOBzWlVdemXj9Ix/5yKj9U7c3b96c+K6Pf/zj+qu/+qvEa3fccUdW41UiYAEAAAAgbQKBQGL9xRdf1KZNm7R582YVFRXpuuuuG/e2MQUFBYl1n8+XOIT4TPv5fL7EeafW2rOOyVqrG2+8UU8++eRZxzzedqrUKwdPtF+mELAAAAAAPO9Mh/lmUjAYVF9f3xlf7+npUXl5uYqKirRnzx69/vrraR/D1Vdfraeeekqf/exn9fzzz6urq+u0fTZu3Kg//dM/1YEDB3TBBRcoFAqpqalJF1544Vk//z3veY9+9KMf6eMf/7ieeOIJXX311Wn/GaaCizgBAAAAwDRUVlbqqquu0po1a/TAAw+c9vrNN9+sSCSihoYG/a//9b+0cePGtI/hwQcf1PPPP6/169fr5z//uRYuXKhgMDhqn+rqaj3++OO6++671dDQoI0bNyYuAnU2Dz/8sB577DE1NDToBz/4gb75zW+m/WeYCjOZKefZpLGx0W7dunWmhwEAAABghu3evVurVq2a6WHMqOHhYfl8PuXm5mrz5s361Kc+lbiolBeM92dojNlmrW0cb38OIQYAAAAAjzp27JjuvPNOxWIx5efn69FHH53pIWUUAQsAAAAAHrVixQq99dZbMz2MrOEcWAAAAACAJxCwAAAAAABPIGABAAAAAJ5AwAIAAAAAPIGABQAAAIAsKS4uliQ1Nzfrwx/+8Lj7XHfddTrbrUO/8Y1vKBQKJbZvueUWdXd3p2+gsxQBCwAAAABZtmjRIj399NPTfv/YgH3uuedUVlaWjqGdVSQSmXD7TKLR6Dl/NwELAAAAANPw2c9+Vt/5zncS21/60pf0ta99Tf39/brhhhu0fv16rV27Vj/72c9Oe++RI0e0Zs0aSdLg4KDuuusuNTQ06CMf+YgGBwcT+33qU59SY2OjVq9erQcffFCS9PDDD6u5uVnXX3+9rr/+eknS0qVL1d7eLkn6+te/rjVr1mjNmjX6xje+kfi+VatW6d5779Xq1at10003jfqeuLa2Nn3oQx/S5Zdfrssvv1yvvfZa4me77777dNNNN+kTn/iEHn/8cd1xxx36wAc+oJtuuknWWj3wwANas2aN1q5dq3/913+VJL344ou6/vrr9dGPflRr164953/m3AcWAAAAgPd9qTSDn90z7tN33XWX/vzP/1x/8id/Ikl66qmn9Itf/EKFhYV65plnVFJSovb2dm3cuFG33nqrjDHjfs4//uM/qqioSDt27NCOHTu0fv36xGt/+7d/q4qKCkWjUd1www3asWOH7r//fn3961/XCy+8oKqqqlGftW3bNj322GP6zW9+I2utrrjiCl177bUqLy/X/v379eSTT+rRRx/VnXfeqZ/85Cf62Mc+Nur9f/Znf6bPfOYzuvrqq3Xs2DG9733v0+7duxOf/eqrr8rv9+vxxx/X5s2btWPHDlVUVOgnP/mJtm/frrffflvt7e26/PLLdc0110iS3njjDe3cuVP19fXT++efgoAFAAAAgGm49NJL1draqubmZrW1tam8vFyLFy/WyMiIvvCFL+jll19WTk6OTpw4oZaWFi1YsGDcz3n55Zd1//33S5IaGhrU0NCQeO2pp57SI488okgkopMnT2rXrl2jXh/r1Vdf1Qc/+EEFAgFJ0u23365XXnlFt956q+rr67Vu3TpJ0mWXXaYjR46c9v5NmzZp165die3e3l719fVJkm699Vb5/f7EazfeeKMqKioS33v33XfL5/Np/vz5uvbaa7VlyxaVlJRow4YNaYlXiYAFAAAAgGn78Ic/rKefflqnTp3SXXfdJUl64okn1NbWpm3btikvL09Lly7V0NDQhJ8z3uzs4cOH9dWvflVbtmxReXm57rnnnrN+jrX2jK8VFBQk1n0+37iHEMdiMW3evHlUqMbFo3i87Ym+d+z7zgUBCwAAAMD7znCYb6bddddduvfee9Xe3q6XXnpJktTT06N58+YpLy9PL7zwgo4ePTrhZ1xzzTV64okndP3112vnzp3asWOHJGf2MxAIqLS0VC0tLfr5z3+u6667TpI
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1],1)\n",
" thetaBest, err = SGD(h, J, dJ, theta, X, Y, alpha=1, adaGrad=True, maxEpochs=2500, batchSize=100, \n",
" logError=True, validate=0.25, valStep=1, lamb=0.01, trainsetsize=m)\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16,8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label='training error')\n",
" plt.plot(M, CostCV, lw=3, label='validation error')\n",
" ax.set_xlabel(u'trainset size')\n",
" ax.set_ylabel(u'cost')\n",
" plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7AAAAHgCAYAAACcrIEcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdeXyU5b3///eVyb4ACQTZSQBRBIKs4lJWF7RH675VLV2kte3xtP3Vanu+rT329PdtrcdWT2s96lHOsVZrXVpP64psoqAsKgdBMOxhTSAJgewz1/ePe5KZLJOEMDP3LK/n4zGPXPcyM58EE/POtRlrrQAAAAAAiHUpbhcAAAAAAEBPEGABAAAAAHGBAAsAAAAAiAsEWAAAAABAXCDAAgAAAADiAgEWAAAAABAXUt0u4GQNGDDAFhUVuV0GAAAAACAC1q9fX2GtLezsWtwF2KKiIq1bt87tMgAAAAAAEWCM2R3qGkOIAQAAAABxgQALAAAAAIgLBFgAAAAAQFyIuzmwAAAAACBJTU1NKisrU319vduloBcyMzM1bNgwpaWl9fg5BFgAAAAAcamsrEx5eXkqKiqSMcbtcnASrLU6cuSIysrKVFxc3OPnMYQYAAAAQFyqr69X//79Ca9xyBij/v37n3TvOQEWAAAAQNwivMav3vzbEWABAAAAoBeqqqr0yCOP9Oq5l112maqqqrq85yc/+YmWLFnSq9dPVARYAAAAAOiFrgKs1+vt8rmvvvqq+vXr1+U99913ny688MJe13ey2tfc3Nzco+f19L5wIMACAAAAQC/cc8892r59u84++2zdddddWr58uebOnaubb75ZEydOlCRdeeWVmjp1qsaPH6/HHnus9blFRUWqqKjQrl27NG7cON1+++0aP368Lr74YtXV1UmSFi5cqBdeeKH1/nvvvVdTpkzRxIkT9emnn0qSysvLddFFF2nKlCn6+te/rpEjR6qioqJDrW+++abOPfdcTZkyRdddd52OHz/e+rr33XefLrjgAv35z3/WnDlz9KMf/UizZ8/WQw89pN27d2v+/PkqKSnR/PnztWfPntbavve972nu3Lm6++67I/dFbodViAEAAADEvaJ7/h6x1971i893ev4Xv/iFNm3apI8++kiStHz5cn3wwQfatGlT68q6Tz75pAoKClRXV6fp06frmmuuUf/+/du8zmeffaZnn31Wjz/+uK6//nq9+OKLuuWWWzq834ABA7RhwwY98sgjeuCBB/TEE0/oX/7lXzRv3jz98Ic/1Ouvv94mJLeoqKjQv/7rv2rJkiXKycnRL3/5Sz344IP6yU9+IsnZzmbVqlWSpEcffVRVVVVasWKFJOnyyy/Xbbfdpi996Ut68skndeedd+ovf/mLJGnbtm1asmSJPB5Pb76svUKABQAAAIAwmTFjRpttYR5++GG9/PLLkqS9e/fqs88+6xBgi4uLdfbZZ0uSpk6dql27dnX62ldffXXrPS+99JIkadWqVa2vv2DBAuXn53d43po1a7R582adf/75kqTGxkade+65rddvuOGGNvcHH69evbr1vW699Vb94Ac/aL123XXXRTW8SgRYAAAAAAibnJyc1vby5cu1ZMkSrV69WtnZ2ZozZ06n28ZkZGS0tj0eT+sQ4lD3eTye1nmn1tpua7LW6qKLLtKzzz7bbc2dHQcLXjm4q/sihQALAAAAIO6FGuYbSXl5eaqpqQl5vbq6Wvn5+crOztann36qNWvWhL2GCy64QM8//7zuvvtuvfnmm6qsrOxwz8yZM/Wtb31LpaWlGjNmjGpra1VWVqaxY8d2+/rnnXeennvuOd1666165plndMEFF4T9czgZLOIEAAAAAL3Qv39/nX/++ZowYYLuuuuuDtcXLFig5uZmlZSU6Mc//rFmzpwZ9hruvfdevfnmm5oyZYpee+01DR48WHl5eW3uKSws1OLFi3XTTTeppKREM2fObF0EqjsPP/ywnnrqKZWUlOjpp5/WQw89FPbP4WSYnnQ5x5Jp06bZdevWuV1G1xpqpIy87u8DAAAA0GtbtmzRuHHj3C7DVQ0NDfJ4PEpNTdXq1at1xx13tC4qFQ86+zc0xqy31k7r7H6GEIeLzyst+/+l0iVS+VbpBzuk9Gy3qwIAAACQwPbs2aPrr79ePp9P6enpevzxx90uKaIIsOGS4pE+/ZtU7u+K3/2udPpF7tYEAAAAIKGdfvrp+vDDD90uI2qYAxtOo+cH2qVvu1cHAAAAACSgiAVYY8yTxpjDxphNXdwzxxjzkTHmE2PMikjVEjVj5gXa2wmwAAAAABBOkeyBXSxpQaiLxph+kh6RdIW1dryk6yJYS3SMPF9KzXTaFdukqr3u1gMAAAAACSRiAdZau1LS0S5uuVnSS9baPf77D0eqlqhJy5JGnhc4phcWAAAAAMLGzTmwYyXlG2OWG2PWG2Nuc7GW8GEeLAAAAIAQcnNzJUn79+/Xtdde2+k9c+bMUXdbh/7mN79RbW1t6/Fll12mqqqq8BUao9wMsKmSpkr6vKRLJP3YGDO2sxuNMYuMMeuMMevKy8ujWePJGxMUYHeskLzN7tUCAAAAICYNGTJEL7zwQq+f3z7Avvrqq+rXr184SutWc3Nzl8eheL3eU35vNwNsmaTXrbUnrLUVklZKmtTZjdbax6y106y10woLC6Na5EkrPFPKG+K0G6qlfevdrQcAAABARNx999165JFHWo9/+tOf6t/+7d90/PhxzZ8/X1OmTNHEiRP117/+tcNzd+3apQkTJkiS6urqdOONN6qkpEQ33HCD6urqWu+74447NG3aNI0fP1733nuvJOnhhx/W/v37NXfuXM2dO1eSVFRUpIqKCknSgw8+qAkTJmjChAn6zW9+0/p+48aN0+23367x48fr4osvbvM+LcrLy3XNNddo+vTpmj59ut59993Wz23RokW6+OKLddttt2nx4sW67rrrdPnll+viiy+WtVZ33XWXJkyYoIkTJ+pPf/qTJGn58uWaO3eubr75Zk2cOPGUv+Zu7gP7V0m/NcakSkqXdI6kX7tYT3gY46xG/OEfnOPSJdKIc9ytCQAAAEh0P+0bwdeu7vT0jTfeqO985zv65je/KUl6/vnn9frrryszM1Mvv/yy+vTpo4qKCs2cOVNXXHGFjDGdvs7vf/97ZWdna+PGjdq4caOmTJnSeu3nP/+5CgoK5PV6NX/+fG3cuFF33nmnHnzwQS1btkwDBgxo81rr16/XU089pffff1/WWp1zzjmaPXu28vPz9dlnn+nZZ5/V448/ruuvv14vvviibrnlljbP/6d/+id997vf1QUXXKA9e/bokksu0ZYtW1pfe9WqVcrKytLixYu1evVqbdy4UQUFBXrxxRf10Ucf6eOPP1ZFRYWmT5+uWbNmSZI++OADbdq0ScXFxb37+geJWIA1xjwraY6kAcaYMkn3SkqTJGvto9baLcaY1yVtlOST9IS1NuSWO3FlzIWBALv9bWneP7tbDwAAAICwmzx5sg4fPqz9+/ervLxc+fn5GjFihJqamvSjH/1IK1euVEpKivbt26dDhw5p0KBBnb7OypUrdeedd0qSSkpKVFJS0nrt+eef12OPPabm5mYdOHBAmzdvbnO9vVWrVumqq65STk6OJOnqq6/WO++8oyuuuELFxcU6++yzJUlTp07Vrl27Ojx/yZIl2rx5c+vxsWPHVFNTI0m64oorlJWV1XrtoosuUkFBQev73nTTTfJ4PDrttNM0e/ZsrV27Vn369NGMGTPCEl6lCAZYa+1NPbjnV5J+FakaXDNqjmRSJOuT9m2Qao9K2QVuVwUAAAAgzK699lq98MILOnjwoG688UZJ0jPPPKPy8nKtX79eaWlpKioqUn19fZev01nv7M6dO/XAAw9o7dq1ys/P18KFC7t9HWttyGsZGRmtbY/H0+kQYp/Pp9WrV7cJqi1aQnFnx129b/vnnQo3hxAnrqx8aehUqWytJCvtWCZNuMbtqgAAAIDEFWKYb6TdeOONuv3221VRUaEVK1ZIkqqrqzVw4EClpaVp2bJ
"text/plain": [
"<Figure size 1152x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "amu"
}
},
"nbformat": 4,
"nbformat_minor": 4
}