uczenie-maszynowe/wyk/06_Problem_nadmiernego_dopa...

1833 lines
582 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe\n",
"# 6. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ szerokość działki, $x_2$ długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y))\n",
"\n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print(\"Algorithm does not converge!\")\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta])\n",
" return theta, logs\n",
"\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c=\"r\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth=\"2\")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv(\n",
" \"data_flats.tsv\", header=0, sep=\"\\t\", usecols=[\"price\", \"rooms\", \"sqrMetres\"]\n",
")\n",
"data = np.matrix(alldata[[\"sqrMetres\", \"price\"]])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(\n",
" m, 3 * n + 1\n",
")\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f011cc2e620>]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEnUlEQVR4nO3dd3hUZd7/8c+UNCAJhpIQCB0B6SAg0hUbrgtixbp2FAuwxeW3++w+W1zXZxXcIrDrqqy62FBAxAYISKgConTpNQk9ISF1Zn5/HDIppMwkM3OmvF/XNVdyJmdmvimE88l939/b4nK5XAIAAAAAeMxqdgEAAAAAEGoIUgAAAADgJYIUAAAAAHiJIAUAAAAAXiJIAQAAAICXCFIAAAAA4CWCFAAAAAB4iSAFAAAAAF4iSAEAAACAlwhSAAAAAOCliA5SX3/9tW666SalpqbKYrFo/vz5Xj+Hy+XSiy++qEsvvVQxMTFq2bKlnnvuOd8XCwAAACBo2M0uwEx5eXnq1auXHnzwQY0bN65Oz/HMM8/oyy+/1IsvvqgePXro9OnTOn36tI8rBQAAABBMLC6Xy2V2EcHAYrFo3rx5Gjt2rPu+wsJC/epXv9I777yjs2fPqnv37nrhhRc0YsQISdKOHTvUs2dPbd26VZ07dzancAAAAAABF9FT+2rz5JNPas2aNXr33Xf1/fff67bbbtP111+v3bt3S5IWLlyo9u3b65NPPlG7du3Utm1bPfzww4xIAQAAAGGOIFWNQ4cO6Y033tAHH3ygoUOHqkOHDvrZz36mIUOG6I033pAk7du3TwcPHtQHH3ygN998U7Nnz9bGjRt16623mlw9AAAAAH+K6DVSNdmyZYscDocuvfTSCvcXFhaqSZMmkiSn06nCwkK9+eab7vNee+019evXT7t27WK6HwAAABCmCFLVyM3Nlc1m08aNG2Wz2Sp8rFGjRpKkFi1ayG63VwhbXbt2lWSMaBGkAAAAgPBEkKpGnz595HA4dPz4cQ0dOrTKcwYPHqySkhLt3btXHTp0kCT98MMPkqQ2bdoErFYAAAAAgRXRXftyc3O1Z88eSUZwmjZtmkaOHKmkpCS1bt1a99xzj1atWqWXXnpJffr00YkTJ7R06VL17NlTN954o5xOp/r3769GjRrp5ZdfltPp1MSJE5WQkKAvv/zS5M8OAAAAgL9EdJBavny5Ro4cedH9999/v2bPnq3i4mL98Y9/1JtvvqmjR4+qadOmuuKKK/S73/1OPXr0kCQdO3ZMTz31lL788ks1bNhQN9xwg1566SUlJSUF+tMBAAAAECARHaQAAAAAoC5ofw4AAAAAXiJIAQAAAICXIq5rn9Pp1LFjxxQfHy+LxWJ2OQAAAAD8yOVy6dy5c0pNTZXV6rtxpIgLUseOHVNaWprZZQAAAAAIoMOHD6tVq1Y+e76IC1Lx8fGSjC9kQkKCJKmg2KGRLy7TuQKHYqKsWv6zEYqPjTKzTACITPn5UosWkid9kCwWKSNDiovzf11AmJq9ar9e/NLYA3PC8A568qqOJlcE+F5OTo7S0tLcOcBXIi5IlU7nS0hIcAepBEk3D+ykt9ceUrGkVQfP6/b+jFoBQMAlJEhjx0oLF0olJdWfZ7dLY8ZIyckBKw0IR5//kCNrTANJ0p2DL1VCQiOTKwL8x9fLemg2ccGt/cqC09yNR0ysBAAi3JQpksNR8zkOhzR5cmDqAcLUrsxz2p6RI0nqldZY7ZsRogBvEKQu6NUqUR2bG79A1h84rQMn80yuCAAi1JAh0owZxtQ9e6WJE3a7cf+MGdLgwebUB4SJ+ZuPut+/uXeqiZUAoYkgdYHFYtGt/coWn320iVEpADDNhAnSypXG9L3SDktWq3G8cqXxcQB15nS6tOBbI0jZrBb9qBdBCvAWQaqcm/u0lPXC1MkPNx2V0+nBYmcAgH8MHizNnSvl5kqZmcbbuXMZiQJ8YN3+0zqWXSBJGtapqZo2ijG5IiD0EKTKSU6I1bBLm0mSjp7N19p9p0yuCACguDijqQTd+QCfmf9t2bS+sX1amlgJELoIUpWUn95H0wkAABBuCood+nRLhiSpYbRN116WYnJFQGgiSFUyqmuyEmKNxc2fbs3QuYJikysCAADwna92Hte5QmN7geu7t1BctM3kioDQRJCqJDbKph9f6FxTUOzUZ1syTa4IAADAd+aVm9Z3M9P6gDojSFWBPaUAAEA4OpNXpOW7jkuSmsfHaFCHJiZXBIQuglQV2FMKAACEo0VbMlTsMLoSj+mdKltpu2IAXiNIVYE9pQAAQDiaR7c+wGcIUtVgTykAABBODp06r40Hz0iSLk1upMtaJJhcERDaCFLVYE8pAAAQTuZvLt9kopUsFqb1AfVBkKoBe0oBAIBw4HK5KmzCO+ZCh2IAdUeQqsGorslKjIuSxJ5SAAAgdH1/JFv7LjTPuqJ9klIbx5lcERD6CFI1iI2y6ce9yvaUKt0FHAAAIJSwdxTgewSpWjC9DwAAhLJih1MLvzsmSYq2W3V99xYmVwSEB4JULXq2SlSnC3tKfXPgDHtKAQCAkJK++6RO5RVJkkZ1be5etgCgfghStai8p9SH7CkFAABCSIW9o3ozrQ/wFYKUByrsKbXxCHtKAQCAkJBbWKIvt2dKkho3iNKIzs1NrggIHwQpDzRPiNXwC3tKHcsu0Br2lAIAACFg4XfHVFDslCTd2KOFou1c+gG+wr8mD93aL839Pk0nAABAKHh3/SH3+3f0T6vhTADeIkh56OpyizM/Y08pAAAQ5LYdy9Z3R7IlSd1SE9SjZaLJFQHhhSDlIfaUAgAAoeTd9Yfd7985oLUsFouJ1QDhhyDlBfaUAgAAoSC/yKH5m41ufXFRNo3pnWpyRUD4IUh5gT2lAABAKFi0JUPnCkokSTf2bKGEWPaOAnzN1CD1/PPPq3///oqPj1fz5s01duxY7dq1q8bHzJ49WxaLpcItNjY2IPWypxQAAAgF5ZtMjB9AkwnAH0wNUitWrNDEiRO1du1aLV68WMXFxbr22muVl1fzSE9CQoIyMjLct4MHDwaoYvaUAgAAwW131jltOHhGktSpeSP1bX2JyRUB4clu5ot//vnnFY5nz56t5s2ba+PGjRo2bFi1j7NYLEpJSfF3eVUq3VNq2a4T7j2lBndsakotAAAAlb37DU0mgEAIqjVS2dlGi86kpKQaz8vNzVWbNm2UlpamMWPGaNu2bdWeW1hYqJycnAq3+mJPKQAAEIwKih366MLSg2i7VeP6tDS5IiB8BU2QcjqdmjRpkgYPHqzu3btXe17nzp31+uuva8GCBXr77bfldDp15ZVX6siRqgPN888/r8TERPctLa3+84TZUwoAAASjL7Zl6sx547rkhu4puqRhtMkVAeEraILUxIkTtXXrVr377rs1njdo0CDdd9996t27t4YPH66PPvpIzZo10z//+c8qz586daqys7Pdt8OHD1d5njdiy7URLSh2atH37CkFAADMV2HvqP6tTawECH9BEaSefPJJffLJJ1q2bJlatWpV+wPKiYqKUp8+fbRnz54qPx4TE6OEhIQKN19gTykAABBMDpzM05p9pyRJ7Zo21BXta14qAaB+TA1SLpdLTz75pObNm6evvvpK7dq18/o5HA6HtmzZohYtWvihwur1aJmoS5ONPaU2HDyj/ewpBQAATFS+ycQd/dNoMgH4malBauLEiXr77bc1Z84cxcfHKzMzU5mZmcrPz3efc99992nq1Knu49///vf68ssvtW/fPm3atEn33HOPDh48qIcffjigtV+0pxSjUgAAwCTFDqd7hozdatEtfb2b4QPAe6YGqZkzZyo7O1sjRoxQixYt3Lf33nvPfc6hQ4eUkVG2BunMmTN65JFH1LVrV40ePVo5OTlavXq1LrvssoDXP7Z3S9kubCr14aYjcrCnFAAAMMHSHVk6mVsoSbrmsmQ1i48xuSIg/Jm6j5TLVXvwWL58eYXj6dOna/r06X6qyDule0p9tfO4MrILtGbvKQ3pxJ5SAAAgsN5ZX3HvKAD+FxTNJkJZxaYT9e8ICAAA4I0jZ87
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 397519.38046962]\n",
" [-841341.14146733]\n",
" [2253713.97125102]\n",
" [-244009.07081946]]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEN0lEQVR4nO3dd3yV9fn/8fcZWRAChJEQCHvvoSAbFAcuhlrFWatWFFsRW1t/9Wtrl61V0FoB2zrqHiggThBkhKks2XtmsRMSMs85vz9ucpJAxjnJObnPeD0fjzyS+5z7nHNlEO4r1+dzXRaXy+USAAAAAMBjVrMDAAAAAIBgQyIFAAAAAF4ikQIAAAAAL5FIAQAAAICXSKQAAAAAwEskUgAAAADgJRIpAAAAAPASiRQAAAAAeIlECgAAAAC8RCIFAAAAAF4K60Rq+fLluuGGG5SUlCSLxaJ58+Z5/Rwul0vPP/+8OnfurKioKLVs2VJ/+ctffB8sAAAAgIBhNzsAM+Xm5qpPnz762c9+pokTJ9boOR599FEtXLhQzz//vHr16qVTp07p1KlTPo4UAAAAQCCxuFwul9lBBAKLxaK5c+dq/Pjx7tsKCgr0u9/9Tu+//77OnDmjnj176u9//7tGjRolSdqxY4d69+6trVu3qkuXLuYEDgAAAKDOhfXSvuo88sgjWr16tT744AP9+OOPuuWWW3TNNddoz549kqQFCxaoffv2+vzzz9WuXTu1bdtW999/PxUpAAAAIMSRSFXi8OHDeuONN/Txxx9r+PDh6tChg371q19p2LBheuONNyRJ+/fv16FDh/Txxx/rrbfe0ptvvqn169fr5ptvNjl6AAAAAP4U1nukqrJlyxY5HA517ty53O0FBQVq0qSJJMnpdKqgoEBvvfWW+7zXXntNAwYM0K5du1juBwAAAIQoEqlK5OTkyGazaf369bLZbOXui42NlSS1aNFCdru9XLLVrVs3SUZFi0QKAAAACE0kUpXo16+fHA6Hjh07puHDh1d4ztChQ1VcXKx9+/apQ4cOkqTdu3dLktq0aVNnsQIAAACoW2HdtS8nJ0d79+6VZCRO06dP1+jRoxUfH6/WrVvrzjvv1MqVK/XCCy+oX79+On78uBYvXqzevXvruuuuk9Pp1KWXXqrY2Fi9+OKLcjqdmjJliuLi4rRw4UKTPzsAAAAA/hLWidTSpUs1evToi26/55579Oabb6qoqEh//vOf9dZbbyk1NVVNmzbVZZddpmeeeUa9evWSJKWlpekXv/iFFi5cqPr162vs2LF64YUXFB8fX9efDgAAAIA6EtaJFAAAAADUBO3PAQAAAMBLJFIAAAAA4KWw69rndDqVlpamBg0ayGKxmB0OAAAAAD9yuVw6e/askpKSZLX6ro4UdolUWlqakpOTzQ4DAAAAQB06cuSIWrVq5bPnC7tEqkGDBpKML2RcXJwpMSzclqFpH22WJN08oJX+cGMPU+IAgICTlye1aCF50gfJYpHS06WYGP/HBYS5Jz/5UQt+TJck/WlcD03o77uLUcDfsrOzlZyc7M4DfCXsEqmS5XxxcXGmJVLXDqinp77cp/wip5YfzFH92AayWVlmCACKi5PGj5cWLJCKiys/z26Xxo2TEhLqLDQgXJ3KLdSivWdljaqnhjERunVoF0VH2MwOC/Car7f10GzCBPUi7RrZuZkk6UROob4/eMrkiAAggEybJjkcVZ/jcEiPPVY38QBh7qMfjqjQ4ZQk3TKgFUkUcB6JlEnG9mzh/vjrrRkmRgIAAWbYMGnmTGPpnv2ChRN2u3H7zJnS0KHmxAeEEafTpXfXHnIf33FZGxOjAQILiZRJLu/WXBE2o7z49dYMOZ3MRQYAt8mTpRUrjOV7JR2WrFbjeMUK434Afrdsz3EdOZUnSRreqanaNa1vckRA4Ai7PVKBIi46QsM6NtV3u44rIztfm46eUf/Wjc0OCwACx9ChxltenpSdbeyforEEUKfeWV1ajbqTahRQDhUpE7G8DwA8EBNjNJUgiQLq1JFT57Rk1zFJUouG0bqia3OTIwICC4mUia7snuDu1vfV1nS5PGn3CwAAUAfeX3fYPYng9oGtZbdx2QiUxb8IEzWuH6nL2sdLko6cytO2tGyTIwIAAJAKih368PsjkiS71aJbByabHBEQeEikTHYNy/sAAECA+Xprhk7mFkqSru6ZqOYNok2OCAg8JFImu7pHgkpmg321Nd3cYAAAACS9s6a0ycRdNJkAKkQiZbLmDaJ1SRujW9++47nak3nW5IgAAEA425mRre8PnpYkdWoeq0Ht4k2OCAhMJFIBoOzyvq9Y3gcAAExUthp152VtZClZOgOgHBKpAHBNz0T3xyRSAADALGfzizR3Q6okqV6kTRP6tzQ5IiBwkUgFgJaNYtSnVUNJ0o70bB06mWtyRAAAIBzN25iq3EKHJGl8v5aKi44wOSIgcJFIBQiW9wEAADO5XC69s+aw+/jOQTSZAKpCIhUgxrK8DwAAmOj7g6e163zTqwFtGqt7UpzJEQGBjUQqQLRtWl9dExtIkjYfOaO0M3kmRwQAAMLJ27Q8B7xCIhVAxjKcFwAAmOD42QJ9fX6eZXz9SI3tlVjNIwCQSAWQsr+0SKQAAEBd+eiHIypyuCRJP7kkWVF2m8kRAYGPRCqAdGoeq/bN6kuSvj90SsfO5pscEQAACHUOp0vvnl/WZ7FIdwxqbXJEQHAgkQogFovF3XTC5ZIWbss0OSIAABDqluw8prQs44+3ozo3U3J8PZMjAoIDiVSAYZ8UAACoS++UbTIxmCYTgKdIpAJMj6Q4tWocI0lavf+kTucWmhwRAAAIVYdO5mrZ7uOSpFaNYzSyc3OTIwKCB4lUgCm7vM/hdGnRDpb3AQAA/3h3bekA3jsGtZHNajExGiC4kEgFoGtY3gcAAPwsv8ihj344IkmKtFn1k0tamRwREFxIpAJQv+RGSoiLkiSl7Dmhs/lFJkcEAABCzRc/puvMOeMa49peiWoSG2VyREBwIZEKQFarRdf0MJb3FTqcWrLzmMkRAQCAUPM2TSaAWjE1kXr22Wd16aWXqkGDBmrevLnGjx+vXbt2VfmYN998UxaLpdxbdHR0HUVcd8ou7/tqC8v7AACA72xNzdKmI2ckSV0TG6h/68bmBgQEIVMTqWXLlmnKlClas2aNFi1apKKiIl111VXKzc2t8nFxcXFKT093vx06dKjK84PRwHbxalI/UpK0dPcxnSssNjkiAAAQKi5seW6x0GQC8JbdzBf/+uuvyx2/+eabat68udavX68RI0ZU+jiLxaLExER/h2cqm9Wiq3ok6P11R5Rf5NSyXcc1tleL6h8IAABQhay8Is3blCpJio2ya3zfliZHBASngNojlZWVJUmKj4+v8rycnBy1adNGycnJGjdunLZt21bpuQUFBcrOzi73FizKLe+jex8AAPCBTzccVX6RU5I0sX9L1Y8y9e/qQNAKmETK6XRq6tSpGjp0qHr27FnpeV26dNHrr7+u+fPn65133pHT6dSQIUN09OjRCs9/9tln1bBhQ/dbcnKyvz4Fnxvcvonioo1fbkt2HlNBscPkiAAAQDBzuVzlmkzceRlNJoCaCphEasqUKdq6das++OCDKs8bPHiw7r77bvXt21cjR47Up59+qmbNmunVV1+t8Pwnn3xSWVlZ7rcjR474I3y/iLRbNaZ7giQpp6BYKXtOmBwRAAAIZqv3ndT+48Ze9EHt4tU5oYHJEQHBKyASqUceeUSff/65vvvuO7Vq5d0wuIiICPXr10979+6t8P6oqCjFxcWVewsmY1neBwAAfOSdtVSjAF8xNZFyuVx65JFHNHfuXC1ZskTt2rXz+jkcDoe2bNmiFi1CsxHD8E5NVS/SJklatD1TRQ6nyREBAIBglJmdr2+2ZUqSmsZG6eoeod24C/A3UxOpKVOm6J133tF7772nBg0aKCMjQxkZGcrLy3Ofc/fdd+vJJ590H//xj3/UwoULtX//fm3YsEF33nmnDh06pPvvv9+MT8HvoiNsGt21uSSjy86a/SdNjggAAASj99cdlsPpkiR
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1, x2, n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n + 1):\n",
" for i in range(m + 1):\n",
" X.append(np.multiply(np.power(x1, i), np.power(x2, (m - i))))\n",
" return np.hstack(X)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c=\"r\", marker=\"x\", s=50, label=\"Dane\")\n",
" ax.scatter(X1p, X2p, c=\"g\", marker=\"o\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" return fig\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpKUlEQVR4nO3df3hU1b3v8c9MQkI0TpALJFCHatSAtigKNYYgckuuYOkxQs9psDYqh+pB0VagKvTWeMS2WH/gfWr50VpFm1ZJ9RGRaqkGxSYBQUEqKhBRbKKSWOWQAYyZJLPvH9OMmclkMpPMr73n/XqeeZS91wxrNntm9nd/1/oum2EYhgAAAAAAUWdPdAcAAAAAwKoIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEYIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEbSE90BK/B4PPr444910kknyWazJbo7AAAAAGLIMAwdPXpUo0aNkt0eOodFwBUFH3/8sZxOZ6K7AQAAACCOGhsbdcopp4RsQ8AVBSeddJIk7wF3OBwJ7g0AAACAWHK5XHI6nb44IBQCrijoGkbocDgIuAAAAIAUEc50IopmAAAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMmCrg+tvf/qZ/+7d/06hRo2Sz2fTMM8/0+ZwtW7bo/PPPV2Zmps444ww9+uijPdqsXLlSp556qgYPHqzCwkLt2LEj+p0HAAAAkHJMFXAdP35c5557rlauXBlW+4MHD2rmzJn63//7f2v37t26+eab9YMf/EB//etffW2qqqq0aNEi3XHHHdq1a5fOPfdcTZ8+XZ988kms3gYAAACAFGEzDMNIdCf6w2azaf369br88st7bXPbbbfpueee01tvveXbNmfOHB05ckSbNm2SJBUWFuob3/iGfv3rX0uSPB6PnE6nbrrpJi1ZsiSsvrhcLuXk5KilpUUOh6P/bwoAAABA0ovk+t9UGa5Ibdu2TSUlJX7bpk+frm3btkmS3G63du7c6dfGbrerpKTE1yaYtrY2uVwuvwcAAAAABLJ0wNXU1KTc3Fy/bbm5uXK5XGptbdWnn36qzs7OoG2ampp6fd3ly5crJyfH93A6nTHp/4C53QPbDwAAAGBALB1wxcrSpUvV0tLiezQ2Nia6Sz1VVUnjxkm99a2x0bu/qiq+/QIAAABSiKUDrry8PDU3N/tta25ulsPhUFZWloYNG6a0tLSgbfLy8np93czMTDkcDr9HUnG7pYoKqb5emjq1Z9DV2OjdXl/vbUemC0CikZEHAFiUpQOuoqIibd682W/biy++qKKiIklSRkaGJkyY4NfG4/Fo8+bNvjamlJEhVVdL+fnS++/7B11dwdb773v3V1d72wNAopCRBwBYmKkCrmPHjmn37t3avXu3JG/Z9927d6uhoUGSd6jfVVdd5Ws/f/58vf/++7r11lu1b98+rVq1Sn/605+0cOFCX5tFixbpoYce0mOPPaa9e/fq+uuv1/HjxzV37ty4vreoczqlLVv8g66tW/2DrS1bvO1SBXfQgeRDRh4AYHGmCrhef/11nXfeeTrvvPMkeYOl8847TxUVFZKkQ4cO+YIvSTrttNP03HPP6cUXX9S5556r+++/X7/73e80ffp0X5uysjLdd999qqio0Pjx47V7925t2rSpRyENUwoMuoqLUzfY4g46kJzIyAMALM6063Alk6Rfh2vrVm+w1aWuTpo0KXH9iTe32xtM1dcHDza7X9QVFEh79nBRB8RbYHBVWSmVl6fuTSIAQA+t7a1ytbnkyHQoa1BWQvvCOlz4UmOj96Klu/Ly3jM9VsQddCD5kZEHAPSitqFWs6tmK3t5tvLuz1P28mzNrpqtuoa6RHctLARcVhYYTNTVBQ86UgFz2oDk53R6M1vdVVbyuQSAFLb6tdWasnaKNtZvlMfwSJI8hkcb6zfqorUXac3raxLcw74xpDAKknJIYWCw1RVM9LY9VXR//11S8TgAyYjPJwCgm9qGWk1ZO0WGeg9XbLKpZm6NikcX99omFhhSmOrcbqmkJHhQFZjpKSlJrapf3EEHkhMZeaQaKucCfVqxbYXS7Gkh26TZ0/TAqw/EqUf9Q8BlRRkZ0rJl3gIQwe4MdwVdBQXedqk0Z4k5bUDyCZZ5nzSp5zBgPqewCirnAn1qbW/Vhv0b1OHpCNmuw9Oh9fvWq7W9NU49ixwBl1WVlXmr7fWWuXE6vfvLyuLbr0TiDjqQfPqbkSc7ALNi7TkgLK42l2/OVl88hkeuNleMe9R/BFxW1lfmKtUyW9xBB5JPfzLyZAdgZlTOBcLiyHTIbgsvVLHb7HJkJkkdhSAomhEFSVk0A19iHS4g+bndoT93Xfv5PMMqWHsO6NPsqtnaWL8x5LDCdHu6SseU6qnvPhXHnlE0A/DHnDYg+YWbkSc7AKtg7TmgT4uKFqnT0xmyTaenUwsvXBinHvUPGa4oIMNlEuHeQQeQ/MgOwCq2bvUGW13q6rxD3gFIkta8vkY3PHeD0uxpfpmudHu6Oj2dWjVzleZPnB/3fpHhAoJhThtgHWQHYAVUzgX6NH/ifNXMrVHpmFLfnC67za7SMaWqmVuTkGArUmS4ooAMFwBEUSTZaLIDMCuytEDEWttb5WpzyZHpUNagrIT2hQwXAMCcIqlASHYAZkXlXKBfsgZlKTc7N+HBVqQIuAAAySGS9YmWLpUuvph19WA+/V17DoBpEXABAJJDuBUIR4+WOjulgwfJDsB8qJwLpBzmcEUBc7gAIIpCzW057TTJZut9ngvrcMEsqJwLmBpzuAAA5hWqAuErr0i/+AXZAZgflXOBlEGGKwrIcAFADISqQEh2IPnxbwTAwshwAQDMra8KhGQHklsk1SYBwOIIuAAAySVwDhcVCM0lkmqTFRVU4QNgeQRcAIDkwfpE5hdutcn8fG87spHm01eQTBAN+CHgAgAkB9Ynso7Af6+pU71z8gKD6cCiJ0h+DBcFIkbABQBIDqxPZC2hqk0SbJkTw0WBfqFKYRRQpRAAoojqdtYSqtokzCfYsF+ns/ftgEVRpRAAYF5UILSOvqpNwnwYLgpEjIALAABEH9UmrYvhokBECLgAAEB0UW3S+pxOqbLSf1tlJcEWEAQBFwAAiB6qTaYGhosCYSPgAgAA0UO1SetjuCgQEaoURgFVCgEACEC1SWuiSiEgiSqFAAAg0ag2aT0MFwX6hYALAAAAfWO4KNAvDCmMAoYUAgCAlMFwUYAhhQAAAIgRhosCESHgAgAAAIAYIeACAAAAgBgh4AIAAACAGCHgAgAAAIAYIeACAAAAgBgxZcC1cuVKnXrqqRo8eLAKCwu1Y8eOXttOnTpVNputx2PmzJm+Ntdcc02P/TNmzIjHWzGfvhYxZJFDAAAAwMd0AVdVVZUWLVqkO+64Q7t27dK5556r6dOn65NPPgna/umnn9ahQ4d8j7feektpaWn6j//4D792M2bM8Gv3xBNPxOPtmEtVlTRunNTYGHx/Y6N3f1VVfPsFAADMi5u5sDjTBVwrVqzQtddeq7lz5+rss8/WmjVrdMIJJ+iRRx4J2n7o0KHKy8vzPV588UWdcMIJPQKuzMxMv3Ynn3xyPN6OebjdUkWFVF8vTZ3aM+hqbPRur6/3tuPLEQAA9IWbuUgBpgq43G63du7cqZKSEt82u92ukpISbdu2LazXePjhhzVnzhydeOKJftu3bNmiESNGaMyYMbr++uv12Wef9foabW1tcrlcfg/Ly8iQqqul/Hzp/ff9g66uYOv99737q6tZ9BAAAITGzVykCFMFXJ9++qk6OzuVm5vrtz03N1dNTU19Pn/Hjh1666239IMf/MBv+4wZM/T73/9emzdv1i9/+Uu98soruvTSS9XZ2Rn0dZYvX66cnBzfw+l09v9NmYnTKW3Z4h90bd3qH2xt2eJtBwAAEAo3c5EibIZhGInuRLg+/vhjfeUrX9HWrVtVVFTk237rrbfqlVde0fbt20M+/7/+67+0bds2vfnmmyHbvf/++zr99NNVXV2tadOm9dj
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób,\n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0 / (1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X * theta, eps)\n",
"\n",
"\n",
"def J(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = (\n",
" -np.sum(np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0)\n",
" / m\n",
" )\n",
" if lamb > 0:\n",
" j += lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ(h, theta, X, y, lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" if lamb > 0:\n",
" g[1:] += lamb / float(y.shape[0]) * theta[1:]\n",
" return g\n",
"\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta])\n",
" return theta, errors\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n",
"print(r\"theta = {}\".format(theta))\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02), np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_539/3318422759.py:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACXNklEQVR4nOzdeVxU1fsH8M8srOKAioDmuKCCS+6kImqUlKglaQu2aJplLm1iWVZqaqWpad/KpcWlKJOs1Ex/lpioLO6au7gzLuAWjMDIwMz9/TExMjDsM3Nn+bxfL17KvWeGZy7D3Pvcc85zJIIgCCAiIiIiIiKLk4odABERERERkbNiwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhK5GIH4Az0ej2uXLmCunXrQiKRiB0OERERERFZkSAIuH37Nho3bgyptOI+LCZcFnDlyhUolUqxwyAiIiIiIhtSqVRo0qRJhW2YcFlA3bp1ARgOuEKhuLvj0iVg0CDgwgWgeXPg66+BMWPufr9xI1DJL8hpabWAu3vN9xORuEp/vhV/npW3nYiIyImo1WoolUpjHlARiSAIgg1icmpqtRq+vr7IyckxTbgAQKUCIiOBc+fubgsOBpKSAFftFUtIAKZNAxITzR8DlQqIigJmzgRiY20fHxFVTcnPt+BgID4eGD787veu/DlHREROrcLr/1KYcFlApQc8NRWIiLj7fUoK0KuX7QK0J1ot0KEDkJ5u/oKs5AVcSAhw5Ah7uojsGW8qERGRlWkKNVAXqKHwUMDLzUvscABUL+FilUJrU6kMd3xLGj7csN0VubsberaCgw0XaJGRd49F6bvliYlMtojsnVJp6NkqKT6eyRYREdVackYyhiYMhc9sHwR9GgSf2T4YmjAUKRkpYodWLUy4rKl0ApGSYj7RcDVKpeHud8ljkZpqeqx4d5zIMfCmEhERWcGSvUvQd0VfbEjfAL2gBwDoBT02pG9AnxV9sHTfUpEjrDomXNZSOtlKSjIMIyydaLjqRUnppCsigskWkaPhTSUiIrKC5IxkTNg0AQIEFOmLTPYV6YsgQMD4jeMdpqeLCZc1aLWGog/mEojSiUZUlKG9K+JQJCLHxZtKRGVVdj531fM9UTUtSFsAmVRWYRuZVIaFuxbaKKLaYcJlDe7uhgp7ISHme2uKk66QEEM7V52nxKFIRI6JN5WIykpIMBSFKu8cplIZ9ick2DYuIgejKdRg/an1ZXq2SivSF2HtybXQFGpsFFnNMeGylthYQ4W98nprlErDflcte86hSESOizeViExptYblTtLTzZ/Dis956emGdrwJQVQudYHaOGerMnpBD3WB2soR1R4TLmuq7CLDVS9COBSJyPHV5KYSh1uRs2IFXiKLUXgoIJVULUWRSqRQeFRckt0eMOEi2+JQJCLnUZ2bShxuRc6OFXiJLMLLzQsxoTGQS+UVtpNL5RjSZojdrMtVESZcZFscikTkejjcilwFK/ASWURceBx0el2FbXR6HSb2nGijiGqHCRfZHue3EbkWDrciV8IKvES11rtpbywetBgSSMr0dMmlckggweJBixHRNEKkCKuHCReJg/PbiFwLh1uRq2AFXiKLGBs2FjtH7URMaIxxTpdUIkVMaAx2jtqJsWFjRY6w6iSCIAhiB+Ho1Go1fH19kZOTA4XC/ifuERGJpmSPVjEmW+QsSvfYxscbki3eVCCqFU2hBuoCNRQeCruZs1Wd63/2cBERUe1VtQIhh1uRs2IFXiKr8XLzQqBPoN0kW9XFhIuIiGqnOhUIOdyKnBEr8BJRBZhwERFRzVWnAuGUKcD993PBc3I+rMBLRBXgHC4L4BwuInJp5oZSKZWm25s2NbTNyCi/Dee4kKPTaitOpirbT0QOg3O4iIjIdiqrQNiiBSCXl022zD2Ww63IkbECLxGZwYSLiIhqr6IFX7dvBz7+mMOtiIjIJXFIoQVwSCER0X9SUw3JVrGUFEOlNoDDrYiIyGlwSCEREdleZRUIOdzKcVW17D8REZXBhIuIiGqvdPELViB0HtUp+09ERGUw4SIiotrhgq/Oqzpl/6dNY08XEZEZTLiIiKjmuOCrc3N3BxITzSfOpRPtxEQOC3UGHD5KZHFMuIiIqOa44Kvzq6zsP9dPcx4cPkpkFaxSaAGsUkhELo8VCJ1fyR6tYky2nIdWa0im0tPN/15L/v5DQoAjR/g3TS6NVQqJiMi2WIHQ+SmVQHy86bb4eCZbzoLDR4mshgkXERERVa6ysv/k+Dh8lMgqmHARERFRxVj233WUTroiIphsEdUSEy4iIiIqH8v+ux4OHyWyKCZcREREZB7L/rsmDh8lsigmXERERGQey/67Hg4fJbI4loW3AJaFJyIip8ay/67B3PBRpbL87UQujGXhiYiIyHJY9t/5cfgokdUw4SIiIiJydRw+SmQ1HFJoARxSSERERE6Bw0eJqoRDComIiIio+jh8lMjimHARERERERFZCRMuIiIiIiIiK2HCRUREREREZCVMuIiIiIiIiKyECRcREREREZGVMOEiIiIiIiKyEiZcREREREREVuKQCdeiRYvQvHlzeHp6okePHtizZ0+5bSMjIyGRSMp8DRo0yNhm5MiRZfZHR0fb4qXUjFZbu/1EREREtsLrFnJxDpdwJSQkIC4uDtOnT8eBAwfQqVMn9O/fH9euXTPb/rfffsPVq1eNX0ePHoVMJsOTTz5p0i46Otqk3U8//WSLl1N9CQlAhw6ASmV+v0pl2J+QYNu4iIiIiErjdQuR4yVcCxYswEsvvYRRo0ahXbt2WLp0Kby9vbF8+XKz7evXr4+goCDj15YtW+Dt7V0m4fLw8DBpV69ePVu8nOrRaoFp04D0dCAysuyHl0pl2J6ebmjHO0ZEREQkFl63EAFwsIRLq9Vi//79iIqKMm6TSqWIiopCWlpalZ5j2bJlGDZsGOrUqWOyPSkpCQEBAQgNDcW4ceNw8+bNcp+joKAAarXa5Msm3N2BxEQgOBg4d870w6v4Q+vcOcP+xERDeyIiIiIx8LqFCICDJVw3btyATqdDYGCgyfbAwEBkZmZW+vg9e/bg6NGjePHFF022R0dH4/vvv8fWrVvxySefYPv27RgwYAB0Op3Z55k9ezZ8fX2NX0qlsuYvqrqUSiApyfTDKzXV9EMrKcnQjoiIiEhMvG4hgkQQBEHsIKrqypUruOeee5Camorw8HDj9smTJ2P79u3YvXt3hY9/+eWXkZaWhsOHD1fY7ty5c2jZsiUSExPRr1+/MvsLCgpQUFBg/F6tVkOpVCInJwcKhaKar6qGSt4ZKsYPLSIiIrJHvG4hJ6NWq+Hr61ul63+H6uHy9/eHTCZDVlaWyfasrCwEBQVV+Ni8vDysXr0ao0ePrvTnBAcHw9/fH2fOnDG738PDAwqFwuTL5pRKID7edFt8PD+0iIiIyP7wuoVcmEMlXO7u7ujWrRu2bt1q3KbX67F161aTHi9z1qxZg4KCAjz33HOV/pxLly7h5s2baNSoUa1jthqVChg+3HTb8OHlVwEiIiIiEguvW8iFOVTCBQBxcXH45ptv8N133+HEiRMYN24c8vLyMGrUKADAiBEjMGXKlDKPW7ZsGR577DE0aNDAZHtubi7eeust7Nq1CxcuXMDWrVsRExODVq1aoX///jZ5TVVSsnJP6YmmKSlAixZlJ6QSERERic3cdYu5QhpETkoudgDVFRsbi+vXr2PatGnIzMxE586dsXnzZmMhjYyMDEilpnnkqVOnkJycjL/++qvM88lkMhw+fBjfffcdsrOz0bhxYzz88MOYNWsWPDw8bPKaKpWQYCiXmpho+L70RNNiAQF3P7w4JpqIiIjEVjrZKr4+SUq6u53XLeTkHKpohr2qzqS5atNqDQsCpqcberEkkrLJVvEHVosWhu/PnwdCQoAjR1hilYiIiMRR8hrGXIGMkskYr1vIwTht0QyXVHINi/Pngdxc88lWcDCwfbvhKyQEmDmTH1pEREQkHnd3w/VISIj5Hqzini5et5CTYw+
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_539/3318422759.py:10: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACZp0lEQVR4nOzdeVxU1fsH8M8sMIA4gCIgCioquO+Jipr9pNwqUisqc8ss0za1zRYtW8wstcUl+6YmbWilZpol5sLmrrmLO6iAC8IIjDMwc39/TIwMDPvM3Fk+79drXsq9Z4Znhpk797nnnOdIBEEQQERERERERBYnFTsAIiIiIiIiZ8WEi4iIiIiIyEqYcBEREREREVkJEy4iIiIiIiIrYcJFRERERERkJUy4iIiIiIiIrIQJFxERERERkZXIxQ7AGej1ely5cgX169eHRCIROxwiIiIiIrIiQRBw69YtBAcHQyqtvA+LCZcFXLlyBSEhIWKHQURERERENpSRkYGmTZtW2oYJlwXUr18fgOEFVyqVd3ZcugQMGwZcuAA0bw4sWwY888ydnzduBKr4AzktrRZwd6/9fiISV9njW8nxrKLtRERk9w5sPYIvp3yDnMxcAMA9j0dh/IePw6ehsvI7uiCVSoWQkBBjHlAZiSAIgg1icmoqlQo+Pj7Iy8szTbgAICMDGDAAOHfuzrawMGD7dsBVe8Xi44GZM4GEBPOvQUYGEB0NzJ4NxMbaPj4iqp7Sx7ewMCAuDhg9+s7PrnycIyJyUAV5BVj+1k/YsORvCIIAZcP6mDR/LKKf7M+pM6VUev5fBhMuC6jyBU9JAaKi7vycnAz06WO7AO2JVgt07AikpZk/ISt9AhceDhw5wp4uInvGi0pERE7p+K40LHz2a5w/kg4A6BbdES8teQbBLYNsHou6SA2VRgWlQglPN0+b/35zapJwsUqhtWVkGK74ljZ6tGG7K3J3N/RshYUZTtAGDLjzWpS9Wp6QwGSLyN6FhBh6tkqLi2OyRUTk4Nr1CsfifXPx1IdPwE3hhgMJR/BMp+n4beFG6HQ6m8SQlJ6EEfEj4D3HG0GfBcF7jjdGxI9AcnqyTX6/pTDhsqayCURysvlEw9WEhBiufpd+LVJSTF8rXh0ncgy8qERE5LTkbnI8PmM4lh3+DJ0HtIdGrcWSaSsxrf9MpJ+8bNXfvWTvEvRf0R8b0jZAL+gBAHpBjw1pG9BvRT8s3bfUqr/fkjik0ALMdimWTbZKEoiKtrsiDkUicmycw0VE5DL0ej02fbMV37wWh8Jbargp3DB65iN49NUHIZPLLPq7ktKT0H9FfwioOE2RQILE8YmICo2qsI01cUih2LRaQ9EHcycdZXt3oqMN7V0RhyIROS5zF4/69Cnfe82eLnIlVX2fu+r3PTkFqVSK+5+9F98cnY+7hnRFkaYIy9/6ES/2eQvnj6Zb9HfNT50PmbTyJE4mlWHBrgUW/b3WwoTLGtzdDRX2wsPNX+EtSbrCww3tXHWeEociETkmXlQiKi8+3lAUqqLvsIwMw/74eNvGRWRhASH++PCPGXh1xRR4+9ZD2r6zmNz9Nfz40W/QFdd9bpe6SI31p9ajWF9cabtifTHWnlwLdZG6zr/T2phwWUtsrKHCXkW9NSEhhv2uWvac89uIHBcvKhGZ0moNy52kpZn/Div5zktLM7TjRQhycBKJBPeNHYBvjs5H7wd7oLhIhxVv/4SX+72DjFN1m9ul0qiMc7aqohf0UGlUdfp9tsA5XBZQkzGcBM5vI3IWNV3AnAuekzPjdxu5KEEQkPD9Tix6cTkK8gqh8HTHhDmjEPP8YEilNe/bURep4T3Hu1pJl1QiRf6MfFFKxXMOF9kvDkUich5VJUel93O4FTk7VuAlFyWRSHDv6Lux7PBn6HZvJ2jUWix+eQVev3c2si9eq/Hjebp5IiYiBnKpvNJ2cqkcw9sMt5t1uSrDhItsi0ORiFwPh1uRqyibdEVFMdkilxEQ4o+PN7+NFxc9DQ8vBQ5tO4ZnOk3HXyu3oaYD6qb1ngadvvL5YDq9DlN7Ta1LyDbDIYUWwCGFtcChRUSuhcOtyJWkpBiSrRLJyYYqnkQu4vKZTHwybhGOp5wCAPR/pDdeWjIRygb1q/0YS/ctxeSNkyGTykwKaMilcuj0OiwethiTekyyeOzVxSGFZP9qMhSJiBwfh1uRq2AFXiI0adUY83e8h6c+fAIyuQw716RiUpdXcWjb0Wo/xqQek5A4PhExETGQSgwpi1QiRUxEDBLHJ4qabNUUe7gsgD1cRETVxAXPyZlxMXCick7tO4s5oz7H5dOZkEgkeOSVBzHu/Vi4ubtV+zHURWqoNCooFUq7mbPFHi4iIrKt6i74ygXPyVlxMXAisyJ6tMSSA59g6NMDIQgCVs9bj5f6vIX0k9UvH+/p5olA70C7SbZqigkXERHVTU0qEHK4FTkjVuAlqpRnPQ9MXTYJs359BfUbeOP0gfOY0uN1bFm1Q+zQbIIJFxER1V5NKhDOmAHcfTcXPCfnwwq8RNXSd3gklh3+DF0HdsTtQg0+GfcV5j21COqC22KHZlWcw2UBnMNFRC6tOhUIQ0MNbdPTWaWQnBcr8BJVi06nw08frUXce6uh1wto1q4p3o6fhubtHef4zzlcRERkO1VVIGzRApDLyydb5u7L4VbkyFiBl6haZDIZnnznYXySMAsNgnxx8fglPN/zDfy1cpvYoVkFEy4iIqq7yhZ83bED+OgjDrciIiITnQe0x9KD89Dt3k7QqLX49KnF+GTcVyguKq76zg6ECRcREVlGZRUIY2OBI0cqHi4YEmLYHxtr/TiJiMhu+AX6Ys6fb2Hc+49BKpWguKgYMrlM7LAsinO4LIBzuIiIwDW2nBnnJhGRDRxNPomwTs3gVd/+y79zDhcREdlW2eIXrEDoPGpS9p+IqA46RLVxiGSrpphwERFR3XDBV+dVk7L/M2ey4AkRkRlMuIiIqPa44Ktzc3cHEhLMJ85lE+2EBA4rdAZVfUb5GSaqMSZcRERUe1zw1flVVfaf8/ScB4ePElkFi2ZYAItmEJHLY1EF58eiKM5NqzUkU2lp5v+upf/+4eGGqqL8TJMLY9EMIiKyLS746vwqK/tPjo/DR4mshgkXERERVS0jAxg92nTb6NEshuJMOHyUyCqYcBEREVHlWPbfdZRNuqKimGwR1RETLiIiIqoYy/67Hg4fJbIoJlxERERkHsv+uyYOHyWyKCZcREREZB7L/rseDh8lsjiWhbcAloUnIiKnxrL/rsHc8NGQkIq3E7kwloUnIiIiy2HZf+fH4aNEVsOEi4iIiMjVcfgokdVwSKEFcEghEREROQUOHyWqFg4pJCIiIqKa4/BRIotjwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhKmHARERERERFZiUMmXIsWLULz5s3h4eGByMhI7Nmzp8K2AwYMgEQiKXcbNmyYsc24cePK7R88eLAtnkrtaLV1209ERERkKzxvIRfncAlXfHw8pk2bhlmzZuHAgQPo3LkzBg0ahKtXr5pt/9tvvyEzM9N4O3r0KGQyGR555BGTdoMHDzZp99NPP9ni6dRcfDzQsSOQkWF+f0aGYX98vG3jIiIiIiqL5y1EjpdwzZ8/HxMnTsT48ePRrl07LF26FF5eXli+fLnZ9g0aNEBQUJDxtmXLFnh5eZVLuBQKhUk7Pz8/WzydmtFqgZkzgbQ0YMCA8gevjAzD9rQ0QzteMSIiIiKx8LyFCICDJVxarRb79+9HdHS0cZtUKkV0dDRSU1Or9RjffvstHnvsMdSrV89k+/bt2xEQEICIiAg899xzuHHjRoWPodFooFKpTG424e4OJCQAYWHAuXOmB6+Sg9a5c4b9CQmG9kRERERi4HkLEQAHS7iuX78OnU6HwMBAk+2BgYHIysqq8v579uzB0aNH8fTTT5tsHzx4MFatWoWtW7di7ty52LFjB4YMGQKdTmf2cebMmQMfHx/jLSQkpPZPqqZCQoDt200PXikppget7dsN7Yi
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix(\n",
" [\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ]\n",
")\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4)\n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5)\n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(\n",
" m, 2 * n + 1\n",
")\n",
"X5 = np.matrix(\n",
" np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)\n",
").reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlLklEQVR4nO3df3CU9Z3A8U9+lJCKG4oKgTEq/jhsxYqtymAUdeTKtZ4DMuNVz3Oo1ztPGk+Bnq3ejDrW01SvA07vBHrenDpetdpe0dOrOohVCSL+QkutR9VyylkTrdasYhol+9wfW9KLkC+/kuwmeb1mdjL77HfTT/p0Sd59nme3IsuyLAAAANiuylIPAAAAUM5EEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQEJJo+mxxx6L008/PSZMmBAVFRVx991393g8y7K44oorYvz48VFbWxszZsyIl156qTTDAgAAw1JJo2nz5s1x1FFHxY033rjdx6+//vr47ne/G8uWLYu1a9fGXnvtFTNnzozf/e53AzwpAAAwXFVkWZaVeoiIiIqKili+fHnMnj07IopHmSZMmBBf//rX4+/+7u8iIqK9vT3GjRsXt9xyS5x11lklnBYAABguqks9QG82btwYra2tMWPGjO5tdXV1MXXq1FizZk2v0dTZ2RmdnZ3d9wuFQrzzzjuxzz77REVFRb/PDQAAlE6WZfHee+/FhAkTorKyb06sK9toam1tjYiIcePG9dg+bty47se2p7m5Oa666qp+nQ0AAChvmzZtiv33379PvlfZRtPuuuyyy2LhwoXd99vb2+OAAw6ITZs2RS6XK+FkAABAf8vn89HQ0BB77713n33Pso2m+vr6iIhoa2uL8ePHd29va2uLKVOm9Pq8mpqaqKmp2WZ7LpcTTQAAMEz05aU5Zfs5TRMnToz6+vpYuXJl97Z8Ph9r166NadOmlXAyAABgOCnpkab3338/Xn755e77GzdujOeeey7GjBkTBxxwQMyfPz/+4R/+IQ477LCYOHFiXH755TFhwoTud9gDAADobyWNpqeffjpOOeWU7vtbr0WaO3du3HLLLfGNb3wjNm/eHOeff368++67ccIJJ8QDDzwQI0eOLNXIAADAMFM2n9PUX/L5fNTV1UV7e7trmgAAYIjrj7//y/aaJgAAgHIgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAw9HR0RLS1Fb8C7CHRBAAMHS0tEXPmRIwaFVFfX/w6Z07E6tWlngwYxEQTADA0LF0aMX16xL33RhQKxW2FQvH+iSdGLFtW2vmAQUs0AQCDX0tLRFNTRJZFbNnS87EtW4rbv/Y1R5yA3SKaAIDBb9GiiKqq9JqqqojFiwdmHmBIEU0AwODW0RFxzz3bHmH6uC1bIpYv9+YQwC4TTQDA4JbP/+Eaph0pFIrrAXaBaAIABrdcLqJyJ/+kqawsrgfYBaIJAOg7pfh8pNraiFmzIqqr0+uqqyPOOKO4HmAXiCYAYM+V+vORFi6M6OpKr+nqiliwYGDmAYYU0QQA7Jly+HykE06IWLIkoqJi2yNO1dXF7UuWRDQ29v8swJAjmgCA3VdOn490wQURq1YVT9Xbeo1TZWXx/qpVxccBdsMOTv4FAEjY+vlIqbf73vr5SANxlKexsXjr6Ci+S14u5xomYI+JJgBg92z9fKQdvd33//98pIEKmNpasQT0GafnAQC7x+cjAcOEaAIAdo/PRwKGCdEEAOwen48EDBOiCQDYfT4fCRgGRBMAsPt8PhIwDIgmAGDP+HwkYIjzluMAwJ7z+UjAECaaAIC+4/ORgCHI6XkAAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgISyjqaurq64/PLLY+LEiVFbWxuHHHJIXH311ZFlWalHAwAAhonqUg+Qct1118XSpUvj1ltvjSOOOCKefvrpOO+886Kuri4uuuiiUo8HAAAMA2UdTY8//njMmjUrTjvttIiIOOigg+KOO+6IJ598stfndHZ2RmdnZ/f9fD7f73MCAABDV1mfnnf88cfHypUr45e//GVERDz//PPR0tISX/ziF3t9TnNzc9TV1XXfGhoaBmpcAABgCKrIyvgCoUKhEH//938f119/fVRVVUVXV1dcc801cdlll/X6nO0daWpoaIj29vbI5XIDMTYAAFAi+Xw+6urq+vTv/7I+Pe+uu+6K73//+3H77bfHEUccEc8991zMnz8/JkyYEHPnzt3uc2pqaqKmpmaAJwUAAIaqso6mSy65JC699NI466yzIiLiyCOPjFdffTWam5t7jSYAAIC+VNbXNH3wwQdRWdlzxKqqqigUCiWaCAAAGG7K+kjT6aefHtdcc00ccMABccQRR8S6deti0aJF8Zd/+ZelHg0AABgmyvqNIN577724/PLLY/ny5fHmm2/GhAkT4uyzz44rrrgiRowYsVPfoz8uBAMAAMpTf/z9X9bR1BdEEwAADB/98fd/WV/TBAAAUGqiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0ATA8dXREtLUVvwJAgmgCYHhpaYmYMydi1KiI+vri1zlzIlavLvVkAJQp0QTA8LF0acT06RH33htRKBS3FQrF+yeeGLFsWWnnA6AsiSYAhoeWloimpogsi9iypedjW7YUt3/ta444AbAN0QTA8LBoUURVVXpNVVXE4sUDMw8Ag4ZoAmDo6+iIuOeebY8wfdyWLRHLl3tzCAB6EE0ADH35/B+uYdqRQqG4HgB+TzQBMPTlchGVO/krr7KyuB4Afk80ATD01dZGzJoVUV2dXlddHXHGGcX1APB7ogmA4WHhwoiurvSarq6IBQsGZh4ABg3RBMDwcMIJEUuWRFRUbHvEqbq6uH3JkojGxtLMB0DZEk0ADB8XXBCxalXxVL2t1zhVVhbvr1pVfBwAPmYHJ3cDwBDT2Fi8dXQU3yUvl3MNEwBJogmA4am2ViwBsFOcngcAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQELZR9Prr78ef/EXfxH77LNP1NbWxpFHHhlPP/10qccCAACGiepSD5Dy29/+NhobG+OUU06J+++/P/bbb7946aWX4lOf+lSpRwMAAIaJso6m6667LhoaGuLmm2/u3jZx4sQSTgQAAAw3ZX163n/+53/GMcccE2eeeWaMHTs2jj766LjpppuSz+ns7Ix8Pt/jBgAAsLvKOpp+9atfxdKlS+Owww6LBx98MObNmxcXXXRR3Hrrrb0+p7m5Oerq6rpvDQ0NAzgxAAAw1FRkWZaVeojejBgxIo455ph4/PHHu7dddNFF8dRTT8WaNWu2+5zOzs7o7Ozsvp/P56OhoSHa29sjl8v1+8wAAEDp5PP5qKur69O//8v6SNP48ePjM5/5TI9tn/70p+O1117r9Tk1NTWRy+V63AAAAHZXWUdTY2NjbNiwoce2X/7yl3H
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f011c936560>]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJ8klEQVR4nO3dd3RUdf7/8ddMOkkmEEqKhA6hJ1GUpYiyoqjoUpS2rj/X3f1aFpcSRcFdsIuggrCKuLvfFb/uSlOxoKKIhQ6KCZ3QIZSEnklC6sz9/TE4iEIIySR3yvNxzpycz82dmRfnOpl5OTPvazEMwxAAAAAA4IKsZgcAAAAAAG9GaQIAAACAClCaAAAAAKAClCYAAAAAqAClCQAAAAAqQGkCAAAAgApQmgAAAACgApQmAAAAAKgApQkAAAAAKkBpAgAAAIAKmFqali1bpttvv12JiYmyWCz64IMPzvu9YRiaOHGiEhISFBERoT59+mjnzp3mhAUAAAAQkEwtTYWFhUpJSdFrr712wd9PmTJFM2bM0KxZs7R27VpFRkaqb9++Ki4uruWkAAAAAAKVxTAMw+wQkmSxWLRw4UINGDBAkutdpsTERD388MN65JFHJEl5eXmKi4vT7NmzNWzYMBPTAgAAAAgUwWYHuJi9e/cqJydHffr0cW+LiYlR165dtXr16ouWppKSEpWUlLjXTqdTJ0+eVP369WWxWGo8NwAAAADzGIah/Px8JSYmymr1zAfrvLY05eTkSJLi4uLO2x4XF+f+3YVMmjRJTz31VI1mAwAAAODdsrOz1bhxY4/clteWpqoaP3680tPT3eu8vDw1adJE2dnZstlsJiYDAAA17bu9J/WHt76TYUjBVove/uM16tS4rtmxANQiu92upKQkRUdHe+w2vbY0xcfHS5Jyc3OVkJDg3p6bm6vU1NSLXi8sLExhYWG/2G6z2ShNAAD4sdNnSvW3T3fLElpHFkmP3JysHu2bmB0LgEk8+dUcrz1PU/PmzRUfH6+lS5e6t9ntdq1du1bdunUzMRkAAPA2hmHo8YWbdCTPNWG3W4v6ur9XS5NTAfAXpr7TVFBQoF27drnXe/fuVWZmpmJjY9WkSRONHj1azz77rFq3bq3mzZtrwoQJSkxMdE/YAwAAkKT532fr002u7zzHRIRo6tAUBVkZAAXAM0wtTd9//7169+7tXv/4XaR77rlHs2fP1qOPPqrCwkLdd999On36tHr27KnFixcrPDzcrMgAAMDL7D5WoCc/2upeT76jkxJiIkxMBMDfeM15mmqK3W5XTEyM8vLy+E4TAAB+prTcqTteX6VNh/IkScOvSdKkQZ1NTgXATDXx+t9rv9MEAABwKS8vyXIXphYNIzXhtvYmJwLgjyhNAADAJ63cdVxvfLtHkhQSZNGMYWmqE+q1g4EB+DBKEwAA8DknC0uVPj/TvX60b1t1vCLGvEAA/BqlCQAA+BTDMPTYexuVay+RJF3buoH+2LO5yakA+DNKEwAA8CnvrDugJVtzJUmxkaF6eXCKrIwXB1CDKE0AAMBn7Dqar2cW/XS8eGc1snEqEgA1i9IEAAB8Qkm5Q3+Zk6niMqck6e5fNdWN7eNMTgUgEFCaAACAT5iyOEvbjtglSa0bRemv/dqZnAhAoKA0AQAAr/ftjmP63xV7JUmhwVbNGJ6m8JAgk1MBCBSUJgAA4NWOF5To4fkb3OtxN7dVuwSbiYkABBpKEwAA8FqGYeixdzfqeIFrvPj1yQ11b49m5oYCEHAoTQAAwGu9vWa/lm4/KklqEBWqF+9MkcXCeHEAtYvSBAAAvFJWTr6e/WSbe/3inSlqGB1mYiIAgYrSBAAAvE5xmUMj52SotNw1Xvz33Zupd9tGJqcCEKgoTQAAwOu88Nl2ZeXmS5Laxkdr3C1tTU4EIJBRmgAAgFf5anuuZq/aJ0kKY7w4AC9AaQIAAF7jaH6xxi7Y6F7/rV87tYmLNjERAFCaAACAl3A6DY1dsFEnCkslSX3aNdLvftXU5FQAQGkCAABe4s1V+/TtjmOSpIbRYZp8R2fGiwPwCpQmAABguq2H7Zr82Xb3euqQFNWPYrw4AO9AaQIAAKYqKnVo5NwMlTpc48X/59rmurZ1Q5NTAcA5lCYAAGCq5z7dql1HCyRJHRJteqRvssmJAOB8lCYAAGCaL7bk6D9rDkiSwkOsmj4sTWHBjBcH4F0oTQAAwBS59mI99t658eJP3N5BrRpFmZgIAC6M0gQAAGqd02kofX6mTp0pkyT17RCnYVcnmZwKAC6M0gQAAGrdv1bs0cpdJyRJ8bZwvTCI8eIAvBelCQAA1KrNh/L04udZkiSLRZo6NEX1IkM9eydFRVJurusnAFQTpQkAANSaM6XlGjknQ2UOQ5L0wHUt1b1lA8/dwYoV0qBBUlSUFB/v+jlokLRypefuA0DAoTQBAIBa88yirdpzvFCS1LlxjMb0aeO5G3/9dalXL+njjyWn65xPcjpd62uvlWbN8tx9AQgolCYAAFArPtt0RHPWZUuS6oQGafqwNIUGe+ilyIoV0ogRkmFI5eXn/6683LX9z3/mHScAVUJpAgAANe5IXpHGvb/JvX7yNx3UvEGk5+5g6lQp6BLndwoKkqZN89x9AggYlCYAAFCjHE5DY+ZlKq/INV68X+cEDb6qsefuoKhI+vDDX77D9HPl5dLChQyHAHDZKE0AAKBGvbFst9bsOSlJSowJ1/MDOnl2vLjdfu47TJfidLr2B4DLQGkCAAA1JjP7tKZ+sUOSZLVIrwxLU0ydEM/eic0mWSv5ksZqde0PAJeB0gQAADznJ+dHKigp16i5GSp3usaLj+jdStc0j/X8fUZESP37S8HBFe8XHCwNHOjaHwAuA6UJAABU3wXOj/Tk6L9r/4kzkqS0JnU18obWNXf/6emSw1HxPg6HNGZMzWUA4LcoTQAAoHoucH6kj9v00LsxrnMwRVmcmj40TSFBNfiyo2dPaeZMyWL55TtOwcGu7TNnSj161FwGAH6L0gQAAKruAudHOmhrqMf7jnDv8szH09Rke0bNZ3ngAWn5ctdH9X78jpPV6lovX+76PQBUwSU+/AsAAFCBH8+PdLYwlVusGnPbI8oPj5Ik9d/yjQZmLXedH6k23uXp0cN1KSpyTcmz2fgOE4BqozQBAICq+fH8SD8Z9z2z2xB9l9RBktT4dI6e+WLm+edHqq0CExFBWQLgMXw8DwAAVM3Pzo+0PrGtpvcYLkmyOh2avugl2UpdgyA4PxIAX0ZpAgAAVfOT8yPZQ+to1O2PyGENkiSNXDVXVx3afm5fzo8EwIdRmgAAQNX85PxIE296UAfrxkuSuhzcoodWzTu3H+dHAuDj+E4TAACouvR0LdxxWh906C1Jii4p1LSPX1awce5je5wfCYCv450mAABQZQfaXakJt58rRM99/pqS7EddC86PBMBPUJoAAECVlDucGjUvQwWG6+XEIPsu/SZrheuXnB8JgB/h43kAAKBKZizdqYwDpyVJTWLr6OmnRkgv38f5kQD4HUoTAAC4bOv2ntSrX++SJAVZLZo+LFVRYcGSgilLAPwOH88DAACXJa+oTGPmZcppuNbpN7ZRWpN65oYCgBpEaQIAAJVmGIYeX7hJh04XSZK6No/VA9e1NDkVANQsShMAAKi0d9cf1Ccbj0iSbOHBmjY0VUFWi8mpAKBmUZoAAECl7D1eqCc+2uJev3BHZyXW5ftLAPwfpQkAAFxSablTo+Zm6EypQ5I0tEuSbu2UYHIqAKgdlCYAAHBJ077coY0H8yRJzRtEauLt7U1OBAC1h9IEAAAqtGr3cc36drckKSTIohnD0hQZxllLAAQOShMAALioU4WlSp+3QcbZ8eIP35SsTo1jzA0FALWM0gQAAC7IMAyNf3+TcuzFkqTuLevrvmtbmJwKAGofpQkAAFzQ3O+ytXhLjiSpbp0QTR2SKivjxQEEIEoTAAD4hV1HC/T0x1vd68l3dFZ8TLiJiQDAPJQmAABwnpJyh0bNzVBRmWu8+G+7NlHfDvEmpwIA81CaAADAeV7+Yoe2HLZLklo2jNSEfowXBxD
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, *bias*) zachodzi **niedostateczne dopasowanie** (*underfitting*)."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f00aeaea680>]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABW+klEQVR4nO3dd3hUZeL28XtKGiEVSIMEQpHem4AUlRUVFcSK6KJrFwu6qyv+XnWtiLrYBcsqrGJXxIqL9F5C7wQChEASWgohdea8f0wciECAkORM+X6uay5yTs6EG8fMzD3Pc55jMQzDEAAAAADgpKxmBwAAAAAAT0ZpAgAAAIBKUJoAAAAAoBKUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEqYWprmzZunK6+8UgkJCbJYLPr+++8rfN8wDD311FOKj49XSEiIBg4cqG3btpkTFgAAAIBfMrU0FRQUqGPHjnrnnXdO+v2XX35Zb775piZOnKilS5cqNDRUgwYNUlFRUS0nBQAAAOCvLIZhGGaHkCSLxaKpU6dq6NChklyjTAkJCfr73/+uf/zjH5Kk3NxcxcbGatKkSbrxxhtNTAsAAADAX9jNDnAqaWlpyszM1MCBA937IiIi1LNnTy1evPiUpam4uFjFxcXubafTqUOHDqlevXqyWCw1nhsAAACAeQzDUH5+vhISEmS1Vs/EOo8tTZmZmZKk2NjYCvtjY2Pd3zuZsWPH6plnnqnRbAAAAAA8W3p6uho1alQtP8tjS1NVjRkzRo888oh7Ozc3V0lJSUpPT1d4eLiJyQAAAADUtLy8PCUmJiosLKzafqbHlqa4uDhJUlZWluLj4937s7Ky1KlTp1PeLygoSEFBQSfsDw8PpzQBAAAAfqI6T83x2Os0JScnKy4uTjNnznTvy8vL09KlS9WrVy8TkwEAAADwJ6aONB05ckSpqanu7bS0NK1evVrR0dFKSkrS6NGj9fzzz6tFixZKTk7Wk08+qYSEBPcKewAAAABQ00wtTStWrNCFF17o3v7jXKSRI0dq0qRJeuyxx1RQUKC77rpLOTk5uuCCCzR9+nQFBwebFRkAAACAn/GY6zTVlLy8PEVERCg3N5dzmgAAAAAfVxPv/z32nCYAAAAA8ASUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEpQmgAAAACgEpQmAAAAAKgEpQkAAAAAKkFpAgAAAIBKUJoAAAAAoBKUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEpQmgAAAACgEpQmAAAAAKgEpQkAAAAAKkFpAgAAAIBKUJoAAAAAoBJ2swMAAAB4gzKHUyUOp0rKXLfiMqcMQ2oYFSKb1WJ2PAA1iNIEAAB8SmZukRZs3KsDB/NUbA9UicXqLjolDlfZOX77ZF8Xl/1xnMO932mc/O8LDbSpc1KUujZ23TonRSosOKB2/9EAahSlCQAAeLVSh1Mrdh7WnK3ZmrtypzbnO2v17y8ocWhB6gEtSD0gSbJapJZx4eraOFLdGkera+MoNYoKkcXCaBTgrShNAADA6+zNKdScLfs1d2u2FqYe1JHisnP+mVaLFGi3KtBmVaDdpiC79bht67Ht4/aVlDm1Zk+OsvKK3T/HaUib9uVp0748fbpktyQpJixI3ZpEqUtSlLo1iVbbhHAF2Di1HPAWlCYAAODxSsqcWrHzkOZs3a85W7K1NevISY+zGE512Jeq/mkr1Dp7p4LKShToKFWgo8z154cfKLBbFwWdpATZq1hiDMNQRk6hUnYdVsquw1qx87A2Z+ZVmM6XnV+sX9Zl6pd1mZKk4ACrOjSKVLfGUe4yFVknsEp/P4CaZzEM4xQzdH1DXl6eIiIilJubq/DwcLPjAACAM7Tn8FHN2bJfc7bs16LtB3S0xHHS46JDA9Vv30b1X/Cj+qUuV73CvJP/QLtdGjJE+uabGkztkl9UqjXpuVqx65BSdh3Wqt05px0Nax5TV90aR6lL4yh1axyl5PqhTOkDqqAm3v9TmgAAgEcoLnNoWdqh8qKUre37C056nMUidWwUqQEtG2hAyxi1jw6ULTxMcp7BuUxWq3TkiBQSUs3pK+dwGtqSma+U8hK1Ytdh7TlcWOl9okMDy6fzRalX03rq0CiCEgWcAUpTFVCaAADwXLsPHtXcrdnlo0kHVVh68tGkeqGB6n9eA/Vv2UD9WjRQVOhxU9mysqS4uDP/SzMzpdjYc0x+7rLyitzT+VJ2HdKGvXkqO9USfZK6NY7S/Rc1V//zGlCegEpQmqqA0gQAgOcoLnNoyY5DmrMlW3O37NeOAycfTbJapM5JURpwnms0qW1CuKynuhZSYaFUt65HjzSdicISh9bsyXGfG5Wy67ByC0tPOK5Dowjdf2FzDWwde+r/JoAfozRVAaUJAIBaVFgo5eVJ4eEVionDaejblD0aP2OrMvOKTnrXBmFBrtGk8xqob4v6Z7cwwrBh0o8/SmWVnDdUi+c0VQen09COA0e0eMchTV60U6nZFRe/aBUXpvsvaq7L2sVzcV3gOJSmKqA0AQBQCxYskMaPl6ZNc434WK3SkCEyHnlEM6NbaNz0zdr2pzf9NqtFXZOi1L+lqyi1ia9kNOlM/v5+/aTK3tZYLNL8+VKfPlX7O0zkdBqaviFTb81K1aZ9FRe6aNYgVKMubK6rOiZUeQVAwJdQmqqA0gQAQA2bMEEaNUqy2SqM9KQkttW4vrdoWWK7Codf1CpG13ZtpD7N6ysiJKD6ckycKN133wk5ZLdLDof07rvSPfdU399nAsMwNGtztt6clao16TkVvpcUXUf3Dmima7o0UqCd8gT/RWmqAkoTAAA16CQjPNujG+qVfn/V9JYVR3Q6J0VqzGWt1SM5uubyLFwovfaaNHXqsRGvq6+WHn7YK0eYTsUwDC1IPaC3ZqVqWdqhCt+LjwjWPf2b6YbuiQoOsJmUEDAPpakKKE0AANSg484lyg6N0ut9btKXHS+Rw3rszXrTQxl6rHiLBk16tfZWfTvFuVW+aOmOg3p7dqrmbztQYX+DsCDd1bepbuqZpNAgu0npgNpHaaoCShMAADWkfNW6fHuQ3u95jT7sNlSFgcHubzc4ckijF3ymG9b+T3aLPHbVOl+xavdhvTM7Vb9vyq6wP6pOgG6/IFl/7d1E4cHVOB0S8FCUpiqgNAEAUDNK9u7TlKvu1lu9b9ShOhHu/XWLj+rupd/q9hXfq05p8bE7eMj1kXzdhr25emd2qn5dn1lhXYywYLtu691Et/VJrnidK8DHUJqqgNIEAED1cjoN/bh2r16dvlnpOceWDw9wlOrmVb/o/kVfql5hxRXePPn6SL5qW1a+3pmdqh/W7NXx18wNDbTp5l6NdWffpqpfN8i8gEANoTRVAaUJAIDqs2DbAb00fZPWZ1QsRUM2zNHf53+ipNysE+/kZddH8jU7DxTo3Tmp+m5lhsqOa0/BAVYN75Gku/s1U1xEcCU/AfAulKYqoDQBAHDu1mfkatz0zScsNnBBfbse//f9apeZeuo7e/H1kXzJnsNH9d7cHfpyebpKHE73/kCbVdd1a6R7+jdTYnQdExMC1YPSVAWUJgAAqi790FG9+r8tmrZ6b4X9bRPC9fhlrdS3RQO/uD6SL8nKK9L783ZoytJdKio9Vp7sVotu6pmkMZe1VkggS5XDe1GaqoDSBADA2TtUUKK3Zm3Tp0t2qdRx7K1Co6gQPTqopa7skCCr9bjlw/3k+ki+5MCRYv1nQZr+u2inCkoc7v0tY8P09k2d1SI2zMR0QNVRmqqA0gQAwJk7WlKmjxak6b25O5RffGzUKKpOgB64qIVGnJ+kIHsloxB+dH0kX5FztEQfL9yp9+ftUGGpqzyFBNj07JC2uq5bosnpgLNHaaoCShMAAKfncBr6akW6XpuxVdn5x5YJDw6w6o4Lmuqu/k25xo+P25aVr/s/W6UtWfnufcM6N9RzQ9txcVx
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f00ae55ca00>]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUuklEQVR4nO3dd3hUZcLG4WcmvYcQ0iBAgBCQ3qUqgqKigrgWLIu9LK4Crm2/VXdtqOuCq2vZxYK9i12RKiC9GhAhgQABkpBQ0uvM+f6YOBCBECCTM+V3X9dc5Jw5Ex4cJzNP3nPe12IYhiEAAAAAwDFZzQ4AAAAAAO6M0gQAAAAA9aA0AQAAAEA9KE0AAAAAUA9KEwAAAADUg9IEAAAAAPWgNAEAAABAPShNAAAAAFAPShMAAAAA1IPSBAAAAAD1MLU0LVq0SBdffLGSkpJksVj0+eef17nfMAw9/PDDSkxMVEhIiEaOHKmMjAxzwgIAAADwSaaWptLSUvXo0UMvvvjiMe9/5pln9Pzzz+uVV17RihUrFBYWplGjRqmioqKJkwIAAADwVRbDMAyzQ0iSxWLRrFmzNHbsWEmOUaakpCTdc889+stf/iJJKiwsVHx8vGbOnKmrrrrKxLQAAAAAfIW/2QGOJysrS7m5uRo5cqRzX1RUlAYMGKBly5YdtzRVVlaqsrLSuW2323XgwAE1b95cFovF5bkBAAAAmMcwDBUXFyspKUlWa+OcWOe2pSk3N1eSFB8fX2d/fHy8875jmTp1qv7xj3+4NBsAAAAA95adna1WrVo1yvdy29J0qh588EFNmTLFuV1YWKjWrVsrOztbkZGRJiYDAAC+qqrGrjEvLlH2gXJJ0ovX9NJZHeNMTgV4p6KiIiUnJysiIqLRvqfblqaEhARJUl5enhITE5378/Ly1LNnz+M+LigoSEFBQUftj4yMpDQBAABTvL4kS3tKLbIGhWpgu+a6qE97LhsAXKwxX2Nuu05TSkqKEhISNG/ePOe+oqIirVixQgMHDjQxGQAAQMMVllfr+fmHl0z5v9GdKUyAhzF1pKmkpESZmZnO7aysLK1fv14xMTFq3bq1Jk2apMcff1ypqalKSUnRQw89pKSkJOcMewAAAO7upQWZOlRWLUm6tFdLdW0ZZXIiACfL1NK0evVqDR8+3Ln927VIEyZM0MyZM3XfffeptLRUt956qw4dOqQhQ4bo+++/V3BwsFmRAQAAGmz3wTK9sXSHJCnQ36p7zutobiAAp8Rt1mlylaKiIkVFRamwsJBrmgAAQJOa9ME6fb5+ryTptrPa6cELOpucCPB+rvj877bXNAEAAHiy9N2FzsLULDRAfzq7g8mJAJwqShMAAEAjMwxDT3z7i3P7rhGpigoJMDERgNNBaQIAAGhk83/dp+XbD0iS2jYP1TUD2picCMDpoDQBAAA0ohqbXVO/+9W5fd/5nRToz0cuwJPxCgYAAGhEH67OVua+EklS79bRuqBrgsmJAJwuShMAAEAjKams0fQ5LGQLeBtKEwAAQCP536LtKiiplCRd0DVBfdrEmJwIQGOgNAEAADSCvKIKzVi0XZLkb7XovvM7mZwIQGOhNAEAADSCaT9sVXm1TZJ07ZltlBIbZnIiAI2F0gQAAHCatuQW6+M12ZKkiCB/3TUi1eREABoTpQkAAOA0Tf1us+yG4+s/De+gmLBAcwMBaFSUJgAAgNOwJKNAC7fkS5KSooJ1w+C25gYC0OgoTQAAAKfIbjf05Lebndv3np+m4AA/ExMBcAVKEwAAwCmatW6PfskpkiR1bRmpMT1ampwIgCtQmgAAAE5BRbVNz/6wxbn91ws6y2plIVvAG1GaAAAATpLdbmj63K3KKayQJA1Pa6FBHWJNTgXAVfzNDgAAAOBJtuWX6MHP0rUy64AkyWqRHryws8mpALgSpQkAAKABqmrs+t+ibXp+fqaqauzO/XeP6KiO8REmJgPgapQmAACAE1i366Ae+DRdW/KKnfuSY0L05KXdNDS1hYnJADQFShMAAMBxlFTW6NnZW/Tmsh0yahevtVqkW4a206SRHRUSyPTigC+gNAEAABzD/F/z9LdZG7W3drIHSeqSFKmnL+uuri2jTEwGoKlRmgAAAI6QX1ypR7/+RV9t2OvcFxxg1eSRHXXTkBT5+zH5MOBrKE0AAACSDMPQx2t264lvNquwvNq5f0iHWD15aTe1bh5qYjoAZqI0AQAAn7ejoFR/nZWupdv2O/dFhwboodFnaFzvlrJYWLQW8GWUJgAA4LOqbXa9ujhLz83dqsojphEf2zNJD110hpqHB5mYDoC7oDQBAACf9PPuQ7r/03Rtzily7msZHaLHL+2q4WlxJiYD4G4oTQAAwPuUl0tFRVJkpBQSUueusqoaTfthq17/KUv2I6YRv35Qiu45r6PCgvh4BKAupn8BAADeY8kSadw4KTxcSkhw/DlunPTTT5KkH7fm67zpi/TqksOFqVNChGb9abAevvgMChOAY+InAwAA8A4vvyxNnCj5+Un22uuT7Hbpq6+0f/Z8Pf7A/zSrNMx5eKC/VZNGpuqWoe0UwDTiAOpBaQIAAJ5vyRJHYTIMqabGuduQNCttqB4752YdPKIwDWzXXE+O66aU2LBjfDMAqIvSBAAAPN+0aY4RpiMKU3ZUvP46aqIWp/R27ou0VepvV/TT5X1bMY04gAajNAEAAM9WXi598YUMu10HQiK1rXmyViR31UtnXq7ywGDnYaM3L9IjC15V3KO7JQoTgJNAaQIAAB7FZje052C5MvOLtW1fqTJ35Wvb+KnKbJ6sQyGRRx2fWJSvx394SSO2rXLsKCo6akY9AKgPpQkAALilimqbtuWXaFt+qTL3lTi+3leirILSOgvRSpJadTnq8RbDrj+u/Ub3LnpL4VXljp1Wq2MacgA4CZQmAADQeOpZH+l49pdU1i1G+SXK3FeiPYfKZRgN/6sTiwvUfn+2OhRkq/2B3RqQvVEdC3YdPsDfXxozhlEmACeN0gQAAE7fkiWOyRi++MIxzbfV6igo99wjDR4swzC0t7BCW/OKlZl3uBhtyy/RwbLqBv81/laL2saGqX2LMHWIC1f7FuHqEBeudpnpCj/nYtXbsmw2afLkRvjHAvA1lCYAAHB6jlgfybDblRfeXFtjW2vrHinjiY+0tV+eMhSmksqaE3+vWuFB/mofF35UOWodE3rsNZVaDZVeekn605+OmkVP/v6OwvTSS9LgwY3wDwbgayhNAADgpBmGofySSmXMW66tr32vref9SRmxrbU1to2KgsPrHlwpSccuTAmRwWofF6YOLcLVPi7c+WdcRNDJTwl+++1St27S9OnSrFl1R7wmT6YwAThlFsM4mbOFPU9RUZGioqJUWFioSC78BADgpB0ordKW3GJl7CvW1rxibc0rUUZe8UmdVteyulhp3dorNT5cqXERSo0LV7sWYYoIDnBN6FO4tgqAd3DF539GmgAAgCSpsLy6thQVa2tubTnaV6yCkqoGf4/EonylFuxSWsFOpRbsUseCXeqwP1vhNZVSSUnTFZiQEMoSgEZDaQIAwMdl5BXrPwsy9dWGvbI38PyTuIggdWwWqNTP3lFavqMgpRbsUmRV2fEfxPpIADwUpQkAAB+1OadI/5mfqW835hx30rnY8EClxkWoY3y4UuMjlJbgOLUuOjTQcQrcn4c7rh06EdZHAuDBKE0AAPiYjXsK9cL8DM3elFdnf0xYoC7omlBbjBxFqXl40PG/UUiIY5KFr76qO1vd77E+EgAPR2kCAMBHbMg+pBfmZ2ju5n119seGB+m2Ye10zZmtFRp4kh8NpkyRPv+8/mNYHwmAh6M0AQDg5dbsPKgX5mdo4Zb8OvvjIoJ0+1ntNb5/a4UE+p3aNx8yhPWRAHg9ShMAAF5q1Y4Den5ehhZnFNTZnxgVrDvObq8r+iYrOOAUy9KRWB8JgJejNAEA4EUMw9Dy7Y6ytGz7/jr3tYwO0cThHXRZn5YK8m+EsnSkwYMdN9ZHAuCFKE0AAHgBwzD0U+Z+PT8vQyt3HKhzX+uYUN05vIMu7d1SAX5W1wZhfSQAXoj
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (*variance*) zachodzi **nadmierne dopasowanie** (*overfitting*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie (błąd systematyczny, *bias*)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wariancja (*variance*)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"40%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.3. Regularyzacja"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(\n",
" h,\n",
" fJ,\n",
" fdJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=0.001,\n",
" maxEpochs=1.0,\n",
" batchSize=100,\n",
" adaGrad=False,\n",
" logError=False,\n",
" validate=0.0,\n",
" valStep=100,\n",
" lamb=0,\n",
" trainsetsize=1.0,\n",
"):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na następnym wykładzie)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
"\n",
" XT, YT = X, Y\n",
"\n",
" m_end = int(trainsetsize * len(X))\n",
"\n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv]\n",
" XT, YT = X[mv:m_end], Y[mv:m_end]\n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
"\n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n, 1)\n",
"\n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end, :], YT[start:end, :]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
"\n",
" if logError:\n",
" errorsX.append(float(i * batchSize) / m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i * batchSize) / m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
"\n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:, 0], data[:, 1], n)\n",
"Y = data[:, 2]\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(\n",
" X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25\n",
"):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" plt.subplot(121)\n",
" plt.scatter(\n",
" X[:, 2].tolist(),\n",
" X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100,\n",
" cmap=plt.cm.get_cmap(\"prism\"),\n",
" )\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=alpha,\n",
" adaGrad=adaGrad,\n",
" maxEpochs=maxEpochs,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=validate,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02), np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
" plt.ylim(-1, 1.2)\n",
" plt.xlim(-1, 1.2)\n",
" plt.legend()\n",
" plt.subplot(122)\n",
" plt.plot(err[0], err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2], err[3], lw=3, label=\"Validation error\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_539/1806277680.py:5: RuntimeWarning: overflow encountered in exp\n",
" y = 1.0/(1.0 + np.exp(-x))\n",
"/tmp/ipykernel_539/3540778240.py:19: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3);\n",
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSAAAAKZCAYAAACod4UiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeZzN1R/H8de9d/Zhxm7se7YsEUIJ0ZBEK0lKlshSJJItS0lJsv+SrRLSImWnEhIiW5bs69iXYYZZ7r2/P74Zhplx7517587yfj4e95F77znf87lmTPd+5nPOx2S32+2IiIiIiIiIiIiIeIDZ2wGIiIiIiIiIiIhI5qUEpIiIiIiIiIiIiHiMEpAiIiIiIiIiIiLiMUpAioiIiIiIiIiIiMcoASkiIiIiIiIiIiIeowSkiIiIiIiIiIiIeIwSkCIiIiIiIiIiIuIxSkCKiIiIiIiIiIiIxygBKSIiIiIiIiIiIh6jBKSIiIiIiIiIiIh4jBKQIiIiIiLAxIkTKV68OAEBAdSqVYuNGzemOH7s2LGULVuWwMBAihQpQq9evbh+/XoaRSsiIiKScSgBKSIiIiJZ3rx58+jduzdDhgxhy5YtVKlShfDwcM6cOZPk+K+//pq3336bIUOGsHv3bqZNm8a8efN455130jhyERERkfTPZLfb7d4OQkRERETEm2rVqkWNGjWYMGECADabjSJFitCjRw/efvvtO8Z3796d3bt3s2rVqoTH3nzzTTZs2MDatWvTLG4RERGRjMDH2wF4g81m4+TJk2TPnh2TyeTtcEREREScZrfbuXLlCgULFsRs1qaW1IiNjWXz5s30798/4TGz2UyjRo1Yv359knPq1KnDV199xcaNG6lZsyYHDx5k8eLFvPjii0mOj4mJISYmJuG+zWbjwoUL5M6dW+9HRUREJENy5v1olkxAnjx5kiJFing7DBEREZFUO3bsGIULF/Z2GBnauXPnsFqt5M+fP9Hj+fPnZ8+ePUnOadOmDefOnePBBx/EbrcTHx9Ply5dkt2CPXLkSIYOHer22EVERES8zZH3o1kyAZk9e3bA+AsKCQnxcjSSXn3/6SJmDZ5HcI4gJmwYSa6wnN4OKUM6e/w8HSv2wuJjYV7EZ/j6+Xo7JBGRTCEyMpIiRYokvK+RtPXbb7/x/vvvM2nSJGrVqsX+/ft5/fXXGT58OIMGDbpjfP/+/endu3fC/cuXL1O0aFHPvh8dedsHgZINodUXnllLREREshxn3o9myQTkjW0uISEhSkBKsl7s/xwbF25l3+aDzOg3j3e/f0tbpFyw9+gBfEy+FL2nELnz5PZ2OCIimY7+35R6efLkwWKxcPr06USPnz59mrCwsCTnDBo0iBdffJGOHTsCUKlSJaKioujcuTMDBgy4YxuSv78//v7+d1zHo+9H/W/73gjyA733FRERETdz5P2oDgwSSYbFx8Jb01/Dx9fCHz9u4rd5f3g7pAzp8D/HAShWUcceiIhI+uTn50f16tUTNZSx2WysWrWK2rVrJzknOjr6jiSjxWIBjPOQ0iWT3vqLiIiId+hdiEgKSlQqRpsBTwMwsec0Lp297OWIMp5DO44AULyCEpAiIpJ+9e7dm6lTpzJr1ix2795N165diYqKon379gC0a9cuUZOa5s2bM3nyZObOncuhQ4dYsWIFgwYNonnz5gmJyHRH1bIiIiLiJVlyC7aIM1q/3ZK132/g4PYjTO41k/5fve7tkDKUfZsPAlCmekkvRyIiIpK8Vq1acfbsWQYPHsypU6eoWrUqS5cuTWhMc/To0UQVjwMHDsRkMjFw4EBOnDhB3rx5ad68Oe+99563XoIDlIAUERER7zDZ0+0eEc+JjIwkNDSUy5cv6wxIcci/mw/QvWZ/7HY7Y1YPo9JD5b0dUoYQFxtH82wvYo23MvvIZPIVyePtkERE0g2bzUZsbGyyz/v6+qZYSaf3Mxlbmnz93g1NfL/c49B6tmfWEhGRNGe1WomLi/N2GJKJufP9qCogRRxwT/VSPNbxERZNXcmEntOY9Neo9Lu9Kh05/m8E1ngrQSGB5C2sBjQiIjfExsZy6NAhbDZbiuNy5MhBWFiYGs2IiIhIArvdzqlTp7h06ZK3Q5EswF3vR5WAFHHQyyNas3r+eg5uO8LPU1bQolsTb4eU7h355xgAxSsW0YdnEZH/2O12IiIisFgsFClS5I5GJjfGREdHc+bMGQAKFCiQ1mFKZqT/F4uIZAo3ko/58uUjKChIn7XEI9z9flQJSBEH5cgbykvDWjGx53Q+f/srajSpSsFSYd4OK107vNNIQBZTAxoRkQTx8fFER0dTsGBBgoKCkh0XGBgIwJkzZ8iXL58q78UN9AFVRCSjs1qtCcnH3Lm1y0w8y53vR9UFW8QJT7wWTuWHK3A9KoZRL03AarV6O6R07fCumxWQIiJiuPH/Dj8/v7uOvZGg1PlO4hYmvfUXEcnobrwnSOmXmCLu5K73o3oXIuIEs9lM35ndCcoeyK4/9rJ46ipvh5SuHd19AoCiFQp7ORIRkfTHke1S2lIlbqXvJxGRTEPvESStuOt7TQlIESflL5aX9iOeB2DmoLlcuXjVyxGlT1arlYgDpwAoUragl6MREREREREREW9RAlLEBc27PkrxikWIPH+FL4fO93Y46dKZo+eIj7Pi6+9L3iI6m0RERMTrtAVbREQymeLFizN27FiHx//222+YTCZ1EPcCvQsRcYHFx0LXT14G4MeJSzny31mHctOJfUb1Y4GS+ZLs8CoiIiJpTdv1RETEO0wmU4q3d99916Xrbtq0ic6dOzs8vk6dOkRERBAaGurSeuI6ZQVEXFStUWXqtKiBzWpjUq+Z2O12b4eUrpzYFwFAoTIFvByJiIiIADoDUkREvCYiIiLhNnbsWEJCQhI91qdPn4Sxdrud+Ph4h66bN29epxry+Pn5ERYWli7P0IyNjb3jMavVis1mc/pars7zJCUgRVLh1dHt8PXzYcuK7WxYtMXb4aQrJ/cbFZCFSisBKSKSFEd+cZXe3jhKRpf+PmyJiEjq2Gx2zl+N8drNZnOsECcsLCzhFhoaislkSri/Z88esmfPzpIlS6hevTr+/v6sXbuWAwcO0KJFC/Lnz0+2bNmoUaMGK1euTHTd27dgm0wmPv/8c5588kmCgoIoU6YMCxcuTHj+9i3YM2fOJEeOHCxbtozy5cuTLVs2mjRpQkRERMKc+Ph4evbsSY4cOcidOzf9+vXjpZdeomXLlim+5rVr1/LQQw8RGBhIkSJF6NmzJ1FRUYliHz58OO3atSMkJITOnTsnxLNw4UIqVKiAv78/R48e5eLFi7Rr146cOXMSFBRE06ZN2bdvX8K1kpuXnvh4OwCRjKxgqTCefL0Z33z0I1P7fcn94VXw8dU/K4BTh88AUKBkfi9HIiKSvvj6+mIymTh79ix58+ZN8jfwdrud2NhYzp49i9lsxs/PzwuRSqaTDqs9REQkdS5Gx1J9xMq7D/SQzQMbkTubv1uu9fbbbzN69GhKlixJzpw5OXbsGI899hjvvfce/v7+fPHFFzRv3py9e/dStGjRZK8zdOhQPvzwQz766CPGjx/PCy+8wJEjR8iVK1eS46Ojoxk9ejRffvklZrOZtm3b0qdPH2bPng3AqFGjmD17NjNmzKB8+fJ8+umnLFiwgAYNGiQbw4EDB2jSpAkjRoxg+vTpnD17lu7du9O9e3dmzJiRMG706NEMHjyYIUOGALBmzRqio6MZNWoUn3/+Oblz5yZfvnw8//zz7Nu3j4ULFxISEkK/fv147LHH2LVrF76+vgmv4/Z56YkyJSKp9Hz/J1k+81eO7j7BgvFLeKZ3c2+HlC6cOXoOgPzF8jg/ec8emDwZVq+GK1cgJAQaNIAuXeCee9wcqYhI2rJYLBQuXJjjx49z+PDhFMcGBQVRtGhRnaUrbqIEpIiIpF/Dhg2jcePGCfdz5cpFlSpVEu4PHz6cH374gYULF9K9e/dkr/Pyyy/z/PPPA/D+++8zbtw4Nm7cSJMmTZIcHxcXx5QpUyhVqhQA3bt3Z9iwYQnPjx8/nv79+/P
"text/plain": [
"<Figure size 1600x800 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrażenie regularyzacyjne** zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przedstawiona tu metoda regularyzacji to tzw. metoda L2 (*ridge*). Istnieją również inne metody regularyzacji, które charakteryzują się trochę innymi własnościami, np. L2 (*lasso*) lub *elastic net*. Więcej na ten temat można przeczytać np. tu:\n",
"* [L1 and L2 Regularization Methods](https://towardsdatascience.com/l1-and-l2-regularization-methods-ce25e7fc831c)\n",
"* [Ridge and Lasso Regression: L1 and L2 Regularization](https://towardsdatascience.com/ridge-and-lasso-regression-a-complete-guide-with-python-scikit-learn-e20e34bcbf0b)\n",
"* [Elastic Net Regression](https://towardsdatascience.com/elastic-net-regression-from-sklearn-to-tensorflow-3b48eee45e91)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej funkcja kosztu\n",
"\n",
"$$\n",
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} h_\\theta(x^{(i)}) - y^{(i)} \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0 / m * -np.sum(\n",
" np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0\n",
" ) + lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ_(h, theta, X, y, lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" g[1:] += lamb / m * theta[1:]\n",
" return g\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(\n",
" min=0.0, max=0.5, step=0.005, value=0.01, description=r\"$\\lambda$\", width=300\n",
")\n",
"\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"scrolled": false,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77f5c3dccbd04409b7873f54bfc535eb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(Lambda, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(r\"$\\lambda$\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKtCAYAAACuZBksAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACCYklEQVR4nOzdeZxbdaH//3eSmWTWzL53SksppaUbdLOgglgsogiuKCjLleIGwq3cC70KiCiVy/KtWhQ3QBR+cN0QL8hiBS7USqHQ0kI3SvfOPp1k1mQmOb8/TtaZ6TIzOZOZ5PV8PHKTnCQnn7lSSl/9fM7HZhiGIQAAAAAAAABIMHuyBwAAAAAAAAAgNREfAQAAAAAAAFiC+AgAAAAAAADAEsRHAAAAAAAAAJYgPgIAAAAAAACwBPERAAAAAAAAgCWIjwAAAAAAAAAsQXwEAAAAAAAAYAniIwAAAAAAAABLEB8BAAAAAAAAWCLp8fG+++7TpEmTlJWVpUWLFmn9+vVHff+qVas0bdo0ZWdnq7a2Vv/+7/+unp6eURotAAAAAAAAgOOV1Pj4+OOPa/ny5br11lv1xhtvaM6cOVq6dKkaGxsHff+jjz6qm266Sbfeequ2bt2qX//613r88cf1X//1X6M8cgAAAAAAAADHYjMMw0jWly9atEgLFizQ6tWrJUnBYFC1tbW69tprddNNNw14/zXXXKOtW7dqzZo1kWPf+ta39Oqrr+qVV14ZtXEDAAAAAAAAOLaMZH2x3+/Xhg0btGLFisgxu92uJUuWaN26dYN+5owzztDvfvc7rV+/XgsXLtR7772np59+Wl/60peO+D0+n08+ny/yPBgMqrW1VSUlJbLZbIn7gQAAAAAAAIA0YBiG2tvbVV1dLbv96AurkxYfm5ubFQgEVFFREXe8oqJC27ZtG/Qzl1xyiZqbm/X+979fhmGor69PX/3qV4+67HrlypW67bbbEjp2AAAAAAAAIN3t379fEyZMOOp7khYfh+PFF1/UHXfcoZ/+9KdatGiR3n33XV133XW6/fbbdfPNNw/6mRUrVmj58uWR5x6PRxMnTtT+/fvldrtHa+gAAAAAAABASvB6vaqtrVV+fv4x35u0+FhaWiqHw6GGhoa44w0NDaqsrBz0MzfffLO+9KUv6aqrrpIkzZo1S52dnbr66qv17W9/e9Bpni6XSy6Xa8Bxt9tNfAQAAAAAAACG6XguaZi03a6dTqfmzZsXt3lMMBjUmjVrtHjx4kE/09XVNSAwOhwOSeZacwAAAAAAAABjR1KXXS9fvlyXX3655s+fr4ULF2rVqlXq7OzUlVdeKUm67LLLVFNTo5UrV0qSLrjgAt1777067bTTIsuub775Zl1wwQWRCAkAAAAAAABgbEhqfLz44ovV1NSkW265RfX19Zo7d66eeeaZyCY0+/bti5vp+J3vfEc2m03f+c53dPDgQZWVlemCCy7QD37wg2T9CAAAAAAAAACOwGak2Xplr9ergoICeTwervkIAAAAAADGFcMw1NfXp0AgkOyhIMVlZmYecaXxUPrauNrtGgAAAAAAIF35/X7V1dWpq6sr2UNBGrDZbJowYYLy8vJGdB7iIwAAAAAAwBgXDAa1e/duORwOVVdXy+l0HtdOw8BwGIahpqYmHThwQFOnTh3RXivERwAAAAAAgDHO7/crGAyqtrZWOTk5yR4O0kBZWZn27Nmj3t7eEcVH+7HfAgAAAAAAgLEgdmNewEqJmlnLP7EAAAAAAAAALEF8BAAAAAAAAGAJ4iMAAAAAAADGjUmTJmnVqlXH/f4XX3xRNptNbW1tlo0JR8aGMwAAAAAAALDM2Wefrblz5w4pGB7Na6+9ptzc3ON+/xlnnKG6ujoVFBQk5PsxNMRHAAAAAACAcSYYNHS4y5/UMRTlOGW3J2ZTEsMwFAgElJFx7FRVVlY2pHM7nU5VVlYOd2iW8vv9cjqdcccCgYBsNtuQNxca7uesRnwEAAAAAAAYZw53+TXv+39P6hg2fGeJSvJcR33PFVdcoZdeekkvvfSSfvSjH0mSdu/erT179uhDH/qQnn76aX3nO9/R5s2b9dxzz6m2tlbLly/Xv/71L3V2dmr69OlauXKllixZEjnnpEmTdP311+v666+XZO7K/Mtf/lJPPfWUnn32WdXU1Oiee+7RJz7xCUnmsusPfehDOnz4sAoLC/XQQw/p+uuv1+OPP67rr79e+/fv1/vf/349+OCDqqqqkiT19fVp+fLlevjhh+VwOHTVVVepvr5eHo9HTzzxxBF/3ldeeUUrVqzQ66+/rtLSUn3yk5/UypUrIzM1J02apC9/+cvauXOnnnjiCX3qU5/S2Wefreuvv14PP/ywbrrpJu3YsUPvvvuuCgoKdN111+mvf/2rfD6fzjrrLP34xz/W1KlTJSnyc/T/3KRJk4bzP6dlxlYKBQAAAAAAQMr40Y9+pMWLF2vZsmWqq6tTXV2damtrI6/fdNNN+uEPf6itW7dq9uzZ6ujo0Pnnn681a9bozTff1HnnnacLLrhA+/btO+r33Hbbbfrc5z6nt956S+eff74uvfRStba2HvH9XV1duvvuu/Xb3/5W//d//6d9+/bphhtuiLx+55136pFHHtGDDz6otWvXyuv1HjU6StKuXbt03nnn6dOf/rTeeustPf7443rllVd0zTXXxL3v7rvv1pw5c/Tmm2/q5ptvjoznzjvv1K9+9Su9/fbbKi8v1xVXXKHXX39dTz75pNatWyfDMHT++eert7c37ufo/7mxhpmPAAAAAAAAsERBQYGcTqdycnIGXfr8ve99T+eee27keXFxsebMmRN5fvvtt+vPf/6znnzyyQERL9YVV1yhL3zhC5KkO+64Qz/+8Y+1fv16nXfeeYO+v7e3V/fff7+mTJkiSbrmmmv0ve99L/L6T37yE61YsUKf/OQnJUmrV6/W008/fdSfdeXKlbr00ksjMzKnTp2qH//4xzrrrLP0s5/9TFlZWZKkc845R9/61rcin3v55ZfV29urn/70p5GffefOnXryySe1du1anXHGGZKkRx55RLW1tXriiSf02c9+NvJzxH5uLCI+AgAAAAAAICnmz58f97yjo0Pf/e539dRTT6murk59fX3q7u4+5szH2bNnRx7n5ubK7XarsbHxiO/PycmJhEdJqqqqirzf4/GooaFBCxcujLzucDg0b948BYPBI55z06ZNeuutt/TII49EjhmGoWAwqN27d2v69OmD/sySeV3K2J9h69atysjI0KJFiyLHSkpKNG3aNG3duvWInxuLiI8AAAAAAADjTFGOUxu+s+TYb7R4DCPVf9fqG264Qc8//7zuvvtunXTSScrOztZnPvMZ+f1H31wnMzMz7rnNZjtqKBzs/YZhDHH08To6OvSVr3xF3/zmNwe8NnHixMjjwXbqzs7Ols029M17hvu50UR8BAAAAAAAGGfsdtsxN3sZK5xOpwKBwHG9d+3atbriiisiy507Ojq0Z88eC0c3UEFBgSoqKvTaa6/pgx/8oCRzJ+k33nhDc+fOPeLnTj/9dL3zzjs66aSTRjyG6dOnq6+vT6+++mpk2XVLS4u2b9+uGTNmjPj8o4kNZwAAAAAAAGCZSZMm6dVXX9WePXvU3Nx81BmJU6dO1Z/+9Cdt3LhRmzZt0iWXXHLU91vl2muv1cqVK/WXv/xF27dv13XXXafDhw8fdZbhjTfeqH/+85+65pprtHHjRu3cuVN/+ctfjnqtyiOZOnWqLrzwQi1btkyvvPKKNm3apC9+8YuqqanRhRdeOJIfbdQRHwEAAAAAAGCZG264QQ6HQzNmzFBZWdlRr9947733qqioSGeccYYuuOACLV26VKeffvoojtZ044036gtf+IIuu+wyLV68WHl5eVq6dGlk05jBzJ49Wy+99JJ27NihD3zgAzrttNN0yy23qLq6elhjePDBBzVv3jx9/OMf1+LFi2UYhp5++ukBS8bHOpsx0gXt44zX61VBQYE8Ho/cbneyhwMAAAAAAHBMPT092r17tyZPnnzUAAZrBINBTZ8+XZ/73Od0++23J3s4o+Jo/8w
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=0.01,\n",
" trainsetsize=m,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(M, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(\"trainset size\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSgAAAKnCAYAAACF0KMLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACNQElEQVR4nOzdd3xb9b3/8bckW/J2vDdkEkLITkgTShkNJKHNJdCyyr0hlHE7gHJTCuR3bxmlbS4t7WWWUUagFy7QEgK90DBygZYQCFnMEMgi8YzjxHY8ZUv6/SFb9pG8LelI8uv5eOghne/5nqOPghPst7/D4vF4PAIAAAAAAAAAE1jNLgAAAAAAAADAyEVACQAAAAAAAMA0BJQAAAAAAAAATENACQAAAAAAAMA0BJQAAAAAAAAATENACQAAAAAAAMA0BJQAAAAAAAAATENACQAAAAAAAMA0cWYXEIncbrfKy8uVmpoqi8VidjkAAAAAAABAVPF4PDp69KgKCwtltfY9RpKAsgfl5eUqKSkxuwwAAAAAAAAgqh04cEDFxcV99iGg7EFqaqok7x9gWlqaydUAAAAAAAAA0aW+vl4lJSW+nK0vBJQ96JzWnZaWRkAJAAAAAAAADNFAlk9kkxwAAAAAAAAApiGgBAAAAAAAAGAaAkoAAAAAAAAApmENSgAAAAAAgBji8XjU3t4ul8tldimIYTabTXFxcQNaY7I/BJQAAAAAAAAxwul0qqKiQk1NTWaXghEgKSlJBQUFstvtw7oPASUAAAAAAEAMcLvd2rt3r2w2mwoLC2W324Myug3w5/F45HQ6VV1drb1792rChAmyWoe+kiQBJQAAAAAAQAxwOp1yu90qKSlRUlKS2eUgxiUmJio+Pl5fffWVnE6nEhIShnwvNskBAAAAAACIIcMZyQYMRrC+1viKBQAAAAAAAGAaAkoAAAAAAADElNGjR+uuu+4acP+33npLFotFtbW1IasJvWMNSgAAAAAAAJjqtNNO0/Tp0wcVKvblgw8+UHJy8oD7z58/XxUVFUpPTw/K+2NwGEEJAAAAAACAiOfxeNTe3j6gvjk5OYPaKMhutys/Pz8idz13Op0BbS6XS263e9D3Gup1oUZACQAAAAAAEIPcbo9qGlpNfbjdnn7rXL58ud5++23dfffdslgsslgs2rdvn2/a9d/+9jfNmjVLDodD77zzjnbv3q1zzjlHeXl5SklJ0Zw5c/TGG28Y7uk/xdtiseiRRx7Rueeeq6SkJE2YMEEvvfSS77z/FO/Vq1dr1KhRevXVVzVp0iSlpKRo0aJFqqio8F3T3t6ua6+9VqNGjVJWVpZuvPFGXXrppVq6dGmfn/edd97RKaecosTERJWUlOjaa69VY2Ojofbbb79dy5YtU1pamq666ipfPS+99JJOOOEEORwO7d+/X0eOHNGyZcuUkZGhpKQkLV68WF9++aXvXr1dF2mY4g0AAAAAABCDjjQ5NeuXb/TfMYS2/McCZaU4+uxz991364svvtCJJ56oX/ziF5K8IyD37dsnSbrpppt05513auzYscrIyNCBAwd09tln61e/+pUcDoeefPJJLVmyRDt37tQxxxzT6/vcdttt+s1vfqPf/va3uvfee3XJJZfoq6++UmZmZo/9m5qadOedd+pPf/qTrFar/vmf/1nXX3+9nnrqKUnSHXfcoaeeekqPP/64Jk2apLvvvltr167V6aef3msNu3fv1qJFi/TLX/5Sjz32mKqrq3X11Vfr6quv1uOPP+7rd+edd+rmm2/WLbfcIkn6xz/+oaamJt1xxx165JFHlJWVpdzcXF188cX68ssv9dJLLyktLU033nijzj77bH322WeKj4/3fQ7/6yINASUAAAAAAABMk56eLrvdrqSkJOXn5wec/8UvfqEzzzzTd5yZmalp06b5jm+//Xa98MILeumll3T11Vf3+j7Lly/XxRdfLEn69a9/rXvuuUebNm3SokWLeuzf1tamBx98UOPGjZMkXX311b4AVZLuvfderVy5Uueee64k6b777tMrr7zS52ddtWqVLrnkEl133XWSpAkTJuiee+7RqaeeqgceeEAJCQmSpDPOOEM//elPfdf94x//UFtbm/7whz/4PntnMLlhwwbNnz9fkvTUU0+ppKREa9eu1fnnn+/7HN2vi0QElAAAAAAAAIhYs2fPNhw3NDTo1ltv1csvv6yKigq1t7erubm536nLU6dO9b1OTk5WWlqaDh482Gv/pKQkXzgpSQUFBb7+dXV1qqqq0kknneQ7b7PZNGvWrD7XePzwww/10Ucf+UZhSt61Nd1ut/bu3atJkyb1+Jkl7zqZ3T/Djh07FBcXp7lz5/rasrKyNHHiRO3YsaPX6yIRASUAAAAAAAAilv9u3Ndff71ef/113XnnnRo/frwSExP13e9+t8fNZLrrnPLcyWKx9Bkm9tTf4+l/Tc2+NDQ06F//9V917bXXBpzrPj29px3IExMTh7SJz1CvCycCSgAAAAAAgBiUkWTXlv9YYHoNA2G32+VyuQbUd8OGDVq+fLlvanVDQ4NvvcpwSU9PV15enj744AN94xvfkOTdIXvr1q2aPn16r9fNnDlTn332mcaPHz/sGiZNmqT29na9//77vineNTU12rlzp0444YRh3z+cCCgBAAAAAABikNVq6XeDmkgxevRovf/++9q3b59SUlJ63bhG8q7buGbNGi1ZskQWi0U///nP+xwJGSrXXHONVq1apfHjx+v444/XvffeqyNHjvQ5WvHGG2/U1772NV199dW64oorlJycrM8++0yvv/667rvvvkG9/4QJE3TOOefoyiuv1EMPPaTU1FTddNNNKioq0jnnnDPcjxdWVrMLAAAAAAAAwMh2/fXXy2az6YQTTlBOTk6f60n+/ve/V0ZGhubPn68lS5Zo4cKFmjlzZhir9brxxht18cUXa9myZZo3b55SUlK0cOFC30Y3PZk6darefvttffHFFzrllFM0Y8YM3XzzzSosLBxSDY8//rhmzZqlb3/725o3b548Ho9eeeWVgOnpkc7iGe7k+RhUX1+v9PR01dXVKS0tzexyAAAAAAAA+tXS0qK9e/dqzJgxfYZkCA23261Jkybpggsu0O233252OWHR19fcYPI1pngDAAAAAAAAg/TVV1/ptdde06mnnqrW1lbdd9992rt3r773ve+ZXVrUYYr3SNPWIh3YJG38g/SX70st9WZXBAAAAAAAEHWsVqtWr16tOXPm6OSTT9bHH3+sN954Q5MmTTK7tKjDCMqRxNUm/Was1NbY1TbzUmnsqebVBAAAAAAAEIVKSkq0YcMGs8uICYygHEls8VLOcca2ss3m1AIAAAAAAACIgHLkKZptPC4loAQAAAAAAIB5CChHmuI5xuPSzRIbuQMAAAAAAMAkBJQjTbHfCMrGg1LdAXNqAQAAAAAAwIhHQDnSZI6VEjOMbUzzBgAAAAAAgEkIKEcai0UqmmVsI6AEAAAAAACASQgoRyL/jXLYyRsAAAAAAES50aNH66677vIdWywWrV27ttf++/btk8Vi0fbt24f1vsG6z0gWZ3YBMIH/RjkVH0quNskWb049AAAAAAAAQVZRUaGMjIz+Ow7C8uXLVVtbawg+S0pKVFFRoezs7KC+10jCCMqRqGim8bi9Rar6xJxaAAAAAAAAQiA/P18OhyPk72Oz2ZSfn6+4uMgbB9jW1hbQ5nQ6h3SvoV43EASUI1FSppQ5ztjGOpQAAAAAAMQWt1tqPGTuw+3ut8yHH35YhYWFcvv1Peecc/T9739fkrR7926dc845ysvLU0pKiubMmaM33nijz/v6T/HetGmTZsyYoYSEBM2ePVvbtm0z9He5XLr88ss1ZswYJSYmauLEibr77rt952+99VY98cQTevHFF2WxWGSxWPTWW2/1OMX77bff1kknnSSHw6GCggLddNNNam9v950/7bTTdO211+qGG25QZmam8vPzdeutt/b7Z/XII49o0qRJSkhI0PHHH68//OEPvnOddTz77LM69dRTlZCQoKeeekrLly/X0qVL9atf/UqFhYWaOHGiJOn
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()\n"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}