uczenie-maszynowe/wyk/06_Problem_nadmiernego_dopasowania.ipynb

1833 lines
573 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe\n",
"# 6. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ szerokość działki, $x_2$ długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y))\n",
"\n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print(\"Algorithm does not converge!\")\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta])\n",
" return theta, logs\n",
"\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c=\"r\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth=\"2\")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 3,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv(\n",
" \"data_flats.tsv\", header=0, sep=\"\\t\", usecols=[\"price\", \"rooms\", \"sqrMetres\"]\n",
")\n",
"data = np.matrix(alldata[[\"sqrMetres\", \"price\"]])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(\n",
" m, 3 * n + 1\n",
")\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 4,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2022-11-28 11:52:13 +01:00
"[<matplotlib.lines.Line2D at 0x7ff4e8126fb0>]"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEnUlEQVR4nO3dd3hUZd7/8c+UNCAJhpIQCB0B6SAg0hUbrgtixbp2FAuwxeW3++w+W1zXZxXcIrDrqqy62FBAxAYISKgConTpNQk9ISF1Zn5/HDIppMwkM3OmvF/XNVdyJmdmvimE88l939/b4nK5XAIAAAAAeMxqdgEAAAAAEGoIUgAAAADgJYIUAAAAAHiJIAUAAAAAXiJIAQAAAICXCFIAAAAA4CWCFAAAAAB4iSAFAAAAAF4iSAEAAACAlwhSAAAAAOCliA5SX3/9tW666SalpqbKYrFo/vz5Xj+Hy+XSiy++qEsvvVQxMTFq2bKlnnvuOd8XCwAAACBo2M0uwEx5eXnq1auXHnzwQY0bN65Oz/HMM8/oyy+/1IsvvqgePXro9OnTOn36tI8rBQAAABBMLC6Xy2V2EcHAYrFo3rx5Gjt2rPu+wsJC/epXv9I777yjs2fPqnv37nrhhRc0YsQISdKOHTvUs2dPbd26VZ07dzancAAAAAABF9FT+2rz5JNPas2aNXr33Xf1/fff67bbbtP111+v3bt3S5IWLlyo9u3b65NPPlG7du3Utm1bPfzww4xIAQAAAGGOIFWNQ4cO6Y033tAHH3ygoUOHqkOHDvrZz36mIUOG6I033pAk7du3TwcPHtQHH3ygN998U7Nnz9bGjRt16623mlw9AAAAAH+K6DVSNdmyZYscDocuvfTSCvcXFhaqSZMmkiSn06nCwkK9+eab7vNee+019evXT7t27WK6HwAAABCmCFLVyM3Nlc1m08aNG2Wz2Sp8rFGjRpKkFi1ayG63VwhbXbt2lWSMaBGkAAAAgPBEkKpGnz595HA4dPz4cQ0dOrTKcwYPHqySkhLt3btXHTp0kCT98MMPkqQ2bdoErFYAAAAAgRXRXftyc3O1Z88eSUZwmjZtmkaOHKmkpCS1bt1a99xzj1atWqWXXnpJffr00YkTJ7R06VL17NlTN954o5xOp/r3769GjRrp5ZdfltPp1MSJE5WQkKAvv/zS5M8OAAAAgL9EdJBavny5Ro4cedH9999/v2bPnq3i4mL98Y9/1JtvvqmjR4+qadOmuuKKK/S73/1OPXr0kCQdO3ZMTz31lL788ks1bNhQN9xwg1566SUlJSUF+tMBAAAAECARHaQAAAAAoC5ofw4AAAAAXiJIAQAAAICXIq5rn9Pp1LFjxxQfHy+LxWJ2OQAAAAD8yOVy6dy5c0pNTZXV6rtxpIgLUseOHVNaWprZZQAAAAAIoMOHD6tVq1Y+e76IC1Lx8fGSjC9kQkKCJKmg2KGRLy7TuQKHYqKsWv6zEYqPjTKzTACITPn5UosWkid9kCwWKSNDiovzf11AmJq9ar9e/NLYA3PC8A568qqOJlcE+F5OTo7S0tLcOcBXIi5IlU7nS0hIcAepBEk3D+ykt9ceUrGkVQfP6/b+jFoBQMAlJEhjx0oLF0olJdWfZ7dLY8ZIyckBKw0IR5//kCNrTANJ0p2DL1VCQiOTKwL8x9fLemg2ccGt/cqC09yNR0ysBAAi3JQpksNR8zkOhzR5cmDqAcLUrsxz2p6RI0nqldZY7ZsRogBvEKQu6NUqUR2bG79A1h84rQMn80yuCAAi1JAh0owZxtQ9e6WJE3a7cf+MGdLgwebUB4SJ+ZuPut+/uXeqiZUAoYkgdYHFYtGt/coWn320iVEpADDNhAnSypXG9L3SDktWq3G8cqXxcQB15nS6tOBbI0jZrBb9qBdBCvAWQaqcm/u0lPXC1MkPNx2V0+nBYmcAgH8MHizNnSvl5kqZmcbbuXMZiQJ8YN3+0zqWXSBJGtapqZo2ijG5IiD0EKTKSU6I1bBLm0mSjp7N19p9p0yuCACguDijqQTd+QCfmf9t2bS+sX1amlgJELoIUpWUn95H0wkAABBuCood+nRLhiSpYbRN116WYnJFQGgiSFUyqmuyEmKNxc2fbs3QuYJikysCAADwna92Hte5QmN7geu7t1BctM3kioDQRJCqJDbKph9f6FxTUOzUZ1syTa4IAADAd+aVm9Z3M9P6gDojSFWBPaUAAEA4OpNXpOW7jkuSmsfHaFCHJiZXBIQuglQV2FMKAACEo0VbMlTsMLoSj+mdKltpu2IAXiNIVYE9pQAAQDiaR7c+wGcIUtVgTykAABBODp06r40Hz0iSLk1upMtaJJhcERDaCFLVYE8pAAAQTuZvLt9kopUsFqb1AfVBkKoBe0oBAIBw4HK5KmzCO+ZCh2IAdUeQqsGorslKjIuSxJ5SAAAgdH1/JFv7LjTPuqJ9klIbx5lcERD6CFI1iI2y6ce9yvaUKt0FHAAAIJSwdxTgewSpWjC9DwAAhLJih1MLvzsmSYq2W3V99xYmVwSEB4JULXq2SlSnC3tKfXPgDHtKAQCAkJK++6RO5RVJkkZ1be5etgCgfghStai8p9SH7CkFAABCSIW9o3ozrQ/wFYKUByrsKbXxCHtKAQCAkJBbWKIvt2dKkho3iNKIzs1NrggIHwQpDzRPiNXwC3tKHcsu0Br2lAIAACFg4XfHVFDslCTd2KOFou1c+gG+wr8mD93aL839Pk0nAABAKHh3/SH3+3f0T6vhTADeIkh56OpyizM/Y08pAAAQ5LYdy9Z3R7IlSd1SE9SjZaLJFQHhhSDlIfaUAgAAoeTd9Yfd7985oLUsFouJ1QDhhyDlBfaUAgAAoSC/yKH5m41ufXFRNo3pnWpyRUD4IUh5gT2lAABAKFi0JUPnCkokSTf2bKGEWPaOAnzN1CD1/PPPq3///oqPj1fz5s01duxY7dq1q8bHzJ49WxaLpcItNjY2IPWypxQAAAgF5ZtMjB9AkwnAH0wNUitWrNDEiRO1du1aLV68WMXFxbr22muVl1fzSE9CQoIyMjLct4MHDwaoYvaUAgAAwW131jltOHhGktSpeSP1bX2JyRUB4clu5ot//vnnFY5nz56t5s2ba+PGjRo2bFi1j7NYLEpJSfF3eVUq3VNq2a4T7j2lBndsakotAAAAlb37DU0mgEAIqjVS2dlGi86kpKQaz8vNzVWbNm2UlpamMWPGaNu2bdWeW1hYqJycnAq3+mJPKQAAEIwKih366MLSg2i7VeP6tDS5IiB8BU2QcjqdmjRpkgYPHqzu3btXe17nzp31+uuva8GCBXr77bfldDp15ZVX6siRqgPN888/r8TERPctLa3+84TZUwoAAASjL7Zl6sx547rkhu4puqRhtMkVAeEraILUxIkTtXXrVr377rs1njdo0CDdd9996t27t4YPH66PPvpIzZo10z//+c8qz586daqys7Pdt8OHD1d5njdiy7URLSh2atH37CkFAADMV2HvqP6tTawECH9BEaSefPJJffLJJ1q2bJlatWpV+wPKiYqKUp8+fbRnz54qPx4TE6OEhIQKN19gTykAABBMDpzM05p9pyRJ7Zo21BXta14qAaB+TA1SLpdLTz75pObNm6evvvpK7dq18/o5HA6HtmzZohYtWvihwur1aJmoS5ONPaU2HDyj/ewpBQAATFS+ycQd/dNoMgH4malBauLEiXr77bc1Z84cxcfHKzMzU5mZmcrPz3efc99992nq1Knu49///vf68ssvtW/fPm3atEn33HOPDh48qIcffjigtV+0pxSjUgAAwCTFDqd7hozdatEtfb2b4QPAe6YGqZkzZyo7O1sjRoxQixYt3Lf33nvPfc6hQ4eUkVG2BunMmTN65JFH1LVrV40ePVo5OTlavXq1LrvssoDXP7Z3S9kubCr14aYjcrCnFAAAMMHSHVk6mVsoSbrmsmQ1i48xuSIg/Jm6j5TLVXvwWL58eYXj6dOna/r06X6qyDule0p9tfO4MrILtGbvKQ3pxJ5SAAAgsN5ZX3HvKAD+FxTNJkJZxaYT9e8ICAAA4I0jZ87
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 6,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 397519.38046962]\n",
" [-841341.14146733]\n",
" [2253713.97125102]\n",
" [-244009.07081946]]\n"
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEN0lEQVR4nO3dd3yV9fn/8fcZWRAChJEQCHvvoSAbFAcuhlrFWatWFFsRW1t/9Wtrl61V0FoB2zrqHiggThBkhKks2XtmsRMSMs85vz9ucpJAxjnJObnPeD0fjzyS+5z7nHNlEO4r1+dzXRaXy+USAAAAAMBjVrMDAAAAAIBgQyIFAAAAAF4ikQIAAAAAL5FIAQAAAICXSKQAAAAAwEskUgAAAADgJRIpAAAAAPASiRQAAAAAeIlECgAAAAC8RCIFAAAAAF4K60Rq+fLluuGGG5SUlCSLxaJ58+Z5/Rwul0vPP/+8OnfurKioKLVs2VJ/+ctffB8sAAAAgIBhNzsAM+Xm5qpPnz762c9+pokTJ9boOR599FEtXLhQzz//vHr16qVTp07p1KlTPo4UAAAAQCCxuFwul9lBBAKLxaK5c+dq/Pjx7tsKCgr0u9/9Tu+//77OnDmjnj176u9//7tGjRolSdqxY4d69+6trVu3qkuXLuYEDgAAAKDOhfXSvuo88sgjWr16tT744AP9+OOPuuWWW3TNNddoz549kqQFCxaoffv2+vzzz9WuXTu1bdtW999/PxUpAAAAIMSRSFXi8OHDeuONN/Txxx9r+PDh6tChg371q19p2LBheuONNyRJ+/fv16FDh/Txxx/rrbfe0ptvvqn169fr5ptvNjl6AAAAAP4U1nukqrJlyxY5HA517ty53O0FBQVq0qSJJMnpdKqgoEBvvfWW+7zXXntNAwYM0K5du1juBwAAAIQoEqlK5OTkyGazaf369bLZbOXui42NlSS1aNFCdru9XLLVrVs3SUZFi0QKAAAACE0kUpXo16+fHA6Hjh07puHDh1d4ztChQ1VcXKx9+/apQ4cOkqTdu3dLktq0aVNnsQIAAACoW2HdtS8nJ0d79+6VZCRO06dP1+jRoxUfH6/WrVvrzjvv1MqVK/XCCy+oX79+On78uBYvXqzevXvruuuuk9Pp1KWXXqrY2Fi9+OKLcjqdmjJliuLi4rRw4UKTPzsAAAAA/hLWidTSpUs1evToi26/55579Oabb6qoqEh//vOf9dZbbyk1NVVNmzbVZZddpmeeeUa9evWSJKWlpekXv/iFFi5cqPr162vs2LF64YUXFB8fX9efDgAAAIA6EtaJFAAAAADUBO3PAQAAAMBLJFIAAAAA4KWw69rndDqVlpamBg0ayGKxmB0OAAAAAD9yuVw6e/askpKSZLX6ro4UdolUWlqakpOTzQ4DAAAAQB06cuSIWrVq5bPnC7tEqkGDBpKML2RcXJwpMSzclqFpH22WJN08oJX+cGMPU+IAgICTlye1aCF50gfJYpHS06WYGP/HBYS5Jz/5UQt+TJck/WlcD03o77uLUcDfsrOzlZyc7M4DfCXsEqmS5XxxcXGmJVLXDqinp77cp/wip5YfzFH92AayWVlmCACKi5PGj5cWLJCKiys/z26Xxo2TEhLqLDQgXJ3KLdSivWdljaqnhjERunVoF0VH2MwOC/Car7f10GzCBPUi7RrZuZkk6UROob4/eMrkiAAggEybJjkcVZ/jcEiPPVY38QBh7qMfjqjQ4ZQk3TKgFUkUcB6JlEnG9mzh/vjrrRkmRgIAAWbYMGnmTGPpnv2ChRN2u3H7zJnS0KHmxAeEEafTpXfXHnIf33FZGxOjAQILiZRJLu/WXBE2o7z49dYMOZ3MRQYAt8mTpRUrjOV7JR2WrFbjeMUK434Afrdsz3EdOZUnSRreqanaNa1vckRA4Ai7PVKBIi46QsM6NtV3u44rIztfm46eUf/Wjc0OCwACx9ChxltenpSdbeyforEEUKfeWV1ajbqTahRQDhUpE7G8DwA8EBNjNJUgiQLq1JFT57Rk1zFJUouG0bqia3OTIwICC4mUia7snuDu1vfV1nS5PGn3CwAAUAfeX3fYPYng9oGtZbdx2QiUxb8IEzWuH6nL2sdLko6cytO2tGyTIwIAAJAKih368PsjkiS71aJbByabHBEQeEikTHYNy/sAAECA+Xprhk7mFkqSru6ZqOYNok2OCAg8JFImu7pHgkpmg321Nd3cYAAAACS9s6a0ycRdNJkAKkQiZbLmDaJ1SRujW9++47nak3nW5IgAAEA425mRre8PnpYkdWoeq0Ht4k2OCAhMJFIBoOzyvq9Y3gcAAExUthp152VtZClZOgOgHBKpAHBNz0T3xyRSAADALGfzizR3Q6okqV6kTRP6tzQ5IiBwkUgFgJaNYtSnVUNJ0o70bB06mWtyRAAAIBzN25iq3EKHJGl8v5aKi44wOSIgcJFIBQiW9wEAADO5XC69s+aw+/jOQTSZAKpCIhUgxrK8DwAAmOj7g6e163zTqwFtGqt7UpzJEQGBjUQqQLRtWl9dExtIkjYfOaO0M3kmRwQAAMLJ27Q8B7xCIhVAxjKcFwAAmOD42QJ9fX6eZXz9SI3tlVjNIwCQSAWQsr+0SKQAAEBd+eiHIypyuCRJP7kkWVF2m8kRAYGPRCqAdGoeq/bN6kuSvj90SsfO5pscEQAACHUOp0vvnl/WZ7FIdwxqbXJEQHAgkQogFovF3XTC5ZIWbss0OSIAABDqluw8prQs44+3ozo3U3J8PZMjAoIDiVSAYZ8UAACoS++UbTIxmCYTgKdIpAJMj6Q4tWocI0lavf+kTucWmhwRAAAIVYdO5mrZ7uOSpFaNYzSyc3OTIwKCB4lUgCm7vM/hdGnRDpb3AQAA/3h3bekA3jsGtZHNajExGiC4kEgFoGtY3gcAAPwsv8ihj344IkmKtFn1k0tamRwREFxIpAJQv+RGSoiLkiSl7Dmhs/lFJkcEAABCzRc/puvMOeMa49peiWoSG2VyREBwIZEKQFarRdf0MJb3FTqcWrLzmMkRAQCAUPM2TSaAWjE1kXr22Wd16aWXqkGDBmrevLnGjx+vXbt2VfmYN998UxaLpdxbdHR0HUVcd8ou7/tqC8v7AACA72xNzdKmI2ckSV0TG6h/68bmBgQEIVMTqWXLlmnKlClas2aNFi1apKKiIl111VXKzc2t8nFxcXFKT093vx06dKjK84PRwHbxalI/UpK0dPcxnSssNjkiAAAQKi5seW6x0GQC8JbdzBf/+uuvyx2/+eabat68udavX68RI0ZU+jiLxaLExER/h2cqm9Wiq3ok6P11R5Rf5NSyXcc1tleL6h8IAABQhay8Is3blCpJio2ya3zfliZHBASngNojlZWVJUmKj4+v8rycnBy1adNGycnJGjdunLZt21bpuQUFBcrOzi73FizKLe+jex8AAPCBTzccVX6RU5I0sX9L1Y8y9e/qQNAKmETK6XRq6tSpGjp0qHr27FnpeV26dNHrr7+u+fPn65133pHT6dSQIUN09OjRCs9/9tln1bBhQ/dbcnKyvz4Fnxvcvonioo1fbkt2HlNBscPkiAAAQDBzuVzlmkzceRlNJoCaCphEasqUKdq6das++OCDKs8bPHiw7r77bvXt21cjR47Up59+qmbNmunVV1+t8Pwnn3xSWVlZ7rcjR474I3y/iLRbNaZ7giQpp6BYKXtOmBwRAAAIZqv3ndT+48Ze9EHt4tU5oYHJEQHBKyASqUceeUSff/65vvvuO7Vq5d0wuIiICPXr10979+6t8P6oqCjFxcWVewsmY1neBwAAfOSdtVSjAF8xNZFyuVx65JFHNHfuXC1ZskTt2rXz+jkcDoe2bNmiFi1CsxHD8E5NVS/SJklatD1TRQ6nyREBAIBglJmdr2+2ZUqSmsZG6eoeod24C/A3UxOpKVOm6J133tF7772nBg0aKCMjQxkZGcrLy3Ofc/fdd+vJJ590H//xj3/UwoULtX//fm3YsEF33nmnDh06pPvvv9+MT8HvoiNsGt21uSSjy86a/SdNjggAAASj99cdlsPpkiR
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1, x2, n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n + 1):\n",
" for i in range(m + 1):\n",
" X.append(np.multiply(np.power(x1, i), np.power(x2, (m - i))))\n",
" return np.hstack(X)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c=\"r\", marker=\"x\", s=50, label=\"Dane\")\n",
" ax.scatter(X1p, X2p, c=\"g\", marker=\"o\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" return fig\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 10,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpKUlEQVR4nO3df3hU1b3v8c9MQkI0TpALJFCHatSAtigKNYYgckuuYOkxQs9psDYqh+pB0VagKvTWeMS2WH/gfWr50VpFm1ZJ9RGRaqkGxSYBQUEqKhBRbKKSWOWQAYyZJLPvH9OMmclkMpPMr73n/XqeeZS91wxrNntm9nd/1/oum2EYhgAAAAAAUWdPdAcAAAAAwKoIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEYIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEbSE90BK/B4PPr444910kknyWazJbo7AAAAAGLIMAwdPXpUo0aNkt0eOodFwBUFH3/8sZxOZ6K7AQAAACCOGhsbdcopp4RsQ8AVBSeddJIk7wF3OBwJ7g0AAACAWHK5XHI6nb44IBQCrijoGkbocDgIuAAAAIAUEc50IopmAAAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMmCrg+tvf/qZ/+7d/06hRo2Sz2fTMM8/0+ZwtW7bo/PPPV2Zmps444ww9+uijPdqsXLlSp556qgYPHqzCwkLt2LEj+p0HAAAAkHJMFXAdP35c5557rlauXBlW+4MHD2rmzJn63//7f2v37t26+eab9YMf/EB//etffW2qqqq0aNEi3XHHHdq1a5fOPfdcTZ8+XZ988kms3gYAAACAFGEzDMNIdCf6w2azaf369br88st7bXPbbbfpueee01tvveXbNmfOHB05ckSbNm2SJBUWFuob3/iGfv3rX0uSPB6PnE6nbrrpJi1ZsiSsvrhcLuXk5KilpUUOh6P/bwoAAABA0ovk+t9UGa5Ibdu2TSUlJX7bpk+frm3btkmS3G63du7c6dfGbrerpKTE1yaYtrY2uVwuvwcAAAAABLJ0wNXU1KTc3Fy/bbm5uXK5XGptbdWnn36qzs7OoG2ampp6fd3ly5crJyfH93A6nTHp/4C53QPbDwAAAGBALB1wxcrSpUvV0tLiezQ2Nia6Sz1VVUnjxkm99a2x0bu/qiq+/QIAAABSiKUDrry8PDU3N/tta25ulsPhUFZWloYNG6a0tLSgbfLy8np93czMTDkcDr9HUnG7pYoKqb5emjq1Z9DV2OjdXl/vbUemC0CikZEHAFiUpQOuoqIibd682W/biy++qKKiIklSRkaGJkyY4NfG4/Fo8+bNvjamlJEhVVdL+fnS++/7B11dwdb773v3V1d72wNAopCRBwBYmKkCrmPHjmn37t3avXu3JG/Z9927d6uhoUGSd6jfVVdd5Ws/f/58vf/++7r11lu1b98+rVq1Sn/605+0cOFCX5tFixbpoYce0mOPPaa9e/fq+uuv1/HjxzV37ty4vreoczqlLVv8g66tW/2DrS1bvO1SBXfQgeRDRh4AYHGmCrhef/11nXfeeTrvvPMkeYOl8847TxUVFZKkQ4cO+YIvSTrttNP03HPP6cUXX9S5556r+++/X7/73e80ffp0X5uysjLdd999qqio0Pjx47V7925t2rSpRyENUwoMuoqLUzfY4g46kJzIyAMALM6063Alk6Rfh2vrVm+w1aWuTpo0KXH9iTe32xtM1dcHDza7X9QVFEh79nBRB8RbYHBVWSmVl6fuTSIAQA+t7a1ytbnkyHQoa1BWQvvCOlz4UmOj96Klu/Ly3jM9VsQddCD5kZEHAPSitqFWs6tmK3t5tvLuz1P28mzNrpqtuoa6RHctLARcVhYYTNTVBQ86UgFz2oDk53R6M1vdVVbyuQSAFLb6tdWasnaKNtZvlMfwSJI8hkcb6zfqorUXac3raxLcw74xpDAKknJIYWCw1RVM9LY9VXR//11S8TgAyYjPJwCgm9qGWk1ZO0WGeg9XbLKpZm6NikcX99omFhhSmOrcbqmkJHhQFZjpKSlJrapf3EEHkhMZeaQaKucCfVqxbYXS7Gkh26TZ0/TAqw/EqUf9Q8BlRRkZ0rJl3gIQwe4MdwVdBQXedqk0Z4k5bUDyCZZ5nzSp5zBgPqewCirnAn1qbW/Vhv0b1OHpCNmuw9Oh9fvWq7W9NU49ixwBl1WVlXmr7fWWuXE6vfvLyuLbr0TiDjqQfPqbkSc7ALNi7TkgLK42l2/OVl88hkeuNleMe9R/BFxW1lfmKtUyW9xBB5JPfzLyZAdgZlTOBcLiyHTIbgsvVLHb7HJkJkkdhSAomhEFSVk0A19iHS4g+bndoT93Xfv5PMMqWHsO6NPsqtnaWL8x5LDCdHu6SseU6qnvPhXHnlE0A/DHnDYg+YWbkSc7AKtg7TmgT4uKFqnT0xmyTaenUwsvXBinHvUPGa4oIMNlEuHeQQeQ/MgOwCq2bvUGW13q6rxD3gFIkta8vkY3PHeD0uxpfpmudHu6Oj2dWjVzleZPnB/3fpHhAoJhThtgHWQHYAVUzgX6NH/ifNXMrVHpmFLfnC67za7SMaWqmVuTkGArUmS4ooAMFwBEUSTZaLIDMCuytEDEWttb5WpzyZHpUNagrIT2hQwXAMCcIqlASHYAZkXlXKBfsgZlKTc7N+HBVqQIuAAAySGS9YmWLpUuvph19WA+/V17DoBpEXABAJJDuBUIR4+WOjulgwfJDsB8qJwLpBzmcEUBc7gAIIpCzW057TTJZut9ngvrcMEsqJwLmBpzuAAA5hWqAuErr0i/+AXZAZgflXOBlEGGKwrIcAFADISqQEh2IPnxbwTAwshwAQDMra8KhGQHklsk1SYBwOIIuAAAySVwDhcVCM0lkmqTFRVU4QNgeQRcAIDkwfpE5hdutcn8fG87spHm01eQTBAN+CHgAgAkB9Ynso7Af6+pU71z8gKD6cCiJ0h+DBcFIkbABQBIDqxPZC2hqk0SbJkTw0WBfqFKYRRQpRAAoojqdtYSqtokzCfYsF+ns/ftgEVRpRAAYF5UILSOvqpNwnwYLgpEjIALAABEH9UmrYvhokBECLgAAEB0UW3S+pxOqbLSf1tlJcEWEAQBFwAAiB6qTaYGhosCYSPgAgAA0UO1SetjuCgQEaoURgFVCgEACEC1SWuiSiEgiSqFAAAg0ag2aT0MFwX6hYALAAAAfWO4KNAvDCmMAoYUAgCAlMFwUYAhhQAAAIgRhosCESHgAgAAAIAYIeACAAAAgBgh4AIAAACAGCHgAgAAAIAYIeACAAAAgBgxZcC1cuVKnXrqqRo8eLAKCwu1Y8eOXttOnTpVNputx2PmzJm+Ntdcc02P/TNmzIjHWzGfvhYxZJFDAAAAwMd0AVdVVZUWLVqkO+64Q7t27dK5556r6dOn65NPPgna/umnn9ahQ4d8j7feektpaWn6j//4D792M2bM8Gv3xBNPxOPtmEtVlTRunNTYGHx/Y6N3f1VVfPsFAADMi5u5sDjTBVwrVqzQtddeq7lz5+rss8/WmjVrdMIJJ+iRRx4J2n7o0KHKy8vzPV588UWdcMIJPQKuzMxMv3Ynn3xyPN6OebjdUkWFVF8vTZ3aM+hqbPRur6/3tuPLEQAA9IWbuUgBpgq43G63du7cqZKSEt82u92ukpISbdu2LazXePjhhzVnzhydeOKJftu3bNmiESNGaMyYMbr++uv12Wef9foabW1tcrlcfg/Ly8iQqqul/Hzp/ff9g66uYOv99737q6tZ9BAAAITGzVykCFMFXJ9++qk6OzuVm5vrtz03N1dNTU19Pn/Hjh1666239IMf/MBv+4wZM/T73/9emzdv1i9/+Uu98soruvTSS9XZ2Rn0dZYvX66cnBzfw+l09v9NmYnTKW3Z4h90bd3qH2xt2eJtBwAAEAo3c5EibIZhGInuRLg+/vhjfeUrX9HWrVtVVFTk237rrbfqlVde0fbt20M+/7/+67+0bds2vfnmmyHbvf/++zr99NNVXV2tadOm9dj
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób,\n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0 / (1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X * theta, eps)\n",
"\n",
"\n",
"def J(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = (\n",
" -np.sum(np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0)\n",
" / m\n",
" )\n",
" if lamb > 0:\n",
" j += lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ(h, theta, X, y, lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" if lamb > 0:\n",
" g[1:] += lamb / float(y.shape[0]) * theta[1:]\n",
" return g\n",
"\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta])\n",
" return theta, errors\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n",
"print(r\"theta = {}\".format(theta))\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02), np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-11-28 11:52:13 +01:00
"/tmp/ipykernel_74/1169766636.py:9: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACXNklEQVR4nOzdeVxU1fsH8M8srOKAioDmuKCCS+6kImqUlKglaQu2aJplLm1iWVZqaqWpad/KpcWlKJOs1Ex/lpioLO6au7gzLuAWjMDIwMz9/TExMjDsM3Nn+bxfL17KvWeGZy7D3Pvcc85zJIIgCCAiIiIiIiKLk4odABERERERkbNiwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhK5GIH4Az0ej2uXLmCunXrQiKRiB0OERERERFZkSAIuH37Nho3bgyptOI+LCZcFnDlyhUolUqxwyAiIiIiIhtSqVRo0qRJhW2YcFlA3bp1ARgOuEKhuLvj0iVg0CDgwgWgeXPg66+BMWPufr9xI1DJL8hpabWAu3vN9xORuEp/vhV/npW3nYiIyImo1WoolUpjHlARiSAIgg1icmpqtRq+vr7IyckxTbgAQKUCIiOBc+fubgsOBpKSAFftFUtIAKZNAxITzR8DlQqIigJmzgRiY20fHxFVTcnPt+BgID4eGD787veu/DlHREROrcLr/1KYcFlApQc8NRWIiLj7fUoK0KuX7QK0J1ot0KEDkJ5u/oKs5AVcSAhw5Ah7uojsGW8qERGRlWkKNVAXqKHwUMDLzUvscABUL+FilUJrU6kMd3xLGj7csN0VubsberaCgw0XaJGRd49F6bvliYlMtojsnVJp6NkqKT6eyRYREdVackYyhiYMhc9sHwR9GgSf2T4YmjAUKRkpYodWLUy4rKl0ApGSYj7RcDVKpeHud8ljkZpqeqx4d5zIMfCmEhERWcGSvUvQd0VfbEjfAL2gBwDoBT02pG9AnxV9sHTfUpEjrDomXNZSOtlKSjIMIyydaLjqRUnppCsigskWkaPhTSUiIrKC5IxkTNg0AQIEFOmLTPYV6YsgQMD4jeMdpqeLCZc1aLWGog/mEojSiUZUlKG9K+JQJCLHxZtKRGVVdj531fM9UTUtSFsAmVRWYRuZVIaFuxbaKKLaYcJlDe7uhgp7ISHme2uKk66QEEM7V52nxKFIRI6JN5WIykpIMBSFKu8cplIZ9ick2DYuIgejKdRg/an1ZXq2SivSF2HtybXQFGpsFFnNMeGylthYQ4W98nprlErDflcte86hSESOizeViExptYblTtLTzZ/Dis956emGdrwJQVQudYHaOGerMnpBD3WB2soR1R4TLmuq7CLDVS9COBSJyPHV5KYSh1uRs2IFXiKLUXgoIJVULUWRSqRQeFRckt0eMOEi2+JQJCLnUZ2bShxuRc6OFXiJLMLLzQsxoTGQS+UVtpNL5RjSZojdrMtVESZcZFscikTkejjcilwFK/ASWURceBx0el2FbXR6HSb2nGijiGqHCRfZHue3EbkWDrciV8IKvES11rtpbywetBgSSMr0dMmlckggweJBixHRNEKkCKuHCReJg/PbiFwLh1uRq2AFXiKLGBs2FjtH7URMaIxxTpdUIkVMaAx2jtqJsWFjRY6w6iSCIAhiB+Ho1Go1fH19kZOTA4XC/ifuERGJpmSPVjEmW+QsSvfYxscbki3eVCCqFU2hBuoCNRQeCruZs1Wd63/2cBERUe1VtQIhh1uRs2IFXiKr8XLzQqBPoN0kW9XFhIuIiGqnOhUIOdyKnBEr8BJRBZhwERFRzVWnAuGUKcD993PBc3I+rMBLRBXgHC4L4BwuInJp5oZSKZWm25s2NbTNyCi/Dee4kKPTaitOpirbT0QOg3O4iIjIdiqrQNiiBSCXl022zD2Ww63IkbECLxGZwYSLiIhqr6IFX7dvBz7+mMOtiIjIJXFIoQVwSCER0X9SUw3JVrGUFEOlNoDDrYiIyGlwSCEREdleZRUIOdzKcVW17D8REZXBhIuIiGqvdPELViB0HtUp+09ERGUw4SIiotrhgq/Oqzpl/6dNY08XEZEZTLiIiKjmuOCrc3N3BxITzSfOpRPtxEQOC3UGHD5KZHFMuIiIqOa44Kvzq6zsP9dPcx4cPkpkFaxSaAGsUkhELo8VCJ1fyR6tYky2nIdWa0im0tPN/15L/v5DQoAjR/g3TS6NVQqJiMi2WIHQ+SmVQHy86bb4eCZbzoLDR4mshgkXERERVa6ysv/k+Dh8lMgqmHARERFRxVj233WUTroiIphsEdUSEy4iIiIqH8v+ux4OHyWyKCZcREREZB7L/rsmDh8lsigmXERERGQey/67Hg4fJbI4loW3AJaFJyIip8ay/67B3PBRpbL87UQujGXhiYiIyHJY9t/5cfgokdUw4SIiIiJydRw+SmQ1HFJoARxSSERERE6Bw0eJqoRDComIiIio+jh8lMjimHARERERERFZCRMuIiIiIiIiK2HCRUREREREZCVMuIiIiIiIiKyECRcREREREZGVMOEiIiIiIiKyEiZcREREREREVuKQCdeiRYvQvHlzeHp6okePHtizZ0+5bSMjIyGRSMp8DRo0yNhm5MiRZfZHR0fb4qXUjFZbu/1EREREtsLrFnJxDpdwJSQkIC4uDtOnT8eBAwfQqVMn9O/fH9euXTPb/rfffsPVq1eNX0ePHoVMJsOTTz5p0i46Otqk3U8//WSLl1N9CQlAhw6ASmV+v0pl2J+QYNu4iIiIiErjdQuR4yVcCxYswEsvvYRRo0ahXbt2WLp0Kby9vbF8+XKz7evXr4+goCDj15YtW+Dt7V0m4fLw8DBpV69ePVu8nOrRaoFp04D0dCAysuyHl0pl2J6ebmjHO0ZEREQkFl63EAFwsIRLq9Vi//79iIqKMm6TSqWIiopCWlpalZ5j2bJlGDZsGOrUqWOyPSkpCQEBAQgNDcW4ceNw8+bNcp+joKAAarXa5Msm3N2BxEQgOBg4d870w6v4Q+vcOcP+xERDeyIiIiIx8LqFCICDJVw3btyATqdDYGCgyfbAwEBkZmZW+vg9e/bg6NGjePHFF022R0dH4/vvv8fWrVvxySefYPv27RgwYAB0Op3Z55k9ezZ8fX2NX0qlsuYvqrqUSiApyfTDKzXV9EMrKcnQjoiIiEhMvG4hgkQQBEHsIKrqypUruOeee5Camorw8HDj9smTJ2P79u3YvXt3hY9/+eWXkZaWhsOHD1fY7ty5c2jZsiUSExPRr1+/MvsLCgpQUFBg/F6tVkOpVCInJwcKhaKar6qGSt4ZKsYPLSIiIrJHvG4hJ6NWq+Hr61ul63+H6uHy9/eHTCZDVlaWyfasrCwEBQVV+Ni8vDysXr0ao0ePrvTnBAcHw9/fH2fOnDG738PDAwqFwuTL5pRKID7edFt8PD+0iIiIyP7wuoVcmEMlXO7u7ujWrRu2bt1q3KbX67F161aTHi9z1qxZg4KCAjz33HOV/pxLly7h5s2baNSoUa1jthqVChg+3HTb8OHlVwEiIiIiEguvW8iFOVTCBQBxcXH45ptv8N133+HEiRMYN24c8vLyMGrUKADAiBEjMGXKlDKPW7ZsGR577DE0aNDAZHtubi7eeust7Nq1CxcuXMDWrVsRExODVq1aoX///jZ5TVVSsnJP6YmmKSlAixZlJ6QSERERic3cdYu5QhpETkoudgDVFRsbi+vXr2PatGnIzMxE586dsXnzZmMhjYyMDEilpnnkqVOnkJycjL/++qvM88lkMhw+fBjfffcdsrOz0bhxYzz88MOYNWsWPDw8bPKaKpWQYCiXmpho+L70RNNiAQF3P7w4JpqIiIjEVjrZKr4+SUq6u53XLeTkHKpohr2qzqS5atNqDQsCpqcberEkkrLJVvEHVosWhu/PnwdCQoAjR1hilYiIiMRR8hrGXIGMkskYr1vIwTht0QyXVHINi/Pngdxc88lWcDCwfbvhKyQEmDmTH1pEREQkHnd3w/VISIj5Hqzini5et5CTYw+
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-11-28 11:52:13 +01:00
"/tmp/ipykernel_74/1169766636.py:9: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACZp0lEQVR4nOzdeVxU1fsH8M8sMIA4gCIgCioquO+Jipr9pNwqUisqc8ss0za1zRYtW8wstcUl+6YmbWilZpol5sLmrrmLO6iAC8IIjDMwc39/TIwMDPvM3Fk+79drXsq9Z4Znhpk797nnnOdIBEEQQERERERERBYnFTsAIiIiIiIiZ8WEi4iIiIiIyEqYcBEREREREVkJEy4iIiIiIiIrYcJFRERERERkJUy4iIiIiIiIrIQJFxERERERkZXIxQ7AGej1ely5cgX169eHRCIROxwiIiIiIrIiQRBw69YtBAcHQyqtvA+LCZcFXLlyBSEhIWKHQURERERENpSRkYGmTZtW2oYJlwXUr18fgOEFVyqVd3ZcugQMGwZcuAA0bw4sWwY888ydnzduBKr4AzktrRZwd6/9fiISV9njW8nxrKLtRERk9w5sPYIvp3yDnMxcAMA9j0dh/IePw6ehsvI7uiCVSoWQkBBjHlAZiSAIgg1icmoqlQo+Pj7Iy8szTbgAICMDGDAAOHfuzrawMGD7dsBVe8Xi44GZM4GEBPOvQUYGEB0NzJ4NxMbaPj4iqp7Sx7ewMCAuDhg9+s7PrnycIyJyUAV5BVj+1k/YsORvCIIAZcP6mDR/LKKf7M+pM6VUev5fBhMuC6jyBU9JAaKi7vycnAz06WO7AO2JVgt07AikpZk/ISt9AhceDhw5wp4uInvGi0pERE7p+K40LHz2a5w/kg4A6BbdES8teQbBLYNsHou6SA2VRgWlQglPN0+b/35zapJwsUqhtWVkGK74ljZ6tGG7K3J3N/RshYUZTtAGDLjzWpS9Wp6QwGSLyN6FhBh6tkqLi2OyRUTk4Nr1CsfifXPx1IdPwE3hhgMJR/BMp+n4beFG6HQ6m8SQlJ6EEfEj4D3HG0GfBcF7jjdGxI9AcnqyTX6/pTDhsqayCURysvlEw9WEhBiufpd+LVJSTF8rXh0ncgy8qERE5LTkbnI8PmM4lh3+DJ0HtIdGrcWSaSsxrf9MpJ+8bNXfvWTvEvRf0R8b0jZAL+gBAHpBjw1pG9BvRT8s3bfUqr/fkjik0ALMdimWTbZKEoiKtrsiDkUicmycw0VE5DL0ej02fbMV37wWh8Jbargp3DB65iN49NUHIZPLLPq7ktKT0H9FfwioOE2RQILE8YmICo2qsI01cUih2LRaQ9EHcycdZXt3oqMN7V0RhyIROS5zF4/69Cnfe82eLnIlVX2fu+r3PTkFqVSK+5+9F98cnY+7hnRFkaYIy9/6ES/2eQvnj6Zb9HfNT50PmbTyJE4mlWHBrgUW/b3WwoTLGtzdDRX2wsPNX+EtSbrCww3tXHWeEociETkmXlQiKi8+3lAUqqLvsIwMw/74eNvGRWRhASH++PCPGXh1xRR4+9ZD2r6zmNz9Nfz40W/QFdd9bpe6SI31p9ajWF9cabtifTHWnlwLdZG6zr/T2phwWUtsrKHCXkW9NSEhhv2uWvac89uIHBcvKhGZ0moNy52kpZn/Div5zktLM7TjRQhycBKJBPeNHYBvjs5H7wd7oLhIhxVv/4SX+72DjFN1m9ul0qiMc7aqohf0UGlUdfp9tsA5XBZQkzGcBM5vI3IWNV3AnAuekzPjdxu5KEEQkPD9Tix6cTkK8gqh8HTHhDmjEPP8YEilNe/bURep4T3Hu1pJl1QiRf6MfFFKxXMOF9kvDkUich5VJUel93O4FTk7VuAlFyWRSHDv6Lux7PBn6HZvJ2jUWix+eQVev3c2si9eq/Hjebp5IiYiBnKpvNJ2cqkcw9sMt5t1uSrDhItsi0ORiFwPh1uRqyibdEVFMdkilxEQ4o+PN7+NFxc9DQ8vBQ5tO4ZnOk3HXyu3oaYD6qb1ngadvvL5YDq9DlN7Ta1LyDbDIYUWwCGFtcChRUSuhcOtyJWkpBiSrRLJyYYqnkQu4vKZTHwybhGOp5wCAPR/pDdeWjIRygb1q/0YS/ctxeSNkyGTykwKaMilcuj0OiwethiTekyyeOzVxSGFZP9qMhSJiBwfh1uRq2AFXiI0adUY83e8h6c+fAIyuQw716RiUpdXcWjb0Wo/xqQek5A4PhExETGQSgwpi1QiRUxEDBLHJ4qabNUUe7gsgD1cRETVxAXPyZlxMXCick7tO4s5oz7H5dOZkEgkeOSVBzHu/Vi4ubtV+zHURWqoNCooFUq7mbPFHi4iIrKt6i74ygXPyVlxMXAisyJ6tMSSA59g6NMDIQgCVs9bj5f6vIX0k9UvH+/p5olA70C7SbZqigkXERHVTU0qEHK4FTkjVuAlqpRnPQ9MXTYJs359BfUbeOP0gfOY0uN1bFm1Q+zQbIIJFxER1V5NKhDOmAHcfTcXPCfnwwq8RNXSd3gklh3+DF0HdsTtQg0+GfcV5j21COqC22KHZlWcw2UBnMNFRC6tOhUIQ0MNbdPTWaWQnBcr8BJVi06nw08frUXce6uh1wto1q4p3o6fhubtHef4zzlcRERkO1VVIGzRApDLyydb5u7L4VbkyFiBl6haZDIZnnznYXySMAsNgnxx8fglPN/zDfy1cpvYoVkFEy4iIqq7yhZ83bED+OgjDrciIiITnQe0x9KD89Dt3k7QqLX49KnF+GTcVyguKq76zg6ECRcREVlGZRUIY2OBI0cqHi4YEmLYHxtr/TiJiMhu+AX6Ys6fb2Hc+49BKpWguKgYMrlM7LAsinO4LIBzuIiIwDW2nBnnJhGRDRxNPomwTs3gVd/+y79zDhcREdlW2eIXrEDoPGpS9p+IqA46RLVxiGSrpphwERFR3XDBV+dVk7L/M2ey4AkRkRlMuIiIqPa44Ktzc3cHEhLMJ85lE+2EBA4rdAZVfUb5GSaqMSZcRERUe1zw1flVVfaf8/ScB4ePElkFi2ZYAItmEJHLY1EF58eiKM5NqzUkU2lp5v+upf/+4eGGqqL8TJMLY9EMIiKyLS746vwqK/tPjo/DR4mshgkXERERVS0jAxg92nTb6NEshuJMOHyUyCqYcBEREVHlWPbfdZRNuqKimGwR1RETLiIiIqoYy/67Hg4fJbIoJlxERERkHsv+uyYOHyWyKCZcREREZB7L/rseDh8lsjiWhbcAloUnIiKnxrL/rsHc8NGQkIq3E7kwloUnIiIiy2HZf+fH4aNEVsOEi4iIiMjVcfgokdVwSKEFcEghEREROQUOHyWqFg4pJCIiIqKa4/BRIotjwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhKmHARERERERFZiUMmXIsWLULz5s3h4eGByMhI7Nmzp8K2AwYMgEQiKXcbNmyYsc24cePK7R88eLAtnkrtaLV1209ERERkKzxvIRfncAlXfHw8pk2bhlmzZuHAgQPo3LkzBg0ahKtXr5pt/9tvvyEzM9N4O3r0KGQyGR555BGTdoMHDzZp99NPP9ni6dRcfDzQsSOQkWF+f0aGYX98vG3jIiIiIiqL5y1EjpdwzZ8/HxMnTsT48ePRrl07LF26FF5eXli+fLnZ9g0aNEBQUJDxtmXLFnh5eZVLuBQKhUk7Pz8/WzydmtFqgZkzgbQ0YMCA8gevjAzD9rQ0QzteMSIiIiKx8LyFCICDJVxarRb79+9HdHS0cZtUKkV0dDRSU1Or9RjffvstHnvsMdSrV89k+/bt2xEQEICIiAg899xzuHHjRoWPodFooFKpTG424e4OJCQAYWHAuXOmB6+Sg9a5c4b9CQmG9kRERERi4HkLEQAHS7iuX78OnU6HwMBAk+2BgYHIysqq8v579uzB0aNH8fTTT5tsHzx4MFatWoWtW7di7ty52LFjB4YMGQKdTmf2cebMmQMfHx/jLSQkpPZPqqZCQoDt200PXikppget7dsN7Yi
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix(\n",
" [\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ]\n",
")\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4)\n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5)\n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(\n",
" m, 2 * n + 1\n",
")\n",
"X5 = np.matrix(\n",
" np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)\n",
").reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlLklEQVR4nO3df3CU9Z3A8U9+lJCKG4oKgTEq/jhsxYqtymAUdeTKtZ4DMuNVz3Oo1ztPGk+Bnq3ejDrW01SvA07vBHrenDpetdpe0dOrOohVCSL+QkutR9VyylkTrdasYhol+9wfW9KLkC+/kuwmeb1mdjL77HfTT/p0Sd59nme3IsuyLAAAANiuylIPAAAAUM5EEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQEJJo+mxxx6L008/PSZMmBAVFRVx991393g8y7K44oorYvz48VFbWxszZsyIl156qTTDAgAAw1JJo2nz5s1x1FFHxY033rjdx6+//vr47ne/G8uWLYu1a9fGXnvtFTNnzozf/e53AzwpAAAwXFVkWZaVeoiIiIqKili+fHnMnj07IopHmSZMmBBf//rX4+/+7u8iIqK9vT3GjRsXt9xyS5x11lklnBYAABguqks9QG82btwYra2tMWPGjO5tdXV1MXXq1FizZk2v0dTZ2RmdnZ3d9wuFQrzzzjuxzz77REVFRb/PDQAAlE6WZfHee+/FhAkTorKyb06sK9toam1tjYiIcePG9dg+bty47se2p7m5Oa666qp+nQ0AAChvmzZtiv33379PvlfZRtPuuuyyy2LhwoXd99vb2+OAAw6ITZs2RS6XK+FkAABAf8vn89HQ0BB77713n33Pso2m+vr6iIhoa2uL8ePHd29va2uLKVOm9Pq8mpqaqKmp2WZ7LpcTTQAAMEz05aU5Zfs5TRMnToz6+vpYuXJl97Z8Ph9r166NadOmlXAyAABgOCnpkab3338/Xn755e77GzdujOeeey7GjBkTBxxwQMyfPz/+4R/+IQ477LCYOHFiXH755TFhwoTud9gDAADobyWNpqeffjpOOeWU7vtbr0WaO3du3HLLLfGNb3wjNm/eHOeff368++67ccIJJ8QDDzwQI0eOLNXIAADAMFM2n9PUX/L5fNTV1UV7e7trmgAAYIjrj7//y/aaJgAAgHIgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAw9HR0RLS1Fb8C7CHRBAAMHS0tEXPmRIwaFVFfX/w6Z07E6tWlngwYxEQTADA0LF0aMX16xL33RhQKxW2FQvH+iSdGLFtW2vmAQUs0AQCDX0tLRFNTRJZFbNnS87EtW4rbv/Y1R5yA3SKaAIDBb9GiiKqq9JqqqojFiwdmHmBIEU0AwODW0RFxzz3bHmH6uC1bIpYv9+YQwC4TTQDA4JbP/+Eaph0pFIrrAXaBaAIABrdcLqJyJ/+kqawsrgfYBaIJAOg7pfh8pNraiFmzIqqr0+uqqyPOOKO4HmAXiCYAYM+V+vORFi6M6OpKr+nqiliwYGDmAYYU0QQA7Jly+HykE06IWLIkoqJi2yNO1dXF7UuWRDQ29v8swJAjmgCA3VdOn490wQURq1YVT9Xbeo1TZWXx/qpVxccBdsMOTv4FAEjY+vlIqbf73vr5SANxlKexsXjr6Ci+S14u5xomYI+JJgBg92z9fKQdvd33//98pIEKmNpasQT0GafnAQC7x+cjAcOEaAIAdo/PRwKGCdEEAOwen48EDBOiCQDYfT4fCRgGRBMAsPt8PhIwDIgmAGDP+HwkYIjzluMAwJ7z+UjAECaaAIC+4/ORgCHI6XkAAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgISyjqaurq64/PLLY+LEiVFbWxuHHHJIXH311ZFlWalHAwAAhonqUg+Qct1118XSpUvj1ltvjSOOOCKefvrpOO+886Kuri4uuuiiUo8HAAAMA2UdTY8//njMmjUrTjvttIiIOOigg+KOO+6IJ598stfndHZ2RmdnZ/f9fD7f73MCAABDV1mfnnf88cfHypUr45e//GVERDz//PPR0tISX/ziF3t9TnNzc9TV1XXfGhoaBmpcAABgCKrIyvgCoUKhEH//938f119/fVRVVUVXV1dcc801cdlll/X6nO0daWpoaIj29vbI5XIDMTYAAFAi+Xw+6urq+vTv/7I+Pe+uu+6K73//+3H77bfHEUccEc8991zMnz8/JkyYEHPnzt3uc2pqaqKmpmaAJwUAAIaqso6mSy65JC699NI466yzIiLiyCOPjFdffTWam5t7jSYAAIC+VNbXNH3wwQdRWdlzxKqqqigUCiWaCAAAGG7K+kjT6aefHtdcc00ccMABccQRR8S6deti0aJF8Zd/+ZelHg0AABgmyvqNIN577724/PLLY/ny5fHmm2/GhAkT4uyzz44rrrgiRowYsVPfoz8uBAMAAMpTf/z9X9bR1BdEEwAADB/98fd/WV/TBAAAUGqiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0ATA8dXREtLUVvwJAgmgCYHhpaYmYMydi1KiI+vri1zlzIlavLvVkAJQp0QTA8LF0acT06RH33htRKBS3FQrF+yeeGLFsWWnnA6AsiSYAhoeWloimpogsi9iypedjW7YUt3/ta444AbAN0QTA8LBoUURVVXpNVVXE4sUDMw8Ag4ZoAmDo6+iIuOeebY8wfdyWLRHLl3tzCAB6EE0ADH35/B+uYdqRQqG4HgB+TzQBMPTlchGVO/krr7KyuB4Afk80ATD01dZGzJoVUV2dXlddHXHGGcX1APB7ogmA4WHhwoiurvSarq6IBQsGZh4ABg3RBMDwcMIJEUuWRFRUbHvEqbq6uH3JkojGxtLMB0DZEk0ADB8XXBCxalXxVL2t1zhVVhbvr1pVfBwAPmYHJ3cDwBDT2Fi8dXQU3yUvl3MNEwBJogmA4am2ViwBsFOcngcAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQELZR9Prr78ef/EXfxH77LNP1NbWxpFHHhlPP/10qccCAACGiepSD5Dy29/+NhobG+OUU06J+++/P/bbb7946aWX4lOf+lSpRwMAAIaJso6m6667LhoaGuLmm2/u3jZx4sQSTgQAAAw3ZX163n/+53/GMcccE2eeeWaMHTs2jj766LjpppuSz+ns7Ix8Pt/jBgAAsLvKOpp+9atfxdKlS+Owww6LBx98MObNmxcXXXRR3Hrrrb0+p7m5Oerq6rpvDQ0NAzgxAAAw1FRkWZaVeojejBgxIo455ph4/PHHu7dddNFF8dRTT8WaNWu2+5zOzs7o7Ozsvp/P56OhoSHa29sjl8v1+8wAAEDp5PP5qKur69O//8v6SNP48ePjM5/5TI9tn/70p+O1117r9Tk1NTWRy+V63AAAAHZXWUdTY2NjbNiwoce2X/7yl3H
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 20,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2022-11-28 11:52:13 +01:00
"[<matplotlib.lines.Line2D at 0x7ff4e7adf4f0>]"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJ8klEQVR4nO3dd3RUdf7/8ddMOkkmEEqKhA6hJ1GUpYiyoqjoUpS2rj/X3f1aFpcSRcFdsIuggrCKuLvfFb/uSlOxoKKIhQ6KCZ3QIZSEnklC6sz9/TE4iEIIySR3yvNxzpycz82dmRfnOpl5OTPvazEMwxAAAAAA4IKsZgcAAAAAAG9GaQIAAACAClCaAAAAAKAClCYAAAAAqAClCQAAAAAqQGkCAAAAgApQmgAAAACgApQmAAAAAKgApQkAAAAAKkBpAgAAAIAKmFqali1bpttvv12JiYmyWCz64IMPzvu9YRiaOHGiEhISFBERoT59+mjnzp3mhAUAAAAQkEwtTYWFhUpJSdFrr712wd9PmTJFM2bM0KxZs7R27VpFRkaqb9++Ki4uruWkAAAAAAKVxTAMw+wQkmSxWLRw4UINGDBAkutdpsTERD388MN65JFHJEl5eXmKi4vT7NmzNWzYMBPTAgAAAAgUwWYHuJi9e/cqJydHffr0cW+LiYlR165dtXr16ouWppKSEpWUlLjXTqdTJ0+eVP369WWxWGo8NwAAAADzGIah/Px8JSYmymr1zAfrvLY05eTkSJLi4uLO2x4XF+f+3YVMmjRJTz31VI1mAwAAAODdsrOz1bhxY4/clteWpqoaP3680tPT3eu8vDw1adJE2dnZstlsJiYDAAA17bu9J/WHt76TYUjBVove/uM16tS4rtmxANQiu92upKQkRUdHe+w2vbY0xcfHS5Jyc3OVkJDg3p6bm6vU1NSLXi8sLExhYWG/2G6z2ShNAAD4sdNnSvW3T3fLElpHFkmP3JysHu2bmB0LgEk8+dUcrz1PU/PmzRUfH6+lS5e6t9ntdq1du1bdunUzMRkAAPA2hmHo8YWbdCTPNWG3W4v6ur9XS5NTAfAXpr7TVFBQoF27drnXe/fuVWZmpmJjY9WkSRONHj1azz77rFq3bq3mzZtrwoQJSkxMdE/YAwAAkKT532fr002u7zzHRIRo6tAUBVkZAAXAM0wtTd9//7169+7tXv/4XaR77rlHs2fP1qOPPqrCwkLdd999On36tHr27KnFixcrPDzcrMgAAMDL7D5WoCc/2upeT76jkxJiIkxMBMDfeM15mmqK3W5XTEyM8vLy+E4TAAB+prTcqTteX6VNh/IkScOvSdKkQZ1NTgXATDXx+t9rv9MEAABwKS8vyXIXphYNIzXhtvYmJwLgjyhNAADAJ63cdVxvfLtHkhQSZNGMYWmqE+q1g4EB+DBKEwAA8DknC0uVPj/TvX60b1t1vCLGvEAA/BqlCQAA+BTDMPTYexuVay+RJF3buoH+2LO5yakA+DNKEwAA8CnvrDugJVtzJUmxkaF6eXCKrIwXB1CDKE0AAMBn7Dqar2cW/XS8eGc1snEqEgA1i9IEAAB8Qkm5Q3+Zk6niMqck6e5fNdWN7eNMTgUgEFCaAACAT5iyOEvbjtglSa0bRemv/dqZnAhAoKA0AQAAr/ftjmP63xV7JUmhwVbNGJ6m8JAgk1MBCBSUJgAA4NWOF5To4fkb3OtxN7dVuwSbiYkABBpKEwAA8FqGYeixdzfqeIFrvPj1yQ11b49m5oYCEHAoTQAAwGu9vWa/lm4/KklqEBWqF+9MkcXCeHEAtYvSBAAAvFJWTr6e/WSbe/3inSlqGB1mYiIAgYrSBAAAvE5xmUMj52SotNw1Xvz33Zupd9tGJqcCEKgoTQAAwOu88Nl2ZeXmS5Laxkdr3C1tTU4EIJBRmgAAgFf5anuuZq/aJ0kKY7w4AC9AaQIAAF7jaH6xxi7Y6F7/rV87tYmLNjERAFCaAACAl3A6DY1dsFEnCkslSX3aNdLvftXU5FQAQGkCAABe4s1V+/TtjmOSpIbRYZp8R2fGiwPwCpQmAABguq2H7Zr82Xb3euqQFNWPYrw4AO9AaQIAAKYqKnVo5NwMlTpc48X/59rmurZ1Q5NTAcA5lCYAAGCq5z7dql1HCyRJHRJteqRvssmJAOB8lCYAAGCaL7bk6D9rDkiSwkOsmj4sTWHBjBcH4F0oTQAAwBS59mI99t658eJP3N5BrRpFmZgIAC6M0gQAAGqd02kofX6mTp0pkyT17RCnYVcnmZwKAC6M0gQAAGrdv1bs0cpdJyRJ8bZwvTCI8eIAvBelCQAA1KrNh/L04udZkiSLRZo6NEX1IkM9eydFRVJurusnAFQTpQkAANSaM6XlGjknQ2UOQ5L0wHUt1b1lA8/dwYoV0qBBUlSUFB/v+jlokLRypefuA0DAoTQBAIBa88yirdpzvFCS1LlxjMb0aeO5G3/9dalXL+njjyWn65xPcjpd62uvlWbN8tx9AQgolCYAAFArPtt0RHPWZUuS6oQGafqwNIUGe+ilyIoV0ogRkmFI5eXn/6683LX9z3/mHScAVUJpAgAANe5IXpHGvb/JvX7yNx3UvEGk5+5g6lQp6BLndwoKkqZN89x9AggYlCYAAFCjHE5DY+ZlKq/INV68X+cEDb6qsefuoKhI+vDDX77D9HPl5dLChQyHAHDZKE0AAKBGvbFst9bsOSlJSowJ1/MDOnl2vLjdfu47TJfidLr2B4DLQGkCAAA1JjP7tKZ+sUOSZLVIrwxLU0ydEM/eic0mWSv5ksZqde0PAJeB0gQAADznJ+dHKigp16i5GSp3usaLj+jdStc0j/X8fUZESP37S8HBFe8XHCwNHOjaHwAuA6UJAABU3wXOj/Tk6L9r/4kzkqS0JnU18obWNXf/6emSw1HxPg6HNGZMzWUA4LcoTQAAoHoucH6kj9v00LsxrnMwRVmcmj40TSFBNfiyo2dPaeZMyWL55TtOwcGu7TNnSj161FwGAH6L0gQAAKruAudHOmhrqMf7jnDv8szH09Rke0bNZ3ngAWn5ctdH9X78jpPV6lovX+76PQBUwSU+/AsAAFCBH8+PdLYwlVusGnPbI8oPj5Ik9d/yjQZmLXedH6k23uXp0cN1KSpyTcmz2fgOE4BqozQBAICq+fH8SD8Z9z2z2xB9l9RBktT4dI6e+WLm+edHqq0CExFBWQLgMXw8DwAAVM3Pzo+0PrGtpvcYLkmyOh2avugl2UpdgyA4PxIAX0ZpAgAAVfOT8yPZQ+to1O2PyGENkiSNXDVXVx3afm5fzo8EwIdRmgAAQNX85PxIE296UAfrxkuSuhzcoodWzTu3H+dHAuDj+E4TAACouvR0LdxxWh906C1Jii4p1LSPX1awce5je5wfCYCv450mAABQZQfaXakJt58rRM99/pqS7EddC86PBMBPUJoAAECVlDucGjUvQwWG6+XEIPsu/SZrheuXnB8JgB/h43kAAKBKZizdqYwDpyVJTWLr6OmnRkgv38f5kQD4HUoTAAC4bOv2ntSrX++SJAVZLZo+LFVRYcGSgilLAPwOH88DAACXJa+oTGPmZcppuNbpN7ZRWpN65oYCgBpEaQIAAJVmGIYeX7hJh04XSZK6No/VA9e1NDkVANQsShMAAKi0d9cf1Ccbj0iSbOHBmjY0VUFWi8mpAKBmUZoAAECl7D1eqCc+2uJev3BHZyXW5ftLAPwfpQkAAFxSablTo+Zm6EypQ5I0tEuSbu2UYHIqAKgdlCYAAHBJ077coY0H8yRJzRtEauLt7U1OBAC1h9IEAAAqtGr3cc36drckKSTIohnD0hQZxllLAAQOShMAALioU4WlSp+3QcbZ8eIP35SsTo1jzA0FALWM0gQAAC7IMAyNf3+TcuzFkqTuLevrvmtbmJwKAGofpQkAAFzQ3O+ytXhLjiSpbp0QTR2SKivjxQEEIEoTAAD4hV1HC/T0x1vd68l3dFZ8TLiJiQDAPJQmAABwnpJyh0bNzVBRmWu8+G+7NlHfDvEmpwIA81CaAADAeV7+Yoe2HLZLklo2jNSEfowXBxD
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, *bias*) zachodzi **niedostateczne dopasowanie** (*underfitting*)."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 21,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2022-11-28 11:52:13 +01:00
"[<matplotlib.lines.Line2D at 0x7ff478d03df0>]"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABW+klEQVR4nO3dd3hUZeL28XtKGiEVSIMEQpHem4AUlRUVFcSK6KJrFwu6qyv+XnWtiLrYBcsqrGJXxIqL9F5C7wQChEASWgohdea8f0wciECAkORM+X6uay5yTs6EG8fMzD3Pc55jMQzDEAAAAADgpKxmBwAAAAAAT0ZpAgAAAIBKUJoAAAAAoBKUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEqYWprmzZunK6+8UgkJCbJYLPr+++8rfN8wDD311FOKj49XSEiIBg4cqG3btpkTFgAAAIBfMrU0FRQUqGPHjnrnnXdO+v2XX35Zb775piZOnKilS5cqNDRUgwYNUlFRUS0nBQAAAOCvLIZhGGaHkCSLxaKpU6dq6NChklyjTAkJCfr73/+uf/zjH5Kk3NxcxcbGatKkSbrxxhtNTAsAAADAX9jNDnAqaWlpyszM1MCBA937IiIi1LNnTy1evPiUpam4uFjFxcXubafTqUOHDqlevXqyWCw1nhsAAACAeQzDUH5+vhISEmS1Vs/EOo8tTZmZmZKk2NjYCvtjY2Pd3zuZsWPH6plnnqnRbAAAAAA8W3p6uho1alQtP8tjS1NVjRkzRo888oh7Ozc3V0lJSUpPT1d4eLiJyQAAAADUtLy8PCUmJiosLKzafqbHlqa4uDhJUlZWluLj4937s7Ky1KlTp1PeLygoSEFBQSfsDw8PpzQBAAAAfqI6T83x2Os0JScnKy4uTjNnznTvy8vL09KlS9WrVy8TkwEAAADwJ6aONB05ckSpqanu7bS0NK1evVrR0dFKSkrS6NGj9fzzz6tFixZKTk7Wk08+qYSEBPcKewAAAABQ00wtTStWrNCFF17o3v7jXKSRI0dq0qRJeuyxx1RQUKC77rpLOTk5uuCCCzR9+nQFBwebFRkAAACAn/GY6zTVlLy8PEVERCg3N5dzmgAAAAAfVxPv/z32nCYAAAAA8ASUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEpQmgAAAACgEpQmAAAAAKgEpQkAAAAAKkFpAgAAAIBKUJoAAAAAoBKUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEpQmgAAAACgEpQmAAAAAKgEpQkAAAAAKkFpAgAAAIBKUJoAAAAAoBJ2swMAAAB4gzKHUyUOp0rKXLfiMqcMQ2oYFSKb1WJ2PAA1iNIEAAB8SmZukRZs3KsDB/NUbA9UicXqLjolDlfZOX77ZF8Xl/1xnMO932mc/O8LDbSpc1KUujZ23TonRSosOKB2/9EAahSlCQAAeLVSh1Mrdh7WnK3ZmrtypzbnO2v17y8ocWhB6gEtSD0gSbJapJZx4eraOFLdGkera+MoNYoKkcXCaBTgrShNAADA6+zNKdScLfs1d2u2FqYe1JHisnP+mVaLFGi3KtBmVaDdpiC79bht67Ht4/aVlDm1Zk+OsvKK3T/HaUib9uVp0748fbpktyQpJixI3ZpEqUtSlLo1iVbbhHAF2Di1HPAWlCYAAODxSsqcWrHzkOZs3a85W7K1NevISY+zGE512Jeq/mkr1Dp7p4LKShToKFWgo8z154cfKLBbFwWdpATZq1hiDMNQRk6hUnYdVsquw1qx87A2Z+ZVmM6XnV+sX9Zl6pd1mZKk4ACrOjSKVLfGUe4yFVknsEp/P4CaZzEM4xQzdH1DXl6eIiIilJubq/DwcLPjAACAM7Tn8FHN2bJfc7bs16LtB3S0xHHS46JDA9Vv30b1X/Cj+qUuV73CvJP/QLtdGjJE+uabGkztkl9UqjXpuVqx65BSdh3Wqt05px0Nax5TV90aR6lL4yh1axyl5PqhTOkDqqAm3v9TmgAAgEcoLnNoWdqh8qKUre37C056nMUidWwUqQEtG2hAyxi1jw6ULTxMcp7BuUxWq3TkiBQSUs3pK+dwGtqSma+U8hK1Ytdh7TlcWOl9okMDy6fzRalX03rq0CiCEgWcAUpTFVCaAADwXLsPHtXcrdnlo0kHVVh68tGkeqGB6n9eA/Vv2UD9WjRQVOhxU9mysqS4uDP/SzMzpdjYc0x+7rLyitzT+VJ2HdKGvXkqO9USfZK6NY7S/Rc1V//zGlCegEpQmqqA0gQAgOcoLnNoyY5DmrMlW3O37NeOAycfTbJapM5JURpwnms0qW1CuKynuhZSYaFUt65HjzSdicISh9bsyXGfG5Wy67ByC0tPOK5Dowjdf2FzDWwde+r/JoAfozRVAaUJAIBaVFgo5eVJ4eEVionDaejblD0aP2OrMvOKTnrXBmFBrtGk8xqob4v6Z7cwwrBh0o8/SmWVnDdUi+c0VQen09COA0e0eMchTV60U6nZFRe/aBUXpvsvaq7L2sVzcV3gOJSmKqA0AQBQCxYskMaPl6ZNc434WK3SkCEyHnlEM6NbaNz0zdr2pzf9NqtFXZOi1L+lqyi1ia9kNOlM/v5+/aTK3tZYLNL8+VKfPlX7O0zkdBqaviFTb81K1aZ9FRe6aNYgVKMubK6rOiZUeQVAwJdQmqqA0gQAQA2bMEEaNUqy2SqM9KQkttW4vrdoWWK7Codf1CpG13ZtpD7N6ysiJKD6ckycKN133wk5ZLdLDof07rvSPfdU399nAsMwNGtztt6clao16TkVvpcUXUf3Dmima7o0UqCd8gT/RWmqAkoTAAA16CQjPNujG+qVfn/V9JYVR3Q6J0VqzGWt1SM5uubyLFwovfaaNHXqsRGvq6+WHn7YK0eYTsUwDC1IPaC3ZqVqWdqhCt+LjwjWPf2b6YbuiQoOsJmUEDAPpakKKE0AANSg484lyg6N0ut9btKXHS+Rw3rszXrTQxl6rHiLBk16tfZWfTvFuVW+aOmOg3p7dqrmbztQYX+DsCDd1bepbuqZpNAgu0npgNpHaaoCShMAADWkfNW6fHuQ3u95jT7sNlSFgcHubzc4ckijF3ymG9b+T3aLPHbVOl+xavdhvTM7Vb9vyq6wP6pOgG6/IFl/7d1E4cHVOB0S8FCUpiqgNAEAUDNK9u7TlKvu1lu9b9ShOhHu/XWLj+rupd/q9hXfq05p8bE7eMj1kXzdhr25emd2qn5dn1lhXYywYLtu691Et/VJrnidK8DHUJqqgNIEAED1cjoN/bh2r16dvlnpOceWDw9wlOrmVb/o/kVfql5hxRXePPn6SL5qW1a+3pmdqh/W7NXx18wNDbTp5l6NdWffpqpfN8i8gEANoTRVAaUJAIDqs2DbAb00fZPWZ1QsRUM2zNHf53+ipNysE+/kZddH8jU7DxTo3Tmp+m5lhsqOa0/BAVYN75Gku/s1U1xEcCU/AfAulKYqoDQBAHDu1mfkatz0zScsNnBBfbse//f9apeZeuo7e/H1kXzJnsNH9d7cHfpyebpKHE73/kCbVdd1a6R7+jdTYnQdExMC1YPSVAWUJgAAqi790FG9+r8tmrZ6b4X9bRPC9fhlrdS3RQO/uD6SL8nKK9L783ZoytJdKio9Vp7sVotu6pmkMZe1VkggS5XDe1GaqoDSBADA2TtUUKK3Zm3Tp0t2qdRx7K1Co6gQPTqopa7skCCr9bjlw/3k+ki+5MCRYv1nQZr+u2inCkoc7v0tY8P09k2d1SI2zMR0QNVRmqqA0gQAwJk7WlKmjxak6b25O5RffGzUKKpOgB64qIVGnJ+kIHsloxB+dH0kX5FztEQfL9yp9+ftUGGpqzyFBNj07JC2uq5bosnpgLNHaaoCShMAAKfncBr6akW6XpuxVdn5x5YJDw6w6o4Lmuqu/k25xo+P25aVr/s/W6UtWfnufcM6N9RzQ9txcVx
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2022-11-28 11:52:13 +01:00
"[<matplotlib.lines.Line2D at 0x7ff47947b9a0>]"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUuklEQVR4nO3dd3hUZcLG4WcmvYcQ0iBAgBCQ3qUqgqKigrgWLIu9LK4Crm2/VXdtqOuCq2vZxYK9i12RKiC9GhAhgQABkpBQ0uvM+f6YOBCBECCTM+V3X9dc5Jw5Ex4cJzNP3nPe12IYhiEAAAAAwDFZzQ4AAAAAAO6M0gQAAAAA9aA0AQAAAEA9KE0AAAAAUA9KEwAAAADUg9IEAAAAAPWgNAEAAABAPShNAAAAAFAPShMAAAAA1IPSBAAAAAD1MLU0LVq0SBdffLGSkpJksVj0+eef17nfMAw9/PDDSkxMVEhIiEaOHKmMjAxzwgIAAADwSaaWptLSUvXo0UMvvvjiMe9/5pln9Pzzz+uVV17RihUrFBYWplGjRqmioqKJkwIAAADwVRbDMAyzQ0iSxWLRrFmzNHbsWEmOUaakpCTdc889+stf/iJJKiwsVHx8vGbOnKmrrrrKxLQAAAAAfIW/2QGOJysrS7m5uRo5cqRzX1RUlAYMGKBly5YdtzRVVlaqsrLSuW2323XgwAE1b95cFovF5bkBAAAAmMcwDBUXFyspKUlWa+OcWOe2pSk3N1eSFB8fX2d/fHy8875jmTp1qv7xj3+4NBsAAAAA95adna1WrVo1yvdy29J0qh588EFNmTLFuV1YWKjWrVsrOztbkZGRJiYDAAC+qqrGrjEvLlH2gXJJ0ovX9NJZHeNMTgV4p6KiIiUnJysiIqLRvqfblqaEhARJUl5enhITE5378/Ly1LNnz+M+LigoSEFBQUftj4yMpDQBAABTvL4kS3tKLbIGhWpgu+a6qE97LhsAXKwxX2Nuu05TSkqKEhISNG/ePOe+oqIirVixQgMHDjQxGQAAQMMVllfr+fmHl0z5v9GdKUyAhzF1pKmkpESZmZnO7aysLK1fv14xMTFq3bq1Jk2apMcff1ypqalKSUnRQw89pKSkJOcMewAAAO7upQWZOlRWLUm6tFdLdW0ZZXIiACfL1NK0evVqDR8+3Ln927VIEyZM0MyZM3XfffeptLRUt956qw4dOqQhQ4bo+++/V3BwsFmRAQAAGmz3wTK9sXSHJCnQ36p7zutobiAAp8Rt1mlylaKiIkVFRamwsJBrmgAAQJOa9ME6fb5+ryTptrPa6cELOpucCPB+rvj877bXNAEAAHiy9N2FzsLULDRAfzq7g8mJAJwqShMAAEAjMwxDT3z7i3P7rhGpigoJMDERgNNBaQIAAGhk83/dp+XbD0iS2jYP1TUD2picCMDpoDQBAAA0ohqbXVO/+9W5fd/5nRToz0cuwJPxCgYAAGhEH67OVua+EklS79bRuqBrgsmJAJwuShMAAEAjKams0fQ5LGQLeBtKEwAAQCP536LtKiiplCRd0DVBfdrEmJwIQGOgNAEAADSCvKIKzVi0XZLkb7XovvM7mZwIQGOhNAEAADSCaT9sVXm1TZJ07ZltlBIbZnIiAI2F0gQAAHCatuQW6+M12ZKkiCB/3TUi1eREABoTpQkAAOA0Tf1us+yG4+s/De+gmLBAcwMBaFSUJgAAgNOwJKNAC7fkS5KSooJ1w+C25gYC0OgoTQAAAKfIbjf05Lebndv3np+m4AA/ExMBcAVKEwAAwCmatW6PfskpkiR1bRmpMT1ampwIgCtQmgAAAE5BRbVNz/6wxbn91ws6y2plIVvAG1GaAAAATpLdbmj63K3KKayQJA1Pa6FBHWJNTgXAVfzNDgAAAOBJtuWX6MHP0rUy64AkyWqRHryws8mpALgSpQkAAKABqmrs+t+ibXp+fqaqauzO/XeP6KiO8REmJgPgapQmAACAE1i366Ae+DRdW/KKnfuSY0L05KXdNDS1hYnJADQFShMAAMBxlFTW6NnZW/Tmsh0yahevtVqkW4a206SRHRUSyPTigC+gNAEAABzD/F/z9LdZG7W3drIHSeqSFKmnL+uuri2jTEwGoKlRmgAAAI6QX1ypR7/+RV9t2OvcFxxg1eSRHXXTkBT5+zH5MOBrKE0AAACSDMPQx2t264lvNquwvNq5f0iHWD15aTe1bh5qYjoAZqI0AQAAn7ejoFR/nZWupdv2O/dFhwboodFnaFzvlrJYWLQW8GWUJgAA4LOqbXa9ujhLz83dqsojphEf2zNJD110hpqHB5mYDoC7oDQBAACf9PPuQ7r/03Rtzily7msZHaLHL+2q4WlxJiYD4G4oTQAAwPuUl0tFRVJkpBQSUueusqoaTfthq17/KUv2I6YRv35Qiu45r6PCgvh4BKAupn8BAADeY8kSadw4KTxcSkhw/DlunPTTT5KkH7fm67zpi/TqksOFqVNChGb9abAevvgMChOAY+InAwAA8A4vvyxNnCj5+Un22uuT7Hbpq6+0f/Z8Pf7A/zSrNMx5eKC/VZNGpuqWoe0UwDTiAOpBaQIAAJ5vyRJHYTIMqabGuduQNCttqB4752YdPKIwDWzXXE+O66aU2LBjfDMAqIvSBAAAPN+0aY4RpiMKU3ZUvP46aqIWp/R27ou0VepvV/TT5X1bMY04gAajNAEAAM9WXi598YUMu10HQiK1rXmyViR31UtnXq7ywGDnYaM3L9IjC15V3KO7JQoTgJNAaQIAAB7FZje052C5MvOLtW1fqTJ35Wvb+KnKbJ6sQyGRRx2fWJSvx394SSO2rXLsKCo6akY9AKgPpQkAALilimqbtuWXaFt+qTL3lTi+3leirILSOgvRSpJadTnq8RbDrj+u/Ub3LnpL4VXljp1Wq2MacgA4CZQmAADQeOpZH+l49pdU1i1G+SXK3FeiPYfKZRgN/6sTiwvUfn+2OhRkq/2B3RqQvVEdC3YdPsDfXxozhlEmACeN0gQAAE7fkiWOyRi++MIxzbfV6igo99wjDR4swzC0t7BCW/OKlZl3uBhtyy/RwbLqBv81/laL2saGqX2LMHWIC1f7FuHqEBeudpnpCj/nYtXbsmw2afLkRvjHAvA1lCYAAHB6jlgfybDblRfeXFtjW2vrHinjiY+0tV+eMhSmksqaE3+vWuFB/mofF35UOWodE3rsNZVaDZVeekn605+OmkVP/v6OwvTSS9LgwY3wDwbgayhNAADgpBmGofySSmXMW66tr32vref9SRmxrbU1to2KgsPrHlwpSccuTAmRwWofF6YOLcLVPi7c+WdcRNDJTwl+++1St27S9OnSrFl1R7wmT6YwAThlFsM4mbOFPU9RUZGioqJUWFioSC78BADgpB0ordKW3GJl7CvW1rxibc0rUUZe8UmdVteyulhp3dorNT5cqXERSo0LV7sWYYoIDnBN6FO4tgqAd3DF539GmgAAgCSpsLy6thQVa2tubTnaV6yCkqoGf4/EonylFuxSWsFOpRbsUseCXeqwP1vhNZVSSUnTFZiQEMoSgEZDaQIAwMdl5BXrPwsy9dWGvbI38PyTuIggdWwWqNTP3lFavqMgpRbsUmRV2fEfxPpIADwUpQkAAB+1OadI/5mfqW835hx30rnY8EClxkWoY3y4UuMjlJbgOLUuOjTQcQrcn4c7rh06EdZHAuDBKE0AAPiYjXsK9cL8DM3elFdnf0xYoC7omlBbjBxFqXl40PG/UUiIY5KFr76qO1vd77E+EgAPR2kCAMBHbMg+pBfmZ2ju5n119seGB+m2Ye10zZmtFRp4kh8NpkyRPv+8/mNYHwmAh6M0AQDg5dbsPKgX5mdo4Zb8OvvjIoJ0+1ntNb5/a4UE+p3aNx8yhPWRAHg9ShMAAF5q1Y4Den5ehhZnFNTZnxgVrDvObq8r+iYrOOAUy9KRWB8JgJejNAEA4EUMw9Dy7Y6ytGz7/jr3tYwO0cThHXRZn5YK8m+EsnSkwYMdN9ZHAuCFKE0AAHgBwzD0U+Z+PT8vQyt3HKhzX+uYUN05vIMu7d1SAX5W1wZhfSQAXoj
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (*variance*) zachodzi **nadmierne dopasowanie** (*overfitting*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie (błąd systematyczny, *bias*)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wariancja (*variance*)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"40%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.3. Regularyzacja"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 23,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(\n",
" h,\n",
" fJ,\n",
" fdJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=0.001,\n",
" maxEpochs=1.0,\n",
" batchSize=100,\n",
" adaGrad=False,\n",
" logError=False,\n",
" validate=0.0,\n",
" valStep=100,\n",
" lamb=0,\n",
" trainsetsize=1.0,\n",
"):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na następnym wykładzie)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
"\n",
" XT, YT = X, Y\n",
"\n",
" m_end = int(trainsetsize * len(X))\n",
"\n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv]\n",
" XT, YT = X[mv:m_end], Y[mv:m_end]\n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
"\n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n, 1)\n",
"\n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end, :], YT[start:end, :]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
"\n",
" if logError:\n",
" errorsX.append(float(i * batchSize) / m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i * batchSize) / m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
"\n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 24,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:, 0], data[:, 1], n)\n",
"Y = data[:, 2]\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 25,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(\n",
" X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25\n",
"):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" plt.subplot(121)\n",
" plt.scatter(\n",
" X[:, 2].tolist(),\n",
" X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100,\n",
" cmap=plt.cm.get_cmap(\"prism\"),\n",
" )\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=alpha,\n",
" adaGrad=adaGrad,\n",
" maxEpochs=maxEpochs,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=validate,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02), np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
" plt.ylim(-1, 1.2)\n",
" plt.xlim(-1, 1.2)\n",
" plt.legend()\n",
" plt.subplot(122)\n",
" plt.plot(err[0], err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2], err[3], lw=3, label=\"Validation error\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 26,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-11-28 11:52:13 +01:00
"/tmp/ipykernel_74/2678993393.py:5: RuntimeWarning: overflow encountered in exp\n",
" y = 1.0 / (1.0 + np.exp(-x))\n",
"/tmp/ipykernel_74/2651435526.py:38: UserWarning: The following kwargs were not used by contour: 'lw'\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSAAAAKZCAYAAACod4UiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3iTVRvH8W+SbkZBRlmVPZW9BAQRkYKI4GSICDIE2TgAmTIdiExF2SgIoqK8gCCCAxBBQBBkb5A9Cy1dSd4/HikUWpqkSdPx+1xXLpo85zznTtOWJ3fOObfJbrfbEREREREREREREfEAs7cDEBERERERERERkYxLCUgRERERERERERHxGCUgRURERERERERExGOUgBQRERERERERERGPUQJSREREREREREREPEYJSBEREREREREREfEYJSBFRERERERERETEY5SAFBEREREREREREY9RAlJEREREREREREQ8RglIERERERERERER8RglIEVEREREgKlTp1KkSBECAgKoWbMmmzdvvmf7CRMmULp0aQIDAwkNDaVv375ERUWlUrQiIiIi6YcSkCIiIiKS6S1atIh+/foxbNgwtm3bRsWKFQkLC+PcuXOJtl+wYAEDBgxg2LBh7Nmzh5kzZ7Jo0SLefvvtVI5cREREJO0z2e12u7eDEBERERHxppo1a1K9enWmTJkCgM1mIzQ0lJ49ezJgwIC72vfo0YM9e/awZs2a+Mdef/11Nm3axPr161MtbhEREZH0wMfbAXiDzWbj1KlTZMuWDZPJ5O1wRERERJxmt9u5du0aBQoUwGzWopaUiImJYevWrQwcODD+MbPZTMOGDdm4cWOifWrXrs0XX3zB5s2bqVGjBocPH2bFihW89NJLibaPjo4mOjo6/r7NZuPSpUvkypVL16MiIiKSLjlzPZopE5CnTp0iNDTU22GIiIiIpNiJEycoVKiQt8NI1y5cuIDVaiUkJCTB4yEhIezduzfRPm3atOHChQs8/PDD2O124uLi6Nq1a5JLsMeOHcs777zj9thFREREvM2R69FMmYDMli0bYHyDsmfP7uVoRERERJwXHh5OaGho/HWNpK5ffvmFMWPG8PHHH1OzZk0OHjxI7969GTlyJEOGDLmr/cCBA+nXr1/8/atXr3L//ffrelRERETSLWeuRzNlAvLmMpfs2bPrgk9ERETSNS3fTbncuXNjsVg4e/ZsgsfPnj1Lvnz5Eu0zZMgQXnrpJTp16gRA+fLliYiIoEuXLgwaNOiuZUj+/v74+/vfdR5dj4qIiEh658j1qDYMEhEREZFMzc/Pj6pVqyYoKGOz2VizZg21atVKtE9kZORdSUaLxQIY+yGJiIiIyC2ZcgakiIiIiMjt+vXrx8svv0y1atWoUaMGEyZMICIigg4dOgDQrl07ChYsyNixYwFo1qwZ48ePp3LlyvFLsIcMGUKzZs3iE5EiIiIiYlACUkREREQyvZYtW3L+/HmGDh3KmTNnqFSpEitXrowvTHP8+PEEMx4HDx6MyWRi8ODB/Pvvv+TJk4dmzZoxevRobz0FERERkTTLZM+Ea0TCw8MJDg7m6tWr2nNHREREvMJmsxETE5PkcV9f33vOpNP1TPqm109ERFLKarUSGxvr7TAkA3Pn9ahmQIqIiIikspiYGI4cOYLNZrtnuxw5cpAvXz4VmhEREZF4drudM2fOcOXKFW+HIpmAu65HlYAUERERSUV2u53Tp09jsVgIDQ29q5DJzTaRkZGcO3cOgPz586d2mCIiIpJG3Uw+5s2bl6CgIH1QKR7h7utRJSBFREREUlFcXByRkZEUKFCAoKCgJNsFBgYCcO7cOfLmzavCJiIiIoLVao1PPubKlcvb4UgG587r0bs/chcRERERj7FarQD4+fkl2/ZmglL7O4mIiAjcuia414eYIu7krutRJSBFREREvMCR5VJaUiUiIiKJ0TWCpBZ3/awpASkiIiIiIiIiIiIeowSkiIiIiIiIiIikO0WKFGHChAkOt//ll18wmUyqIO4FSkCKiIiIiIiIiIjHmEyme96GDx/u0nn//PNPunTp4nD72rVrc/r0aYKDg10aT1ynKtgiIiIiIiIiIuIxp0+fjv960aJFDB06lH379sU/ljVr1viv7XY7VqsVH5/kU1Z58uRxKg4/Pz/y5cvnVJ/UEhMTc1eRQqvVislkwmx2bv6gq/08Ke1EIiIiIpKJ2O32ZNvYbLZUiERERETSK5vNzsXr0V672WzJX88A5MuXL/4WHByMyWSKv793716yZcvGDz/8QNWqVfH392f9+vUcOnSI5s2bExISQtasWalevTo//fRTgvPeuQTbZDIxY8YMnn76aYKCgihZsiRLly6NP37nEuw5c+aQI0cOVq1aRdmyZcmaNSuNGzdOkDCNi4ujV69e5MiRg1y5ctG/f39efvllWrRocc/nvH79eurWrUtgYCChoaH06tWLiIiIBLGPHDmSdu3akT17drp06RIfz9KlSylXrhz+/v4cP36cy5cv065dO3LmzElQUBBNmjThwIED8edKql9aohmQIiIiIqnI19cXk8nE+fPnyZMnT6KVBe12OzExMZw/fx6z2XzXp+EiIiIiAJcjY6g66qfkG3rI1sENyZXV3y3nGjBgAOPGjaNYsWLkzJmTEydO8MQTTzB69Gj8/f2ZN28ezZo1Y9++fdx///1Jnuedd97h/fff54MPPmDy5Mm8+OKLHDt2jPvuuy/R9pGRkYwbN47PP/8cs9lM27ZteeONN5g/fz4A7733HvPnz2f27NmULVuWiRMn8t133/Hoo48mGcOhQ4do3Lgxo0aNYtasWZw/f54ePXrQo0cPZs+eHd9u3LhxDB06lGHDhgGwbt06IiMjee+995gxYwa5cuUib968tG7dmgMHDrB06VKyZ89O//79eeKJJ9i9eze+vr7xz+POfmmJEpAikvbs3QuffAK//grXrkH27PDoo9C1K5Qq5e3oRERSxGKxUKhQIU6ePMnRo0fv2TYoKIj7778/TS2fEREREfGEESNG8Pjjj8ffv++++6hYsWL8/ZEjR7JkyRKWLl1Kjx49kjxP+/btad26NQBjxoxh0qRJbN68mcaNGyfaPjY2lmnTplG8eHEAevTowYgRI+KPT548mYEDB/L0008DMGXKFFasWHHP5zJ27FhefPFF+vTpA0DJkiWZNGkSjzzyCJ988gkBAQEANGjQgNdffz2+37p164iNjeXjjz+Of+43E48bNmygdu3aAMyfP5/Q0FC+++47nn/++fjncXu/tEYJSBFJO06dgnbtYM0a8PGBuLhbx3buhI8+grAwmDsXQkK8F6eISAplzZqVkiVLEhsbm2Qbi8WCj49PojMkRURERDKaatWqJbh//fp1hg8fzvLlyzl9+jRxcXHcuHEj2aXFFSpUiP86S5YsZM+enXPnziXZPigoKD75CJA/f/749levXuXs2bPUqFEj/rjFYqFq1ar33Cpnx44d/P333/GzKMFY4WKz2Thy5Ahly5ZN9DmDsU/l7c9hz549+Pj4ULNmzfjHcuXKRenSpdmzZ0+S/dIaJSBFJG04dgxq1YLz5437tycfAaxW4981a6B6ddi4EQoWTN0YRUTcyGKxYLFYvB2GiIiISJqQJUuWBPffeOMNVq9ezbhx4yhRogSBgYE899xzxMTE3PM8N5ck32Qyme6ZLEysvSN7dd/L9evXefXVV+nVq9ddx25fPn7ncwYIDAx06QNoV/ulFiUgRcT7YmONmY3nz9+deLxTXBycPg1NmsBff4HevIuIiIiISCaVM8iPrYMbenV8T9mwYQPt27ePX/p8/fr1ZLevcbfg4GBCQkL4888/qVevHmBUmN62bRuVKlVKsl+VKlXYvXs3JUqUSHEMZcuWJS4ujk2bNsUvwb548SL79u2jXLlyKT5/alECUkS8b+lS2LfP8fZxccaS7B9+gCef9FxcIiIiIiIiaZjZbHJbEZi0pmTJknz77bc0a9YMk8nEkCFD7jmT0VN69uzJ2LFjKVGiBGXKlGHy5Mlcvnz5nrMN+/fvz0MPPUSPHj3o1Kk
"text/plain": [
"<Figure size 1600x800 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrażenie regularyzacyjne** zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przedstawiona tu metoda regularyzacji to tzw. metoda L2 (*ridge*). Istnieją również inne metody regularyzacji, które charakteryzują się trochę innymi własnościami, np. L2 (*lasso*) lub *elastic net*. Więcej na ten temat można przeczytać np. tu:\n",
"* [L1 and L2 Regularization Methods](https://towardsdatascience.com/l1-and-l2-regularization-methods-ce25e7fc831c)\n",
"* [Ridge and Lasso Regression: L1 and L2 Regularization](https://towardsdatascience.com/ridge-and-lasso-regression-a-complete-guide-with-python-scikit-learn-e20e34bcbf0b)\n",
"* [Elastic Net Regression](https://towardsdatascience.com/elastic-net-regression-from-sklearn-to-tensorflow-3b48eee45e91)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej funkcja kosztu\n",
"\n",
"$$\n",
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} h_\\theta(x^{(i)}) - y^{(i)} \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 27,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0 / m * -np.sum(\n",
" np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0\n",
" ) + lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ_(h, theta, X, y, lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" g[1:] += lamb / m * theta[1:]\n",
" return g\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 28,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(\n",
" min=0.0, max=0.5, step=0.005, value=0.01, description=r\"$\\lambda$\", width=300\n",
")\n",
"\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 29,
"metadata": {
"scrolled": false,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-11-28 11:52:13 +01:00
"model_id": "546e94aba9124fcf8da68fc67ab9cdd0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 30,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(Lambda, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(r\"$\\lambda$\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 31,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKtCAYAAACuZBksAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB760lEQVR4nOzdeZibZaH//0+SmSSzZvalM9NOaUtpoRvdLIogllMOWkVFKiDbAVzhgCPnCz0IFVArX5Fv1aIosojCjyoK4qEHhAoIpVLa0tJCF7p32tmXZNZkJsnvjydrO11meSazvF/XlSvJM8mTe7QR+va+n9sSDAaDAgAAAAAAAIABZk30AAAAAAAAAACMTMRHAAAAAAAAAKYgPgIAAAAAAAAwBfERAAAAAAAAgCmIjwAAAAAAAABMQXwEAAAAAAAAYAriIwAAAAAAAABTEB8BAAAAAAAAmIL4CAAAAAAAAMAUxEcAAAAAAAAApkh4fHzooYdUXl4up9Op+fPna/369Sd8/YoVKzR58mSlpKSorKxM3/nOd9TZ2TlIowUAAAAAAABwqhIaH1etWqWKigotW7ZMmzZt0owZM7Ro0SLV1tb2+Pqnn35ad9xxh5YtW6bt27fr0Ucf1apVq/Tf//3fgzxyAAAAAAAAACdjCQaDwUR9+Pz58zV37lytXLlSkhQIBFRWVqabb75Zd9xxxzGvv+mmm7R9+3atWbMmcuy73/2u3nnnHb311luDNm4AAAAAAAAAJ5eUqA/2+XzauHGjli5dGjlmtVq1cOFCrVu3rsf3nHPOOfrDH/6g9evXa968edq7d69Wr16tq6666rif4/V65fV6I88DgYAaGxuVm5sri8UycL8QAAAAAAAAMAoEg0G1tLRozJgxslpPvLA6YfGxvr5efr9fhYWFcccLCwu1Y8eOHt9zxRVXqL6+Xp/4xCcUDAbV3d2tb3zjGydcdr18+XLdc889Azp2AAAAAAAAYLQ7dOiQSktLT/iahMXHvnj99df1ox/9SL/85S81f/587d69W7fccovuu+8+3XXXXT2+Z+nSpaqoqIg8d7vdGjt2rA4dOqTMzMzBGjoAAAAAAAAwIng8HpWVlSkjI+Okr01YfMzLy5PNZlNNTU3c8ZqaGhUVFfX4nrvuuktXXXWVbrjhBknStGnT1NbWpq997Wu68847e5zm6XA45HA4jjmemZlJfAQAAAAAAAD66FQuaZiw3a7tdrtmz54dt3lMIBDQmjVrtGDBgh7f097efkxgtNlskoy15gAAAAAAAACGjoQuu66oqNA111yjOXPmaN68eVqxYoXa2tp03XXXSZKuvvpqlZSUaPny5ZKkxYsX68EHH9SsWbMiy67vuusuLV68OBIhAQAAAAAAAAwNCY2PS5YsUV1dne6++25VV1dr5syZeumllyKb0Bw8eDBupuP3vvc9WSwWfe9739Phw4eVn5+vxYsX64c//GGifgUAAAAAAAAAx2EJjrL1yh6PRy6XS263m2s+AgAAAACAYSUYDKq7u1t+vz/RQ8EIl5ycfNyVxr3pa8Nqt2sAAAAAAIDRyufzqaqqSu3t7YkeCkYBi8Wi0tJSpaen9+s8xEcAAAAAAIAhLhAIaN++fbLZbBozZozsdvsp7TQM9EUwGFRdXZ0qKys1adKkfu21QnwEAAAAAAAY4nw+nwKBgMrKypSampro4WAUyM/P1/79+9XV1dWv+Gg9+UsAAAAAAAAwFMRuzAuYaaBm1vInFgAAAAAAAIApiI8AAAAAAAAATEF8BAAAAAAAwLBRXl6uFStWnPLrX3/9dVksFjU3N5s2JhwfG84AAAAAAADANOeff75mzpzZq2B4Iu+++67S0tJO+fXnnHOOqqqq5HK5BuTz0TvERwAAAAAAgGEmEAiqqd2X0DFkp9pltQ7MpiTBYFB+v19JSSdPVfn5+b06t91uV1FRUV+HZiqfzye73R53zO/3y2Kx9Hpzob6+z2zERwAAAAAAgGGmqd2n2T94NaFj2Pi9hcpNd5zwNddee63eeOMNvfHGG/rZz34mSdq3b5/279+vT33qU1q9erW+973vaevWrfr73/+usrIyVVRU6F//+pfa2to0ZcoULV++XAsXLoycs7y8XLfeeqtuvfVWScauzI888ohefPFFvfzyyyopKdFPf/pTfe5zn5NkLLv+1Kc+paamJmVlZemJJ57QrbfeqlWrVunWW2/VoUOH9IlPfEKPP/64iouLJUnd3d2qqKjQk08+KZvNphtuuEHV1dVyu916/vnnj/v7vvXWW1q6dKk2bNigvLw8feELX9Dy5csjMzXLy8t1/fXX66OPPtLzzz+vL37xizr//PN166236sknn9Qdd9yhXbt2affu3XK5XLrlllv0t7/9TV6vV+edd55+/vOfa9KkSZIU+T2Ofl95eXlf/us0zdBKoQAAAAAAABgxfvazn2nBggW68cYbVVVVpaqqKpWVlUV+fscdd+jHP/6xtm/frunTp6u1tVUXX3yx1qxZo/fee08XXXSRFi9erIMHD57wc+655x5ddtllev/993XxxRfryiuvVGNj43Ff397ergceeEC///3v9c9//lMHDx7UbbfdFvn5/fffr6eeekqPP/641q5dK4/Hc8LoKEl79uzRRRddpC996Ut6//33tWrVKr311lu66aab4l73wAMPaMaMGXrvvfd01113RcZz//3367e//a0++OADFRQU6Nprr9WGDRv0wgsvaN26dQoGg7r44ovV1dUV93sc/b6hhpmPAAAAAAAAMIXL5ZLdbldqamqPS5/vvfdeXXjhhZHnOTk5mjFjRuT5fffdp+eee04vvPDCMREv1rXXXqvLL79ckvSjH/1IP//5z7V+/XpddNFFPb6+q6tLDz/8sCZMmCBJuummm3TvvfdGfv6LX/xCS5cu1Re+8AVJ0sqVK7V69eoT/q7Lly/XlVdeGZmROWnSJP385z/Xeeedp1/96ldyOp2SpAsuuEDf/e53I+9788031dXVpV/+8peR3/2jjz7SCy+8oLVr1+qcc86RJD311FMqKyvT888/ry9/+cuR3yP2fUMR8REAAAAAAAAJMWfOnLjnra2t+v73v68XX3xRVVVV6u7uVkdHx0lnPk6fPj3yOC0tTZmZmaqtrT3u61NTUyPhUZKKi4sjr3e73aqpqdG8efMiP7fZbJo9e7YCgcBxz7llyxa9//77euqppyLHgsGgAoGA9u3bpylTpvT4O0vGdSljf4ft27crKSlJ8+fPjxzLzc3V5MmTtX379uO+bygiPgIAAAAAAAwz2al2bfzewpO/0OQx9NfRu1bfdttteuWVV/TAAw9o4sSJSklJ0aWXXiqf78Sb6yQnJ8c9t1gsJwyFPb0+GAz2cvTxWltb9fWvf13/+Z//eczPxo4dG3nc007dKSkpslh6v3lPX983mIiPAAAAAAAAw4zVajnpZi9Dhd1ul9/vP6XXrl27Vtdee21kuXNra6v2799v4uiO5XK5VFhYqHfffVef/OQnJRk7SW/atEkzZ8487vvOPvtsffjhh5o4cWK/xzBlyhR1d3frnXfeiSy7bmho0M6dOzV16tR+n38wseEMAAAAAAAATFNeXq533nlH+/fvV319/QlnJE6aNEl/+ctftHnzZm3ZskVXXHHFCV9vlptvvlnLly/XX//6V+3cuVO33HKLmpqaTjjL8Pbbb9fbb7+tm266SZs3b9ZHH32kv/71rye8VuXxTJo0SZ///Od144036q233tKWLVv01a9+VSUlJfr85z/fn19t0BEfAQAAAAAAYJrbbrtNNptNU6dOVX5+/gmv3/jggw8qOztb55xzjhYvXqxFixbp7LPPHsTRGm6//XZdfvnluvrqq7VgwQKlp6dr0aJFkU1jejJ9+nS98cYb2rVrl84991zNmjVLd999t8aMGdOnMTz++OOaPXu2PvvZz2rBggUKBoNavXr1MUvGhzpLsL8L2ocZj8cjl8slt9utzMzMRA8HAAAAAADgpDo7O7Vv3z6NHz/+hAEM5ggEApoyZYouu+wy3XfffYkezqA40Z+53vQ1rvkIAAA
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 32,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=0.01,\n",
" trainsetsize=m,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(M, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(\"trainset size\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 33,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKnCAYAAAAP/zpKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACHqUlEQVR4nOzdeXxcdb3/8ffMJJnsabMvDXShdKE0LQVK6wUBixW8vaA/hQtcWS7i9Qq49KJQxVZA6eWCiFoQZRWFK4hcwEsvWKooS2VrUyhd6EqX7GmzJzPJzPn9Mclktuwzc2Z5PR+PeSTne86Z+aQMbfLO5/v9WgzDMAQAAAAAAAAAYWY1uwAAAAAAAAAAiYnwEQAAAAAAAEBEED4CAAAAAAAAiAjCRwAAAAAAAAARQfgIAAAAAAAAICIIHwEAAAAAAABEBOEjAAAAAAAAgIggfAQAAAAAAAAQESlmFxBtbrdbNTU1ysnJkcViMbscAAAAAAAAIK4YhqH29naVl5fLah2+tzHpwseamhpVVlaaXQYAAAAAAAAQ1w4dOqQpU6YMe03ShY85OTmSPH84ubm5JlcDAAAAAAAAxJe2tjZVVlZ6c7bhJF34ODDVOjc3l/ARAAAAAAAAGKfRLGnIhjMAAAAAAAAAIoLwEQAAAAAAAEBEED4CAAAAAAAAiIikW/MRAAAAAAAgXhmGob6+PrlcLrNLQYJLTU2VzWab8PMQPgIAAAAAAMQBp9Op2tpadXV1mV0KkoDFYtGUKVOUnZ09oechfAQAAAAAAIhxbrdb+/fvl81mU3l5udLS0ka10zAwHoZhqLGxUYcPH9bMmTMn1AFJ+AgAAAAAABDjnE6n3G63KisrlZmZaXY5SAJFRUU6cOCAent7JxQ+suEMAAAAAABAnLBaiXIQHeHqrOUdCwAAAAAAACAiCB8BAAAAAAAQN6ZOnap777131Ne/+uqrslgsamlpiVhNGBprPgIAAAAAACBizj77bC1YsGBMgeFw3nnnHWVlZY36+qVLl6q2tlZ5eXlheX2MDZ2PAAAAAAAAMJVhGOrr6xvVtUVFRWPadCctLU2lpaUxuTu40+kMGnO5XHK73WN+rvHeF2mEjwAAAAAAAHHG7TbU3OEw9eF2GyPWedVVV+mvf/2rfvrTn8pischisejAgQPeqdD/93//p0WLFslut+v111/X3r17deGFF6qkpETZ2dk67bTT9Morr/g9Z+C0a4vFooceekif+9znlJmZqZkzZ+qFF17wng+cdv3YY49p0qRJevnllzVnzhxlZ2frM5/5jGpra7339PX16etf/7omTZqkgoIC3XTTTbryyit10UUXDfv1vv766zrzzDOVkZGhyspKff3rX1dnZ6df7bfffruuuOIK5ebm6itf+Yq3nhdeeEFz586V3W7XwYMHdezYMV1xxRWaPHmyMjMzdf7552v37t3e5xrqvljDtGsAAAAAAIA4c6zLqUU/fGXkCyPovVuWqSDbPuw1P/3pT/XRRx9p3rx5uu222yR5OhcPHDggSbr55pt19913a/r06Zo8ebIOHTqkCy64QD/60Y9kt9v1+OOPa8WKFdq1a5eOO+64IV/n1ltv1X/913/prrvu0s9//nNdfvnl+vjjj5Wfnx/y+q6uLt199936zW9+I6vVqn/5l3/RjTfeqCeeeEKSdOedd+qJJ57Qo48+qjlz5uinP/2pnnvuOZ1zzjlD1rB371595jOf0Q9/+EM98sgjamxs1PXXX6/rr79ejz76qPe6u+++W6tXr9aaNWskSa+99pq6urp055136qGHHlJBQYGKi4t16aWXavfu3XrhhReUm5urm266SRdccIG2b9+u1NRU79cReF+sIXwEAAAAAABAROTl5SktLU2ZmZkqLS0NOn/bbbfpvPPO8x7n5+erqqrKe3z77bfrf/7nf/TCCy/o+uuvH/J1rrrqKl166aWSpDvuuEM/+9nP9Pbbb+szn/lMyOt7e3v1wAMPaMaMGZKk66+/3huOStLPf/5zrVq1Sp/73OckSevWrdP69euH/VrXrl2ryy+/XN/85jclSTNnztTPfvYzffKTn9QvfvELpaenS5LOPfdc/cd//If3vtdee029vb26//77vV/7QOj4xhtvaOnSpZKkJ554QpWVlXruuef0xS9+0ft1+N4XiwgfAQAAAAAAYIpTTz3V77ijo0M/+MEP9OKLL6q2tlZ9fX3q7u4ecTrx/PnzvZ9nZWUpNzdXDQ0NQ16fmZnpDR4lqayszHt9a2ur6uvrdfrpp3vP22w2LVq0aNg1Fbdu3ar333/f2z0pedaydLvd2r9/v+bMmRPya5Y861L6fg07duxQSkqKFi9e7B0rKCjQrFmztGPHjiHvi0WEjwAAAAAAADBF4K7VN954ozZs2KC7775bJ5xwgjIyMvSFL3wh5MYsvgamIQ+wWCzDBoWhrjeMkdewHE5HR4f+7d/+TV//+teDzvlOGQ+1U3dGRsa4NsQZ733RRPgIAAAAAAAQZyZnpum9W5aZXsNopKWlyeVyjeraN954Q1dddZV3unNHR4d3fchoycvLU0lJid555x2dddZZkjw7SW/evFkLFiwY8r5TTjlF27dv1wknnDDhGubMmaO+vj699dZb3mnXzc3N2rVrl+bOnTvh548mwkcAAAAAAIA4Y7VaRtzsJVZMnTpVb731lg4cOKDs7OwhN4GRPOskPvvss1qxYoUsFou+//3vD9vBGCk33HCD1q5dqxNOOEGzZ8/Wz3/+cx07dmzYLsObbrpJZ5xxhq6//np9+ctfVlZWlrZv364NGzZo3bp1Y3r9mTNn6sILL9S1116rX/7yl8rJydHNN9+siooKXXjhhRP98qLKanYBAAAAAAAASFw33nijbDab5s6dq6KiomHXb7znnns0efJkLV26VCtWrNDy5ct1yimnRLFaj5tuukmXXnqprrjiCi1ZskTZ2dlavny5d9OYUObPn6+//vWv+uijj3TmmWdq4cKFWr16tcrLy8dVw6OPPqpFixbpH//xH7VkyRIZhqH169cHTRmPdRZjohPa40xbW5vy8vLU2tqq3Nxcs8sBAAAAAAAYUU9Pj/bv369p06YNG4AhMtxut+bMmaOLL75Yt99+u9nlRMVw77mx5GtMuwYAAAAAAAB8fPzxx/rTn/6kT37yk3I4HFq3bp3279+vyy67zOzS4g7hYyJxu6TGXVLNZunIZmne/5OmfsLsqgAAAAAAAOKK1WrVY489phtvvFGGYWjevHl65ZVXNGfOHLNLizuEj4nkyUukPRsGjzMmEz4CAAAAAACMUWVlpd544w2zy0gIbDiTSEpO8j+u2WJOHQAAAAAAAIAIHxNL+UL/45rNUnLtJwQAAAAAAIAYQviYSCoCtp7vPia1fGxOLQAAAAAAAEh6hI+JJK9SyizwHzuy2ZxaAAAAAAAAkPQIHxOJxSKVB3Q/1hA+AgAAAAAAwByEj4kmaN3HalPKAAAAAAAAAAgfE03guo811ZLbbUopAAAAAAAA4TB16lTde++93mOLxaLnnntuyOsPHDggi8Wi6urqCb1uuJ4nmaWYXQDCLLDz0dkuNe+Rik40px4AAAAAAIAwq62t1eTJk8P6nFdddZVaWlr8Qs3KykrV1taqsLAwrK+VTOh8TDQ5pVJOuf8Y6z4CAAAAAIAEUlpaKrvdHvHXsdlsKi0tVUpK7PXv9fb2Bo05nc5xPdd47xsNwsdEFDT1eos5dQAAAAAAgMhwu6XOJnMfo1jm7Ve/+pXKy8vlDrj2wgsv1L/+679Kkvbu3asLL7xQJSUlys7O1mmnnaZXXnll2OcNnHb99ttva+HChUpPT9epp56qLVv8sxCXy6VrrrlG06ZNU0ZGhmbNmqWf/vSn3vM/+MEP9Otf/1rPP/+8LBaLLBaLXn311ZDTrv/617/q9NNPl91uV1lZmW6++Wb19fV5z5999tn6+te/ru985zvKz89XaWmpfvCDH4z4Z/XQQw9pzpw5Sk9P1+zZs3X//fd7zw3U8dRTT+mTn/yk0tPT9cQTT+iqq67SRRddpB/96EcqLy/XrFmzJEk
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()\n"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}