2022-11-24 07:22:33 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe\n",
"# 6. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ – szerokość działki, $x_2$ – długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ – powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 1,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 2,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y))\n",
"\n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print(\"Algorithm does not converge!\")\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta])\n",
" return theta, logs\n",
"\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c=\"r\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth=\"2\")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 3,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv(\n",
" \"data_flats.tsv\", header=0, sep=\"\\t\", usecols=[\"price\", \"rooms\", \"sqrMetres\"]\n",
")\n",
"data = np.matrix(alldata[[\"sqrMetres\", \"price\"]])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(\n",
" m, 3 * n + 1\n",
")\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 4,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 5,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2023-04-13 09:46:29 +02:00
"[<matplotlib.lines.Line2D at 0x7f4283c91b10>]"
2022-11-24 07:22:33 +01:00
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 5,
2022-11-24 07:22:33 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEnUlEQVR4nO3dd3hUZd7/8c+UNCAJhpIQCB0B6SAg0hUbrgtixbp2FAuwxeW3++w+W1zXZxXcIrDrqqy62FBAxAYISKgConTpNQk9ISF1Zn5/HDIppMwkM3OmvF/XNVdyJmdmvimE88l939/b4nK5XAIAAAAAeMxqdgEAAAAAEGoIUgAAAADgJYIUAAAAAHiJIAUAAAAAXiJIAQAAAICXCFIAAAAA4CWCFAAAAAB4iSAFAAAAAF4iSAEAAACAlwhSAAAAAOCliA5SX3/9tW666SalpqbKYrFo/vz5Xj+Hy+XSiy++qEsvvVQxMTFq2bKlnnvuOd8XCwAAACBo2M0uwEx5eXnq1auXHnzwQY0bN65Oz/HMM8/oyy+/1IsvvqgePXro9OnTOn36tI8rBQAAABBMLC6Xy2V2EcHAYrFo3rx5Gjt2rPu+wsJC/epXv9I777yjs2fPqnv37nrhhRc0YsQISdKOHTvUs2dPbd26VZ07dzancAAAAAABF9FT+2rz5JNPas2aNXr33Xf1/fff67bbbtP111+v3bt3S5IWLlyo9u3b65NPPlG7du3Utm1bPfzww4xIAQAAAGGOIFWNQ4cO6Y033tAHH3ygoUOHqkOHDvrZz36mIUOG6I033pAk7du3TwcPHtQHH3ygN998U7Nnz9bGjRt16623mlw9AAAAAH+K6DVSNdmyZYscDocuvfTSCvcXFhaqSZMmkiSn06nCwkK9+eab7vNee+019evXT7t27WK6HwAAABCmCFLVyM3Nlc1m08aNG2Wz2Sp8rFGjRpKkFi1ayG63VwhbXbt2lWSMaBGkAAAAgPBEkKpGnz595HA4dPz4cQ0dOrTKcwYPHqySkhLt3btXHTp0kCT98MMPkqQ2bdoErFYAAAAAgRXRXftyc3O1Z88eSUZwmjZtmkaOHKmkpCS1bt1a99xzj1atWqWXXnpJffr00YkTJ7R06VL17NlTN954o5xOp/r3769GjRrp5ZdfltPp1MSJE5WQkKAvv/zS5M8OAAAAgL9EdJBavny5Ro4cedH9999/v2bPnq3i4mL98Y9/1JtvvqmjR4+qadOmuuKKK/S73/1OPXr0kCQdO3ZMTz31lL788ks1bNhQN9xwg1566SUlJSUF+tMBAAAAECARHaQAAAAAoC5ofw4AAAAAXiJIAQAAAICXIq5rn9Pp1LFjxxQfHy+LxWJ2OQAAAAD8yOVy6dy5c0pNTZXV6rtxpIgLUseOHVNaWprZZQAAAAAIoMOHD6tVq1Y+e76IC1Lx8fGSjC9kQkKCJKmg2KGRLy7TuQKHYqKsWv6zEYqPjTKzTACITPn5UosWkid9kCwWKSNDiovzf11AmJq9ar9e/NLYA3PC8A568qqOJlcE+F5OTo7S0tLcOcBXIi5IlU7nS0hIcAepBEk3D+ykt9ceUrGkVQfP6/b+jFoBQMAlJEhjx0oLF0olJdWfZ7dLY8ZIyckBKw0IR5//kCNrTANJ0p2DL1VCQiOTKwL8x9fLemg2ccGt/cqC09yNR0ysBAAi3JQpksNR8zkOhzR5cmDqAcLUrsxz2p6RI0nqldZY7ZsRogBvEKQu6NUqUR2bG79A1h84rQMn80yuCAAi1JAh0owZxtQ9e6WJE3a7cf+MGdLgwebUB4SJ+ZuPut+/uXeqiZUAoYkgdYHFYtGt/coWn320iVEpADDNhAnSypXG9L3SDktWq3G8cqXxcQB15nS6tOBbI0jZrBb9qBdBCvAWQaqcm/u0lPXC1MkPNx2V0+nBYmcAgH8MHizNnSvl5kqZmcbbuXMZiQJ8YN3+0zqWXSBJGtapqZo2ijG5IiD0EKTKSU6I1bBLm0mSjp7N19p9p0yuCACguDijqQTd+QCfmf9t2bS+sX1amlgJELoIUpWUn95H0wkAABBuCood+nRLhiSpYbRN116WYnJFQGgiSFUyqmuyEmKNxc2fbs3QuYJikysCAADwna92Hte5QmN7geu7t1BctM3kioDQRJCqJDbKph9f6FxTUOzUZ1syTa4IAADAd+aVm9Z3M9P6gDojSFWBPaUAAEA4OpNXpOW7jkuSmsfHaFCHJiZXBIQuglQV2FMKAACEo0VbMlTsMLoSj+mdKltpu2IAXiNIVYE9pQAAQDiaR7c+wGcIUtVgTykAABBODp06r40Hz0iSLk1upMtaJJhcERDaCFLVYE8pAAAQTuZvLt9kopUsFqb1AfVBkKoBe0oBAIBw4HK5KmzCO+ZCh2IAdUeQqsGorslKjIuSxJ5SAAAgdH1/JFv7LjTPuqJ9klIbx5lcERD6CFI1iI2y6ce9yvaUKt0FHAAAIJSwdxTgewSpWjC9DwAAhLJih1MLvzsmSYq2W3V99xYmVwSEB4JULXq2SlSnC3tKfXPgDHtKAQCAkJK++6RO5RVJkkZ1be5etgCgfghStai8p9SH7CkFAABCSIW9o3ozrQ/wFYKUByrsKbXxCHtKAQCAkJBbWKIvt2dKkho3iNKIzs1NrggIHwQpDzRPiNXwC3tKHcsu0Br2lAIAACFg4XfHVFDslCTd2KOFou1c+gG+wr8mD93aL839Pk0nAABAKHh3/SH3+3f0T6vhTADeIkh56OpyizM/Y08pAAAQ5LYdy9Z3R7IlSd1SE9SjZaLJFQHhhSDlIfaUAgAAoeTd9Yfd7985oLUsFouJ1QDhhyDlBfaUAgAAoSC/yKH5m41ufXFRNo3pnWpyRUD4IUh5gT2lAABAKFi0JUPnCkokSTf2bKGEWPaOAnzN1CD1/PPPq3///oqPj1fz5s01duxY7dq1q8bHzJ49WxaLpcItNjY2IPWypxQAAAgF5ZtMjB9AkwnAH0wNUitWrNDEiRO1du1aLV68WMXFxbr22muVl1fzSE9CQoIyMjLct4MHDwaoYvaUAgAAwW131jltOHhGktSpeSP1bX2JyRUB4clu5ot//vnnFY5nz56t5s2ba+PGjRo2bFi1j7NYLEpJSfF3eVUq3VNq2a4T7j2lBndsakotAAAAlb37DU0mgEAIqjVS2dlGi86kpKQaz8vNzVWbNm2UlpamMWPGaNu2bdWeW1hYqJycnAq3+mJPKQAAEIwKih366MLSg2i7VeP6tDS5IiB8BU2QcjqdmjRpkgYPHqzu3btXe17nzp31+uuva8GCBXr77bfldDp15ZVX6siRqgPN888/r8TERPctLa3+84TZUwoAAASjL7Zl6sx547rkhu4puqRhtMkVAeEraILUxIkTtXXrVr377rs1njdo0CDdd9996t27t4YPH66PPvpIzZo10z//+c8qz586daqys7Pdt8OHD1d5njdiy7URLSh2atH37CkFAADMV2HvqP6tTawECH9BEaSefPJJffLJJ1q2bJlatWpV+wPKiYqKUp8+fbRnz54qPx4TE6OEhIQKN19gTykAABBMDpzM05p9pyRJ7Zo21BXta14qAaB+TA1SLpdLTz75pObNm6evvvpK7dq18/o5HA6HtmzZohYtWvihwur1aJmoS5ONPaU2HDyj/ewpBQAATFS+ycQd/dNoMgH4malBauLEiXr77bc1Z84cxcfHKzMzU5mZmcrPz3efc99992nq1Knu49///vf68ssvtW/fPm3atEn33HOPDh48qIcffjigtV+0pxSjUgAAwCTFDqd7hozdatEtfb2b4QPAe6YGqZkzZyo7O1sjRoxQixYt3Lf33nvPfc6hQ4eUkVG2BunMmTN65JFH1LVrV40ePVo5OTlavXq1LrvssoDXP7Z3S9kubCr14aYjcrCnFAAAMMHSHVk6mVsoSbrmsmQ1i48xuSIg/Jm6j5TLVXvwWL58eYXj6dOna/r06X6qyDule0p9tfO4MrILtGbvKQ3pxJ5SAAAgsN5ZX3HvKAD+FxTNJkJZxaYT9e8ICAAA4I0jZ87
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 6,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 397519.38046962]\n",
" [-841341.14146733]\n",
" [2253713.97125102]\n",
" [-244009.07081946]]\n"
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEN0lEQVR4nO3dd3yV9fn/8fcZWRAChJEQCHvvoSAbFAcuhlrFWatWFFsRW1t/9Wtrl61V0FoB2zrqHiggThBkhKks2XtmsRMSMs85vz9ucpJAxjnJObnPeD0fjzyS+5z7nHNlEO4r1+dzXRaXy+USAAAAAMBjVrMDAAAAAIBgQyIFAAAAAF4ikQIAAAAAL5FIAQAAAICXSKQAAAAAwEskUgAAAADgJRIpAAAAAPASiRQAAAAAeIlECgAAAAC8RCIFAAAAAF4K60Rq+fLluuGGG5SUlCSLxaJ58+Z5/Rwul0vPP/+8OnfurKioKLVs2VJ/+ctffB8sAAAAgIBhNzsAM+Xm5qpPnz762c9+pokTJ9boOR599FEtXLhQzz//vHr16qVTp07p1KlTPo4UAAAAQCCxuFwul9lBBAKLxaK5c+dq/Pjx7tsKCgr0u9/9Tu+//77OnDmjnj176u9//7tGjRolSdqxY4d69+6trVu3qkuXLuYEDgAAAKDOhfXSvuo88sgjWr16tT744AP9+OOPuuWWW3TNNddoz549kqQFCxaoffv2+vzzz9WuXTu1bdtW999/PxUpAAAAIMSRSFXi8OHDeuONN/Txxx9r+PDh6tChg371q19p2LBheuONNyRJ+/fv16FDh/Txxx/rrbfe0ptvvqn169fr5ptvNjl6AAAAAP4U1nukqrJlyxY5HA517ty53O0FBQVq0qSJJMnpdKqgoEBvvfWW+7zXXntNAwYM0K5du1juBwAAAIQoEqlK5OTkyGazaf369bLZbOXui42NlSS1aNFCdru9XLLVrVs3SUZFi0QKAAAACE0kUpXo16+fHA6Hjh07puHDh1d4ztChQ1VcXKx9+/apQ4cOkqTdu3dLktq0aVNnsQIAAACoW2HdtS8nJ0d79+6VZCRO06dP1+jRoxUfH6/WrVvrzjvv1MqVK/XCCy+oX79+On78uBYvXqzevXvruuuuk9Pp1KWXXqrY2Fi9+OKLcjqdmjJliuLi4rRw4UKTPzsAAAAA/hLWidTSpUs1evToi26/55579Oabb6qoqEh//vOf9dZbbyk1NVVNmzbVZZddpmeeeUa9evWSJKWlpekXv/iFFi5cqPr162vs2LF64YUXFB8fX9efDgAAAIA6EtaJFAAAAADUBO3PAQAAAMBLJFIAAAAA4KWw69rndDqVlpamBg0ayGKxmB0OAAAAAD9yuVw6e/askpKSZLX6ro4UdolUWlqakpOTzQ4DAAAAQB06cuSIWrVq5bPnC7tEqkGDBpKML2RcXJwpMSzclqFpH22WJN08oJX+cGMPU+IAgICTlye1aCF50gfJYpHS06WYGP/HBYS5Jz/5UQt+TJck/WlcD03o77uLUcDfsrOzlZyc7M4DfCXsEqmS5XxxcXGmJVLXDqinp77cp/wip5YfzFH92AayWVlmCACKi5PGj5cWLJCKiys/z26Xxo2TEhLqLDQgXJ3KLdSivWdljaqnhjERunVoF0VH2MwOC/Car7f10GzCBPUi7RrZuZkk6UROob4/eMrkiAAggEybJjkcVZ/jcEiPPVY38QBh7qMfjqjQ4ZQk3TKgFUkUcB6JlEnG9mzh/vjrrRkmRgIAAWbYMGnmTGPpnv2ChRN2u3H7zJnS0KHmxAeEEafTpXfXHnIf33FZGxOjAQILiZRJLu/WXBE2o7z49dYMOZ3MRQYAt8mTpRUrjOV7JR2WrFbjeMUK434Afrdsz3EdOZUnSRreqanaNa1vckRA4Ai7PVKBIi46QsM6NtV3u44rIztfm46eUf/Wjc0OCwACx9ChxltenpSdbeyforEEUKfeWV1ajbqTahRQDhUpE7G8DwA8EBNjNJUgiQLq1JFT57Rk1zFJUouG0bqia3OTIwICC4mUia7snuDu1vfV1nS5PGn3CwAAUAfeX3fYPYng9oGtZbdx2QiUxb8IEzWuH6nL2sdLko6cytO2tGyTIwIAAJAKih368PsjkiS71aJbByabHBEQeEikTHYNy/sAAECA+Xprhk7mFkqSru6ZqOYNok2OCAg8JFImu7pHgkpmg321Nd3cYAAAACS9s6a0ycRdNJkAKkQiZbLmDaJ1SRujW9++47nak3nW5IgAAEA425mRre8PnpYkdWoeq0Ht4k2OCAhMJFIBoOzyvq9Y3gcAAExUthp152VtZClZOgOgHBKpAHBNz0T3xyRSAADALGfzizR3Q6okqV6kTRP6tzQ5IiBwkUgFgJaNYtSnVUNJ0o70bB06mWtyRAAAIBzN25iq3EKHJGl8v5aKi44wOSIgcJFIBQiW9wEAADO5XC69s+aw+/jOQTSZAKpCIhUgxrK8DwAAmOj7g6e163zTqwFtGqt7UpzJEQGBjUQqQLRtWl9dExtIkjYfOaO0M3kmRwQAAMLJ27Q8B7xCIhVAxjKcFwAAmOD42QJ9fX6eZXz9SI3tlVjNIwCQSAWQsr+0SKQAAEBd+eiHIypyuCRJP7kkWVF2m8kRAYGPRCqAdGoeq/bN6kuSvj90SsfO5pscEQAACHUOp0vvnl/WZ7FIdwxqbXJEQHAgkQogFovF3XTC5ZIWbss0OSIAABDqluw8prQs44+3ozo3U3J8PZMjAoIDiVSAYZ8UAACoS++UbTIxmCYTgKdIpAJMj6Q4tWocI0lavf+kTucWmhwRAAAIVYdO5mrZ7uOSpFaNYzSyc3OTIwKCB4lUgCm7vM/hdGnRDpb3AQAA/3h3bekA3jsGtZHNajExGiC4kEgFoGtY3gcAAPwsv8ihj344IkmKtFn1k0tamRwREFxIpAJQv+RGSoiLkiSl7Dmhs/lFJkcEAABCzRc/puvMOeMa49peiWoSG2VyREBwIZEKQFarRdf0MJb3FTqcWrLzmMkRAQCAUPM2TSaAWjE1kXr22Wd16aWXqkGDBmrevLnGjx+vXbt2VfmYN998UxaLpdxbdHR0HUVcd8ou7/tqC8v7AACA72xNzdKmI2ckSV0TG6h/68bmBgQEIVMTqWXLlmnKlClas2aNFi1apKKiIl111VXKzc2t8nFxcXFKT093vx06dKjK84PRwHbxalI/UpK0dPcxnSssNjkiAAAQKi5seW6x0GQC8JbdzBf/+uuvyx2/+eabat68udavX68RI0ZU+jiLxaLExER/h2cqm9Wiq3ok6P11R5Rf5NSyXcc1tleL6h8IAABQhay8Is3blCpJio2ya3zfliZHBASngNojlZWVJUmKj4+v8rycnBy1adNGycnJGjdunLZt21bpuQUFBcrOzi73FizKLe+jex8AAPCBTzccVX6RU5I0sX9L1Y8y9e/qQNAKmETK6XRq6tSpGjp0qHr27FnpeV26dNHrr7+u+fPn65133pHT6dSQIUN09OjRCs9/9tln1bBhQ/dbcnKyvz4Fnxvcvonioo1fbkt2HlNBscPkiAAAQDBzuVzlmkzceRlNJoCaCphEasqUKdq6das++OCDKs8bPHiw7r77bvXt21cjR47Up59+qmbNmunVV1+t8Pwnn3xSWVlZ7rcjR474I3y/iLRbNaZ7giQpp6BYKXtOmBwRAAAIZqv3ndT+48Ze9EHt4tU5oYHJEQHBKyASqUceeUSff/65vvvuO7Vq5d0wuIiICPXr10979+6t8P6oqCjFxcWVewsmY1neBwAAfOSdtVSjAF8xNZFyuVx65JFHNHfuXC1ZskTt2rXz+jkcDoe2bNmiFi1CsxHD8E5NVS/SJklatD1TRQ6nyREBAIBglJmdr2+2ZUqSmsZG6eoeod24C/A3UxOpKVOm6J133tF7772nBg0aKCMjQxkZGcrLy3Ofc/fdd+vJJ590H//xj3/UwoULtX//fm3YsEF33nmnDh06pPvvv9+MT8HvoiNsGt21uSSjy86a/SdNjggAAASj99cdlsPpkiR
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 7,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1, x2, n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n + 1):\n",
" for i in range(m + 1):\n",
" X.append(np.multiply(np.power(x1, i), np.power(x2, (m - i))))\n",
" return np.hstack(X)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 8,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 8,
2022-11-24 07:22:33 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 9,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c=\"r\", marker=\"x\", s=50, label=\"Dane\")\n",
" ax.scatter(X1p, X2p, c=\"g\", marker=\"o\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" return fig\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 10,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpKUlEQVR4nO3df3hU1b3v8c9MQkI0TpALJFCHatSAtigKNYYgckuuYOkxQs9psDYqh+pB0VagKvTWeMS2WH/gfWr50VpFm1ZJ9RGRaqkGxSYBQUEqKhBRbKKSWOWQAYyZJLPvH9OMmclkMpPMr73n/XqeeZS91wxrNntm9nd/1/oum2EYhgAAAAAAUWdPdAcAAAAAwKoIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEYIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEbSE90BK/B4PPr444910kknyWazJbo7AAAAAGLIMAwdPXpUo0aNkt0eOodFwBUFH3/8sZxOZ6K7AQAAACCOGhsbdcopp4RsQ8AVBSeddJIk7wF3OBwJ7g0AAACAWHK5XHI6nb44IBQCrijoGkbocDgIuAAAAIAUEc50IopmAAAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMmCrg+tvf/qZ/+7d/06hRo2Sz2fTMM8/0+ZwtW7bo/PPPV2Zmps444ww9+uijPdqsXLlSp556qgYPHqzCwkLt2LEj+p0HAAAAkHJMFXAdP35c5557rlauXBlW+4MHD2rmzJn63//7f2v37t26+eab9YMf/EB//etffW2qqqq0aNEi3XHHHdq1a5fOPfdcTZ8+XZ988kms3gYAAACAFGEzDMNIdCf6w2azaf369br88st7bXPbbbfpueee01tvveXbNmfOHB05ckSbNm2SJBUWFuob3/iGfv3rX0uSPB6PnE6nbrrpJi1ZsiSsvrhcLuXk5KilpUUOh6P/bwoAAABA0ovk+t9UGa5Ibdu2TSUlJX7bpk+frm3btkmS3G63du7c6dfGbrerpKTE1yaYtrY2uVwuvwcAAAAABLJ0wNXU1KTc3Fy/bbm5uXK5XGptbdWnn36qzs7OoG2ampp6fd3ly5crJyfH93A6nTHp/4C53QPbDwAAAGBALB1wxcrSpUvV0tLiezQ2Nia6Sz1VVUnjxkm99a2x0bu/qiq+/QIAAABSiKUDrry8PDU3N/tta25ulsPhUFZWloYNG6a0tLSgbfLy8np93czMTDkcDr9HUnG7pYoKqb5emjq1Z9DV2OjdXl/vbUemC0CikZEHAFiUpQOuoqIibd682W/biy++qKKiIklSRkaGJkyY4NfG4/Fo8+bNvjamlJEhVVdL+fnS++/7B11dwdb773v3V1d72wNAopCRBwBYmKkCrmPHjmn37t3avXu3JG/Z9927d6uhoUGSd6jfVVdd5Ws/f/58vf/++7r11lu1b98+rVq1Sn/605+0cOFCX5tFixbpoYce0mOPPaa9e/fq+uuv1/HjxzV37ty4vreoczqlLVv8g66tW/2DrS1bvO1SBXfQgeRDRh4AYHGmCrhef/11nXfeeTrvvPMkeYOl8847TxUVFZKkQ4cO+YIvSTrttNP03HPP6cUXX9S5556r+++/X7/73e80ffp0X5uysjLdd999qqio0Pjx47V7925t2rSpRyENUwoMuoqLUzfY4g46kJzIyAMALM6063Alk6Rfh2vrVm+w1aWuTpo0KXH9iTe32xtM1dcHDza7X9QVFEh79nBRB8RbYHBVWSmVl6fuTSIAQA+t7a1ytbnkyHQoa1BWQvvCOlz4UmOj96Klu/Ly3jM9VsQddCD5kZEHAPSitqFWs6tmK3t5tvLuz1P28mzNrpqtuoa6RHctLARcVhYYTNTVBQ86UgFz2oDk53R6M1vdVVbyuQSAFLb6tdWasnaKNtZvlMfwSJI8hkcb6zfqorUXac3raxLcw74xpDAKknJIYWCw1RVM9LY9VXR//11S8TgAyYjPJwCgm9qGWk1ZO0WGeg9XbLKpZm6NikcX99omFhhSmOrcbqmkJHhQFZjpKSlJrapf3EEHkhMZeaQaKucCfVqxbYXS7Gkh26TZ0/TAqw/EqUf9Q8BlRRkZ0rJl3gIQwe4MdwVdBQXedqk0Z4k5bUDyCZZ5nzSp5zBgPqewCirnAn1qbW/Vhv0b1OHpCNmuw9Oh9fvWq7W9NU49ixwBl1WVlXmr7fWWuXE6vfvLyuLbr0TiDjqQfPqbkSc7ALNi7TkgLK42l2/OVl88hkeuNleMe9R/BFxW1lfmKtUyW9xBB5JPfzLyZAdgZlTOBcLiyHTIbgsvVLHb7HJkJkkdhSAomhEFSVk0A19iHS4g+bndoT93Xfv5PMMqWHsO6NPsqtnaWL8x5LDCdHu6SseU6qnvPhXHnlE0A/DHnDYg+YWbkSc7AKtg7TmgT4uKFqnT0xmyTaenUwsvXBinHvUPGa4oIMNlEuHeQQeQ/MgOwCq2bvUGW13q6rxD3gFIkta8vkY3PHeD0uxpfpmudHu6Oj2dWjVzleZPnB/3fpHhAoJhThtgHWQHYAVUzgX6NH/ifNXMrVHpmFLfnC67za7SMaWqmVuTkGArUmS4ooAMFwBEUSTZaLIDMCuytEDEWttb5WpzyZHpUNagrIT2hQwXAMCcIqlASHYAZkXlXKBfsgZlKTc7N+HBVqQIuAAAySGS9YmWLpUuvph19WA+/V17DoBpEXABAJJDuBUIR4+WOjulgwfJDsB8qJwLpBzmcEUBc7gAIIpCzW057TTJZut9ngvrcMEsqJwLmBpzuAAA5hWqAuErr0i/+AXZAZgflXOBlEGGKwrIcAFADISqQEh2IPnxbwTAwshwAQDMra8KhGQHklsk1SYBwOIIuAAAySVwDhcVCM0lkmqTFRVU4QNgeQRcAIDkwfpE5hdutcn8fG87spHm01eQTBAN+CHgAgAkB9Ynso7Af6+pU71z8gKD6cCiJ0h+DBcFIkbABQBIDqxPZC2hqk0SbJkTw0WBfqFKYRRQpRAAoojqdtYSqtokzCfYsF+ns/ftgEVRpRAAYF5UILSOvqpNwnwYLgpEjIALAABEH9UmrYvhokBECLgAAEB0UW3S+pxOqbLSf1tlJcEWEAQBFwAAiB6qTaYGhosCYSPgAgAA0UO1SetjuCgQEaoURgFVCgEACEC1SWuiSiEgiSqFAAAg0ag2aT0MFwX6hYALAAAAfWO4KNAvDCmMAoYUAgCAlMFwUYAhhQAAAIgRhosCESHgAgAAAIAYIeACAAAAgBgh4AIAAACAGCHgAgAAAIAYIeACAAAAgBgxZcC1cuVKnXrqqRo8eLAKCwu1Y8eOXttOnTpVNputx2PmzJm+Ntdcc02P/TNmzIjHWzGfvhYxZJFDAAAAwMd0AVdVVZUWLVqkO+64Q7t27dK5556r6dOn65NPPgna/umnn9ahQ4d8j7feektpaWn6j//4D792M2bM8Gv3xBNPxOPtmEtVlTRunNTYGHx/Y6N3f1VVfPsFAADMi5u5sDjTBVwrVqzQtddeq7lz5+rss8/WmjVrdMIJJ+iRRx4J2n7o0KHKy8vzPV588UWdcMIJPQKuzMxMv3Ynn3xyPN6OebjdUkWFVF8vTZ3aM+hqbPRur6/3tuPLEQAA9IWbuUgBpgq43G63du7cqZKSEt82u92ukpISbdu2LazXePjhhzVnzhydeOKJftu3bNmiESNGaMyYMbr++uv12Wef9foabW1tcrlcfg/Ly8iQqqul/Hzp/ff9g66uYOv99737q6tZ9BAAAITGzVykCFMFXJ9++qk6OzuVm5vrtz03N1dNTU19Pn/Hjh1666239IMf/MBv+4wZM/T73/9emzdv1i9/+Uu98soruvTSS9XZ2Rn0dZYvX66cnBzfw+l09v9NmYnTKW3Z4h90bd3qH2xt2eJtBwAAEAo3c5EibIZhGInuRLg+/vhjfeUrX9HWrVtVVFTk237rrbfqlVde0fbt20M+/7/+67+0bds2vfnmmyHbvf/++zr99NNVXV2tadOm9dj
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ – funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 11,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób,\n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0 / (1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X * theta, eps)\n",
"\n",
"\n",
"def J(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = (\n",
" -np.sum(np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0)\n",
" / m\n",
" )\n",
" if lamb > 0:\n",
" j += lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ(h, theta, X, y, lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" if lamb > 0:\n",
" g[1:] += lamb / float(y.shape[0]) * theta[1:]\n",
" return g\n",
"\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 12,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta])\n",
" return theta, errors\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 13,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n",
"print(r\"theta = {}\".format(theta))\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 14,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02), np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 15,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2023-04-13 09:46:29 +02:00
"/tmp/ipykernel_3088/1169766636.py:9: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2022-11-28 11:52:13 +01:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
2022-11-24 07:22:33 +01:00
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACXNklEQVR4nOzdeVxU1fsH8M8srOKAioDmuKCCS+6kImqUlKglaQu2aJplLm1iWVZqaqWpad/KpcWlKJOs1Ex/lpioLO6au7gzLuAWjMDIwMz9/TExMjDsM3Nn+bxfL17KvWeGZy7D3Pvcc85zJIIgCCAiIiIiIiKLk4odABERERERkbNiwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhK5GIH4Az0ej2uXLmCunXrQiKRiB0OERERERFZkSAIuH37Nho3bgyptOI+LCZcFnDlyhUolUqxwyAiIiIiIhtSqVRo0qRJhW2YcFlA3bp1ARgOuEKhuLvj0iVg0CDgwgWgeXPg66+BMWPufr9xI1DJL8hpabWAu3vN9xORuEp/vhV/npW3nYiIyImo1WoolUpjHlARiSAIgg1icmpqtRq+vr7IyckxTbgAQKUCIiOBc+fubgsOBpKSAFftFUtIAKZNAxITzR8DlQqIigJmzgRiY20fHxFVTcnPt+BgID4eGD787veu/DlHREROrcLr/1KYcFlApQc8NRWIiLj7fUoK0KuX7QK0J1ot0KEDkJ5u/oKs5AVcSAhw5Ah7uojsGW8qERGRlWkKNVAXqKHwUMDLzUvscABUL+FilUJrU6kMd3xLGj7csN0VubsberaCgw0XaJGRd49F6bvliYlMtojsnVJp6NkqKT6eyRYREdVackYyhiYMhc9sHwR9GgSf2T4YmjAUKRkpYodWLUy4rKl0ApGSYj7RcDVKpeHud8ljkZpqeqx4d5zIMfCmEhERWcGSvUvQd0VfbEjfAL2gBwDoBT02pG9AnxV9sHTfUpEjrDomXNZSOtlKSjIMIyydaLjqRUnppCsigskWkaPhTSUiIrKC5IxkTNg0AQIEFOmLTPYV6YsgQMD4jeMdpqeLCZc1aLWGog/mEojSiUZUlKG9K+JQJCLHxZtKRGVVdj531fM9UTUtSFsAmVRWYRuZVIaFuxbaKKLaYcJlDe7uhgp7ISHme2uKk66QEEM7V52nxKFIRI6JN5WIykpIMBSFKu8cplIZ9ick2DYuIgejKdRg/an1ZXq2SivSF2HtybXQFGpsFFnNMeGylthYQ4W98nprlErDflcte86hSESOizeViExptYblTtLTzZ/Dis956emGdrwJQVQudYHaOGerMnpBD3WB2soR1R4TLmuq7CLDVS9COBSJyPHV5KYSh1uRs2IFXiKLUXgoIJVULUWRSqRQeFRckt0eMOEi2+JQJCLnUZ2bShxuRc6OFXiJLMLLzQsxoTGQS+UVtpNL5RjSZojdrMtVESZcZFscikTkejjcilwFK/ASWURceBx0el2FbXR6HSb2nGijiGqHCRfZHue3EbkWDrciV8IKvES11rtpbywetBgSSMr0dMmlckggweJBixHRNEKkCKuHCReJg/PbiFwLh1uRq2AFXiKLGBs2FjtH7URMaIxxTpdUIkVMaAx2jtqJsWFjRY6w6iSCIAhiB+Ho1Go1fH19kZOTA4XC/ifuERGJpmSPVjEmW+QsSvfYxscbki3eVCCqFU2hBuoCNRQeCruZs1Wd63/2cBERUe1VtQIhh1uRs2IFXiKr8XLzQqBPoN0kW9XFhIuIiGqnOhUIOdyKnBEr8BJRBZhwERFRzVWnAuGUKcD993PBc3I+rMBLRBXgHC4L4BwuInJp5oZSKZWm25s2NbTNyCi/Dee4kKPTaitOpirbT0QOg3O4iIjIdiqrQNiiBSCXl022zD2Ww63IkbECLxGZwYSLiIhqr6IFX7dvBz7+mMOtiIjIJXFIoQVwSCER0X9SUw3JVrGUFEOlNoDDrYiIyGlwSCEREdleZRUIOdzKcVW17D8REZXBhIuIiGqvdPELViB0HtUp+09ERGUw4SIiotrhgq/Oqzpl/6dNY08XEZEZTLiIiKjmuOCrc3N3BxITzSfOpRPtxEQOC3UGHD5KZHFMuIiIqOa44Kvzq6zsP9dPcx4cPkpkFaxSaAGsUkhELo8VCJ1fyR6tYky2nIdWa0im0tPN/15L/v5DQoAjR/g3TS6NVQqJiMi2WIHQ+SmVQHy86bb4eCZbzoLDR4mshgkXERERVa6ysv/k+Dh8lMgqmHARERFRxVj233WUTroiIphsEdUSEy4iIiIqH8v+ux4OHyWyKCZcREREZB7L/rsmDh8lsigmXERERGQey/67Hg4fJbI4loW3AJaFJyIip8ay/67B3PBRpbL87UQujGXhiYiIyHJY9t/5cfgokdUw4SIiIiJydRw+SmQ1HFJoARxSSERERE6Bw0eJqoRDComIiIio+jh8lMjimHARERERERFZCRMuIiIiIiIiK2HCRUREREREZCVMuIiIiIiIiKyECRcREREREZGVMOEiIiIiIiKyEiZcREREREREVuKQCdeiRYvQvHlzeHp6okePHtizZ0+5bSMjIyGRSMp8DRo0yNhm5MiRZfZHR0fb4qXUjFZbu/1EREREtsLrFnJxDpdwJSQkIC4uDtOnT8eBAwfQqVMn9O/fH9euXTPb/rfffsPVq1eNX0ePHoVMJsOTTz5p0i46Otqk3U8//WSLl1N9CQlAhw6ASmV+v0pl2J+QYNu4iIiIiErjdQuR4yVcCxYswEsvvYRRo0ahXbt2WLp0Kby9vbF8+XKz7evXr4+goCDj15YtW+Dt7V0m4fLw8DBpV69ePVu8nOrRaoFp04D0dCAysuyHl0pl2J6ebmjHO0ZEREQkFl63EAFwsIRLq9Vi//79iIqKMm6TSqWIiopCWlpalZ5j2bJlGDZsGOrUqWOyPSkpCQEBAQgNDcW4ceNw8+bNcp+joKAAarXa5Msm3N2BxEQgOBg4d870w6v4Q+vcOcP+xERDeyIiIiIx8LqFCICDJVw3btyATqdDYGCgyfbAwEBkZmZW+vg9e/bg6NGjePHFF022R0dH4/vvv8fWrVvxySefYPv27RgwYAB0Op3Z55k9ezZ8fX2NX0qlsuYvqrqUSiApyfTDKzXV9EMrKcnQjoiIiEhMvG4hgkQQBEHsIKrqypUruOeee5Camorw8HDj9smTJ2P79u3YvXt3hY9/+eWXkZaWhsOHD1fY7ty5c2jZsiUSExPRr1+/MvsLCgpQUFBg/F6tVkOpVCInJwcKhaKar6qGSt4ZKsYPLSIiIrJHvG4hJ6NWq+Hr61ul63+H6uHy9/eHTCZDVlaWyfasrCwEBQVV+Ni8vDysXr0ao0ePrvTnBAcHw9/fH2fOnDG738PDAwqFwuTL5pRKID7edFt8PD+0iIiIyP7wuoVcmEMlXO7u7ujWrRu2bt1q3KbX67F161aTHi9z1qxZg4KCAjz33HOV/pxLly7h5s2baNSoUa1jthqVChg+3HTb8OHlVwEiIiIiEguvW8iFOVTCBQBxcXH45ptv8N133+HEiRMYN24c8vLyMGrUKADAiBEjMGXKlDKPW7ZsGR577DE0aNDAZHtubi7eeust7Nq1CxcuXMDWrVsRExODVq1aoX///jZ5TVVSsnJP6YmmKSlAixZlJ6QSERERic3cdYu5QhpETkoudgDVFRsbi+vXr2PatGnIzMxE586dsXnzZmMhjYyMDEilpnnkqVOnkJycjL/++qvM88lkMhw+fBjfffcdsrOz0bhxYzz88MOYNWsWPDw8bPKaKpWQYCiXmpho+L70RNNiAQF3P7w4JpqIiIjEVjrZKr4+SUq6u53XLeTkHKpohr2qzqS5atNqDQsCpqcberEkkrLJVvEHVosWhu/PnwdCQoAjR1hilYiIiMRR8hrGXIGMkskYr1vIwTht0QyXVHINi/Pngdxc88lWcDCwfbvhKyQEmDmTH1pEREQkHnd3w/VISIj5Hqzini5et5CTYw+
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 16,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 17,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2023-04-13 09:46:29 +02:00
"/tmp/ipykernel_3088/1169766636.py:9: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2022-11-28 11:52:13 +01:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
2022-11-24 07:22:33 +01:00
]
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACZp0lEQVR4nOzdeVxU1fsH8M8sMIA4gCIgCioquO+Jipr9pNwqUisqc8ss0za1zRYtW8wstcUl+6YmbWilZpol5sLmrrmLO6iAC8IIjDMwc39/TIwMDPvM3Fk+79drXsq9Z4Znhpk797nnnOdIBEEQQERERERERBYnFTsAIiIiIiIiZ8WEi4iIiIiIyEqYcBEREREREVkJEy4iIiIiIiIrYcJFRERERERkJUy4iIiIiIiIrIQJFxERERERkZXIxQ7AGej1ely5cgX169eHRCIROxwiIiIiIrIiQRBw69YtBAcHQyqtvA+LCZcFXLlyBSEhIWKHQURERERENpSRkYGmTZtW2oYJlwXUr18fgOEFVyqVd3ZcugQMGwZcuAA0bw4sWwY888ydnzduBKr4AzktrRZwd6/9fiISV9njW8nxrKLtRERk9w5sPYIvp3yDnMxcAMA9j0dh/IePw6ehsvI7uiCVSoWQkBBjHlAZiSAIgg1icmoqlQo+Pj7Iy8szTbgAICMDGDAAOHfuzrawMGD7dsBVe8Xi44GZM4GEBPOvQUYGEB0NzJ4NxMbaPj4iqp7Sx7ewMCAuDhg9+s7PrnycIyJyUAV5BVj+1k/YsORvCIIAZcP6mDR/LKKf7M+pM6VUev5fBhMuC6jyBU9JAaKi7vycnAz06WO7AO2JVgt07AikpZk/ISt9AhceDhw5wp4uInvGi0pERE7p+K40LHz2a5w/kg4A6BbdES8teQbBLYNsHou6SA2VRgWlQglPN0+b/35zapJwsUqhtWVkGK74ljZ6tGG7K3J3N/RshYUZTtAGDLjzWpS9Wp6QwGSLyN6FhBh6tkqLi2OyRUTk4Nr1CsfifXPx1IdPwE3hhgMJR/BMp+n4beFG6HQ6m8SQlJ6EEfEj4D3HG0GfBcF7jjdGxI9AcnqyTX6/pTDhsqayCURysvlEw9WEhBiufpd+LVJSTF8rXh0ncgy8qERE5LTkbnI8PmM4lh3+DJ0HtIdGrcWSaSsxrf9MpJ+8bNXfvWTvEvRf0R8b0jZAL+gBAHpBjw1pG9BvRT8s3bfUqr/fkjik0ALMdimWTbZKEoiKtrsiDkUicmycw0VE5DL0ej02fbMV37wWh8Jbargp3DB65iN49NUHIZPLLPq7ktKT0H9FfwioOE2RQILE8YmICo2qsI01cUih2LRaQ9EHcycdZXt3oqMN7V0RhyIROS5zF4/69Cnfe82eLnIlVX2fu+r3PTkFqVSK+5+9F98cnY+7hnRFkaYIy9/6ES/2eQvnj6Zb9HfNT50PmbTyJE4mlWHBrgUW/b3WwoTLGtzdDRX2wsPNX+EtSbrCww3tXHWeEociETkmXlQiKi8+3lAUqqLvsIwMw/74eNvGRWRhASH++PCPGXh1xRR4+9ZD2r6zmNz9Nfz40W/QFdd9bpe6SI31p9ajWF9cabtifTHWnlwLdZG6zr/T2phwWUtsrKHCXkW9NSEhhv2uWvac89uIHBcvKhGZ0moNy52kpZn/Div5zktLM7TjRQhycBKJBPeNHYBvjs5H7wd7oLhIhxVv/4SX+72DjFN1m9ul0qiMc7aqohf0UGlUdfp9tsA5XBZQkzGcBM5vI3IWNV3AnAuekzPjdxu5KEEQkPD9Tix6cTkK8gqh8HTHhDmjEPP8YEilNe/bURep4T3Hu1pJl1QiRf6MfFFKxXMOF9kvDkUich5VJUel93O4FTk7VuAlFyWRSHDv6Lux7PBn6HZvJ2jUWix+eQVev3c2si9eq/Hjebp5IiYiBnKpvNJ2cqkcw9sMt5t1uSrDhItsi0ORiFwPh1uRqyibdEVFMdkilxEQ4o+PN7+NFxc9DQ8vBQ5tO4ZnOk3HXyu3oaYD6qb1ngadvvL5YDq9DlN7Ta1LyDbDIYUWwCGFtcChRUSuhcOtyJWkpBiSrRLJyYYqnkQu4vKZTHwybhGOp5wCAPR/pDdeWjIRygb1q/0YS/ctxeSNkyGTykwKaMilcuj0OiwethiTekyyeOzVxSGFZP9qMhSJiBwfh1uRq2AFXiI0adUY83e8h6c+fAIyuQw716RiUpdXcWjb0Wo/xqQek5A4PhExETGQSgwpi1QiRUxEDBLHJ4qabNUUe7gsgD1cRETVxAXPyZlxMXCick7tO4s5oz7H5dOZkEgkeOSVBzHu/Vi4ubtV+zHURWqoNCooFUq7mbPFHi4iIrKt6i74ygXPyVlxMXAisyJ6tMSSA59g6NMDIQgCVs9bj5f6vIX0k9UvH+/p5olA70C7SbZqigkXERHVTU0qEHK4FTkjVuAlqpRnPQ9MXTYJs359BfUbeOP0gfOY0uN1bFm1Q+zQbIIJFxER1V5NKhDOmAHcfTcXPCfnwwq8RNXSd3gklh3+DF0HdsTtQg0+GfcV5j21COqC22KHZlWcw2UBnMNFRC6tOhUIQ0MNbdPTWaWQnBcr8BJVi06nw08frUXce6uh1wto1q4p3o6fhubtHef4zzlcRERkO1VVIGzRApDLyydb5u7L4VbkyFiBl6haZDIZnnznYXySMAsNgnxx8fglPN/zDfy1cpvYoVkFEy4iIqq7yhZ83bED+OgjDrciIiITnQe0x9KD89Dt3k7QqLX49KnF+GTcVyguKq76zg6ECRcREVlGZRUIY2OBI0cqHi4YEmLYHxtr/TiJiMhu+AX6Ys6fb2Hc+49BKpWguKgYMrlM7LAsinO4LIBzuIiIwDW2nBnnJhGRDRxNPomwTs3gVd/+y79zDhcREdlW2eIXrEDoPGpS9p+IqA46RLVxiGSrpphwERFR3XDBV+dVk7L/M2ey4AkRkRlMuIiIqPa44Ktzc3cHEhLMJ85lE+2EBA4rdAZVfUb5GSaqMSZcRERUe1zw1flVVfaf8/ScB4ePElkFi2ZYAItmEJHLY1EF58eiKM5NqzUkU2lp5v+upf/+4eGGqqL8TJMLY9EMIiKyLS746vwqK/tPjo/DR4mshgkXERERVS0jAxg92nTb6NEshuJMOHyUyCqYcBEREVHlWPbfdZRNuqKimGwR1RETLiIiIqoYy/67Hg4fJbIoJlxERERkHsv+uyYOHyWyKCZcREREZB7L/rseDh8lsjiWhbcAloUnIiKnxrL/rsHc8NGQkIq3E7kwloUnIiIiy2HZf+fH4aNEVsOEi4iIiMjVcfgokdVwSKEFcEghEREROQUOHyWqFg4pJCIiIqKa4/BRIotjwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhKmHARERERERFZiUMmXIsWLULz5s3h4eGByMhI7Nmzp8K2AwYMgEQiKXcbNmyYsc24cePK7R88eLAtnkrtaLV1209ERERkKzxvIRfncAlXfHw8pk2bhlmzZuHAgQPo3LkzBg0ahKtXr5pt/9tvvyEzM9N4O3r0KGQyGR555BGTdoMHDzZp99NPP9ni6dRcfDzQsSOQkWF+f0aGYX98vG3jIiIiIiqL5y1EjpdwzZ8/HxMnTsT48ePRrl07LF26FF5eXli+fLnZ9g0aNEBQUJDxtmXLFnh5eZVLuBQKhUk7Pz8/WzydmtFqgZkzgbQ0YMCA8gevjAzD9rQ0QzteMSIiIiKx8LyFCICDJVxarRb79+9HdHS0cZtUKkV0dDRSU1Or9RjffvstHnvsMdSrV89k+/bt2xEQEICIiAg899xzuHHjRoWPodFooFKpTG424e4OJCQAYWHAuXOmB6+Sg9a5c4b9CQmG9kRERERi4HkLEQAHS7iuX78OnU6HwMBAk+2BgYHIysqq8v579uzB0aNH8fTTT5tsHzx4MFatWoWtW7di7ty52LFjB4YMGQKdTmf2cebMmQMfHx/jLSQkpPZPqqZCQoDt200PXikppget7dsN7Yi
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 18,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix(\n",
" [\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ]\n",
")\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4)\n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5)\n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(\n",
" m, 2 * n + 1\n",
")\n",
"X5 = np.matrix(\n",
" np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)\n",
").reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 19,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlLklEQVR4nO3df3CU9Z3A8U9+lJCKG4oKgTEq/jhsxYqtymAUdeTKtZ4DMuNVz3Oo1ztPGk+Bnq3ejDrW01SvA07vBHrenDpetdpe0dOrOohVCSL+QkutR9VyylkTrdasYhol+9wfW9KLkC+/kuwmeb1mdjL77HfTT/p0Sd59nme3IsuyLAAAANiuylIPAAAAUM5EEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQEJJo+mxxx6L008/PSZMmBAVFRVx991393g8y7K44oorYvz48VFbWxszZsyIl156qTTDAgAAw1JJo2nz5s1x1FFHxY033rjdx6+//vr47ne/G8uWLYu1a9fGXnvtFTNnzozf/e53AzwpAAAwXFVkWZaVeoiIiIqKili+fHnMnj07IopHmSZMmBBf//rX4+/+7u8iIqK9vT3GjRsXt9xyS5x11lklnBYAABguqks9QG82btwYra2tMWPGjO5tdXV1MXXq1FizZk2v0dTZ2RmdnZ3d9wuFQrzzzjuxzz77REVFRb/PDQAAlE6WZfHee+/FhAkTorKyb06sK9toam1tjYiIcePG9dg+bty47se2p7m5Oa666qp+nQ0AAChvmzZtiv33379PvlfZRtPuuuyyy2LhwoXd99vb2+OAAw6ITZs2RS6XK+FkAABAf8vn89HQ0BB77713n33Pso2m+vr6iIhoa2uL8ePHd29va2uLKVOm9Pq8mpqaqKmp2WZ7LpcTTQAAMEz05aU5Zfs5TRMnToz6+vpYuXJl97Z8Ph9r166NadOmlXAyAABgOCnpkab3338/Xn755e77GzdujOeeey7GjBkTBxxwQMyfPz/+4R/+IQ477LCYOHFiXH755TFhwoTud9gDAADobyWNpqeffjpOOeWU7vtbr0WaO3du3HLLLfGNb3wjNm/eHOeff368++67ccIJJ8QDDzwQI0eOLNXIAADAMFM2n9PUX/L5fNTV1UV7e7trmgAAYIjrj7//y/aaJgAAgHIgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAw9HR0RLS1Fb8C7CHRBAAMHS0tEXPmRIwaFVFfX/w6Z07E6tWlngwYxEQTADA0LF0aMX16xL33RhQKxW2FQvH+iSdGLFtW2vmAQUs0AQCDX0tLRFNTRJZFbNnS87EtW4rbv/Y1R5yA3SKaAIDBb9GiiKqq9JqqqojFiwdmHmBIEU0AwODW0RFxzz3bHmH6uC1bIpYv9+YQwC4TTQDA4JbP/+Eaph0pFIrrAXaBaAIABrdcLqJyJ/+kqawsrgfYBaIJAOg7pfh8pNraiFmzIqqr0+uqqyPOOKO4HmAXiCYAYM+V+vORFi6M6OpKr+nqiliwYGDmAYYU0QQA7Jly+HykE06IWLIkoqJi2yNO1dXF7UuWRDQ29v8swJAjmgCA3VdOn490wQURq1YVT9Xbeo1TZWXx/qpVxccBdsMOTv4FAEjY+vlIqbf73vr5SANxlKexsXjr6Ci+S14u5xomYI+JJgBg92z9fKQdvd33//98pIEKmNpasQT0GafnAQC7x+cjAcOEaAIAdo/PRwKGCdEEAOwen48EDBOiCQDYfT4fCRgGRBMAsPt8PhIwDIgmAGDP+HwkYIjzluMAwJ7z+UjAECaaAIC+4/ORgCHI6XkAAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgISyjqaurq64/PLLY+LEiVFbWxuHHHJIXH311ZFlWalHAwAAhonqUg+Qct1118XSpUvj1ltvjSOOOCKefvrpOO+886Kuri4uuuiiUo8HAAAMA2UdTY8//njMmjUrTjvttIiIOOigg+KOO+6IJ598stfndHZ2RmdnZ/f9fD7f73MCAABDV1mfnnf88cfHypUr45e//GVERDz//PPR0tISX/ziF3t9TnNzc9TV1XXfGhoaBmpcAABgCKrIyvgCoUKhEH//938f119/fVRVVUVXV1dcc801cdlll/X6nO0daWpoaIj29vbI5XIDMTYAAFAi+Xw+6urq+vTv/7I+Pe+uu+6K73//+3H77bfHEUccEc8991zMnz8/JkyYEHPnzt3uc2pqaqKmpmaAJwUAAIaqso6mSy65JC699NI466yzIiLiyCOPjFdffTWam5t7jSYAAIC+VNbXNH3wwQdRWdlzxKqqqigUCiWaCAAAGG7K+kjT6aefHtdcc00ccMABccQRR8S6deti0aJF8Zd/+ZelHg0AABgmyvqNIN577724/PLLY/ny5fHmm2/GhAkT4uyzz44rrrgiRowYsVPfoz8uBAMAAMpTf/z9X9bR1BdEEwAADB/98fd/WV/TBAAAUGqiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0ATA8dXREtLUVvwJAgmgCYHhpaYmYMydi1KiI+vri1zlzIlavLvVkAJQp0QTA8LF0acT06RH33htRKBS3FQrF+yeeGLFsWWnnA6AsiSYAhoeWloimpogsi9iypedjW7YUt3/ta444AbAN0QTA8LBoUURVVXpNVVXE4sUDMw8Ag4ZoAmDo6+iIuOeebY8wfdyWLRHLl3tzCAB6EE0ADH35/B+uYdqRQqG4HgB+TzQBMPTlchGVO/krr7KyuB4Afk80ATD01dZGzJoVUV2dXlddHXHGGcX1APB7ogmA4WHhwoiurvSarq6IBQsGZh4ABg3RBMDwcMIJEUuWRFRUbHvEqbq6uH3JkojGxtLMB0DZEk0ADB8XXBCxalXxVL2t1zhVVhbvr1pVfBwAPmYHJ3cDwBDT2Fi8dXQU3yUvl3MNEwBJogmA4am2ViwBsFOcngcAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQELZR9Prr78ef/EXfxH77LNP1NbWxpFHHhlPP/10qccCAACGiepSD5Dy29/+NhobG+OUU06J+++/P/bbb7946aWX4lOf+lSpRwMAAIaJso6m6667LhoaGuLmm2/u3jZx4sQSTgQAAAw3ZX163n/+53/GMcccE2eeeWaMHTs2jj766LjpppuSz+ns7Ix8Pt/jBgAAsLvKOpp+9atfxdKlS+Owww6LBx98MObNmxcXXXRR3Hrrrb0+p7m5Oerq6rpvDQ0NAzgxAAAw1FRkWZaVeojejBgxIo455ph4/PHHu7dddNFF8dRTT8WaNWu2+5zOzs7o7Ozsvp/P56OhoSHa29sjl8v1+8wAAEDp5PP5qKur69O//8v6SNP48ePjM5/5TI9tn/70p+O1117r9Tk1NTWRy+V63AAAAHZXWUdTY2NjbNiwoce2X/7yl3H
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 20,
2022-11-24 07:22:33 +01:00
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2023-04-13 09:46:29 +02:00
"[<matplotlib.lines.Line2D at 0x7f428364a290>]"
2022-11-24 07:22:33 +01:00
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 20,
2022-11-24 07:22:33 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJ8klEQVR4nO3dd3RUdf7/8ddMOkkmEEqKhA6hJ1GUpYiyoqjoUpS2rj/X3f1aFpcSRcFdsIuggrCKuLvfFb/uSlOxoKKIhQ6KCZ3QIZSEnklC6sz9/TE4iEIIySR3yvNxzpycz82dmRfnOpl5OTPvazEMwxAAAAAA4IKsZgcAAAAAAG9GaQIAAACAClCaAAAAAKAClCYAAAAAqAClCQAAAAAqQGkCAAAAgApQmgAAAACgApQmAAAAAKgApQkAAAAAKkBpAgAAAIAKmFqali1bpttvv12JiYmyWCz64IMPzvu9YRiaOHGiEhISFBERoT59+mjnzp3mhAUAAAAQkEwtTYWFhUpJSdFrr712wd9PmTJFM2bM0KxZs7R27VpFRkaqb9++Ki4uruWkAAAAAAKVxTAMw+wQkmSxWLRw4UINGDBAkutdpsTERD388MN65JFHJEl5eXmKi4vT7NmzNWzYMBPTAgAAAAgUwWYHuJi9e/cqJydHffr0cW+LiYlR165dtXr16ouWppKSEpWUlLjXTqdTJ0+eVP369WWxWGo8NwAAAADzGIah/Px8JSYmymr1zAfrvLY05eTkSJLi4uLO2x4XF+f+3YVMmjRJTz31VI1mAwAAAODdsrOz1bhxY4/clteWpqoaP3680tPT3eu8vDw1adJE2dnZstlsJiYDAAA17bu9J/WHt76TYUjBVove/uM16tS4rtmxANQiu92upKQkRUdHe+w2vbY0xcfHS5Jyc3OVkJDg3p6bm6vU1NSLXi8sLExhYWG/2G6z2ShNAAD4sdNnSvW3T3fLElpHFkmP3JysHu2bmB0LgEk8+dUcrz1PU/PmzRUfH6+lS5e6t9ntdq1du1bdunUzMRkAAPA2hmHo8YWbdCTPNWG3W4v6ur9XS5NTAfAXpr7TVFBQoF27drnXe/fuVWZmpmJjY9WkSRONHj1azz77rFq3bq3mzZtrwoQJSkxMdE/YAwAAkKT532fr002u7zzHRIRo6tAUBVkZAAXAM0wtTd9//7169+7tXv/4XaR77rlHs2fP1qOPPqrCwkLdd999On36tHr27KnFixcrPDzcrMgAAMDL7D5WoCc/2upeT76jkxJiIkxMBMDfeM15mmqK3W5XTEyM8vLy+E4TAAB+prTcqTteX6VNh/IkScOvSdKkQZ1NTgXATDXx+t9rv9MEAABwKS8vyXIXphYNIzXhtvYmJwLgjyhNAADAJ63cdVxvfLtHkhQSZNGMYWmqE+q1g4EB+DBKEwAA8DknC0uVPj/TvX60b1t1vCLGvEAA/BqlCQAA+BTDMPTYexuVay+RJF3buoH+2LO5yakA+DNKEwAA8CnvrDugJVtzJUmxkaF6eXCKrIwXB1CDKE0AAMBn7Dqar2cW/XS8eGc1snEqEgA1i9IEAAB8Qkm5Q3+Zk6niMqck6e5fNdWN7eNMTgUgEFCaAACAT5iyOEvbjtglSa0bRemv/dqZnAhAoKA0AQAAr/ftjmP63xV7JUmhwVbNGJ6m8JAgk1MBCBSUJgAA4NWOF5To4fkb3OtxN7dVuwSbiYkABBpKEwAA8FqGYeixdzfqeIFrvPj1yQ11b49m5oYCEHAoTQAAwGu9vWa/lm4/KklqEBWqF+9MkcXCeHEAtYvSBAAAvFJWTr6e/WSbe/3inSlqGB1mYiIAgYrSBAAAvE5xmUMj52SotNw1Xvz33Zupd9tGJqcCEKgoTQAAwOu88Nl2ZeXmS5Laxkdr3C1tTU4EIJBRmgAAgFf5anuuZq/aJ0kKY7w4AC9AaQIAAF7jaH6xxi7Y6F7/rV87tYmLNjERAFCaAACAl3A6DY1dsFEnCkslSX3aNdLvftXU5FQAQGkCAABe4s1V+/TtjmOSpIbRYZp8R2fGiwPwCpQmAABguq2H7Zr82Xb3euqQFNWPYrw4AO9AaQIAAKYqKnVo5NwMlTpc48X/59rmurZ1Q5NTAcA5lCYAAGCq5z7dql1HCyRJHRJteqRvssmJAOB8lCYAAGCaL7bk6D9rDkiSwkOsmj4sTWHBjBcH4F0oTQAAwBS59mI99t658eJP3N5BrRpFmZgIAC6M0gQAAGqd02kofX6mTp0pkyT17RCnYVcnmZwKAC6M0gQAAGrdv1bs0cpdJyRJ8bZwvTCI8eIAvBelCQAA1KrNh/L04udZkiSLRZo6NEX1IkM9eydFRVJurusnAFQTpQkAANSaM6XlGjknQ2UOQ5L0wHUt1b1lA8/dwYoV0qBBUlSUFB/v+jlokLRypefuA0DAoTQBAIBa88yirdpzvFCS1LlxjMb0aeO5G3/9dalXL+njjyWn65xPcjpd62uvlWbN8tx9AQgolCYAAFArPtt0RHPWZUuS6oQGafqwNIUGe+ilyIoV0ogRkmFI5eXn/6683LX9z3/mHScAVUJpAgAANe5IXpHGvb/JvX7yNx3UvEGk5+5g6lQp6BLndwoKkqZN89x9AggYlCYAAFCjHE5DY+ZlKq/INV68X+cEDb6qsefuoKhI+vDDX77D9HPl5dLChQyHAHDZKE0AAKBGvbFst9bsOSlJSowJ1/MDOnl2vLjdfu47TJfidLr2B4DLQGkCAAA1JjP7tKZ+sUOSZLVIrwxLU0ydEM/eic0mWSv5ksZqde0PAJeB0gQAADznJ+dHKigp16i5GSp3usaLj+jdStc0j/X8fUZESP37S8HBFe8XHCwNHOjaHwAuA6UJAABU3wXOj/Tk6L9r/4kzkqS0JnU18obWNXf/6emSw1HxPg6HNGZMzWUA4LcoTQAAoHoucH6kj9v00LsxrnMwRVmcmj40TSFBNfiyo2dPaeZMyWL55TtOwcGu7TNnSj161FwGAH6L0gQAAKruAudHOmhrqMf7jnDv8szH09Rke0bNZ3ngAWn5ctdH9X78jpPV6lovX+76PQBUwSU+/AsAAFCBH8+PdLYwlVusGnPbI8oPj5Ik9d/yjQZmLXedH6k23uXp0cN1KSpyTcmz2fgOE4BqozQBAICq+fH8SD8Z9z2z2xB9l9RBktT4dI6e+WLm+edHqq0CExFBWQLgMXw8DwAAVM3Pzo+0PrGtpvcYLkmyOh2avugl2UpdgyA4PxIAX0ZpAgAAVfOT8yPZQ+to1O2PyGENkiSNXDVXVx3afm5fzo8EwIdRmgAAQNX85PxIE296UAfrxkuSuhzcoodWzTu3H+dHAuDj+E4TAACouvR0LdxxWh906C1Jii4p1LSPX1awce5je5wfCYCv450mAABQZQfaXakJt58rRM99/pqS7EddC86PBMBPUJoAAECVlDucGjUvQwWG6+XEIPsu/SZrheuXnB8JgB/h43kAAKBKZizdqYwDpyVJTWLr6OmnRkgv38f5kQD4HUoTAAC4bOv2ntSrX++SJAVZLZo+LFVRYcGSgilLAPwOH88DAACXJa+oTGPmZcppuNbpN7ZRWpN65oYCgBpEaQIAAJVmGIYeX7hJh04XSZK6No/VA9e1NDkVANQsShMAAKi0d9cf1Ccbj0iSbOHBmjY0VUFWi8mpAKBmUZoAAECl7D1eqCc+2uJev3BHZyXW5ftLAPwfpQkAAFxSablTo+Zm6EypQ5I0tEuSbu2UYHIqAKgdlCYAAHBJ077coY0H8yRJzRtEauLt7U1OBAC1h9IEAAAqtGr3cc36drckKSTIohnD0hQZxllLAAQOShMAALioU4WlSp+3QcbZ8eIP35SsTo1jzA0FALWM0gQAAC7IMAyNf3+TcuzFkqTuLevrvmtbmJwKAGofpQkAAFzQ3O+ytXhLjiSpbp0QTR2SKivjxQEEIEoTAAD4hV1HC/T0x1vd68l3dFZ8TLiJiQDAPJQmAABwnpJyh0bNzVBRmWu8+G+7NlHfDvEmpwIA81CaAADAeV7+Yoe2HLZLklo2jNSEfowXBxD
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X1, y, eps=0.00001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, *bias*) – zachodzi **niedostateczne dopasowanie** (*underfitting*)."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 21,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2023-04-13 09:46:29 +02:00
"[<matplotlib.lines.Line2D at 0x7f42157242e0>]"
2022-11-24 07:22:33 +01:00
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 21,
2022-11-24 07:22:33 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABW+klEQVR4nO3dd3hUZeL28XtKGiEVSIMEQpHem4AUlRUVFcSK6KJrFwu6qyv+XnWtiLrYBcsqrGJXxIqL9F5C7wQChEASWgohdea8f0wciECAkORM+X6uay5yTs6EG8fMzD3Pc55jMQzDEAAAAADgpKxmBwAAAAAAT0ZpAgAAAIBKUJoAAAAAoBKUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEqYWprmzZunK6+8UgkJCbJYLPr+++8rfN8wDD311FOKj49XSEiIBg4cqG3btpkTFgAAAIBfMrU0FRQUqGPHjnrnnXdO+v2XX35Zb775piZOnKilS5cqNDRUgwYNUlFRUS0nBQAAAOCvLIZhGGaHkCSLxaKpU6dq6NChklyjTAkJCfr73/+uf/zjH5Kk3NxcxcbGatKkSbrxxhtNTAsAAADAX9jNDnAqaWlpyszM1MCBA937IiIi1LNnTy1evPiUpam4uFjFxcXubafTqUOHDqlevXqyWCw1nhsAAACAeQzDUH5+vhISEmS1Vs/EOo8tTZmZmZKk2NjYCvtjY2Pd3zuZsWPH6plnnqnRbAAAAAA8W3p6uho1alQtP8tjS1NVjRkzRo888oh7Ozc3V0lJSUpPT1d4eLiJyQAAAADUtLy8PCUmJiosLKzafqbHlqa4uDhJUlZWluLj4937s7Ky1KlTp1PeLygoSEFBQSfsDw8PpzQBAAAAfqI6T83x2Os0JScnKy4uTjNnznTvy8vL09KlS9WrVy8TkwEAAADwJ6aONB05ckSpqanu7bS0NK1evVrR0dFKSkrS6NGj9fzzz6tFixZKTk7Wk08+qYSEBPcKewAAAABQ00wtTStWrNCFF17o3v7jXKSRI0dq0qRJeuyxx1RQUKC77rpLOTk5uuCCCzR9+nQFBwebFRkAAACAn/GY6zTVlLy8PEVERCg3N5dzmgAAAAAfVxPv/z32nCYAAAAA8ASUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEpQmgAAAACgEpQmAAAAAKgEpQkAAAAAKkFpAgAAAIBKUJoAAAAAoBKUJgAAAACoBKUJAAAAACpBaQIAAACASlCaAAAAAKASlCYAAAAAqASlCQAAAAAqQWkCAAAAgEpQmgAAAACgEpQmAAAAAKgEpQkAAAAAKkFpAgAAAIBKUJoAAAAAoBJ2swMAAAB4gzKHUyUOp0rKXLfiMqcMQ2oYFSKb1WJ2PAA1iNIEAAB8SmZukRZs3KsDB/NUbA9UicXqLjolDlfZOX77ZF8Xl/1xnMO932mc/O8LDbSpc1KUujZ23TonRSosOKB2/9EAahSlCQAAeLVSh1Mrdh7WnK3ZmrtypzbnO2v17y8ocWhB6gEtSD0gSbJapJZx4eraOFLdGkera+MoNYoKkcXCaBTgrShNAADA6+zNKdScLfs1d2u2FqYe1JHisnP+mVaLFGi3KtBmVaDdpiC79bht67Ht4/aVlDm1Zk+OsvKK3T/HaUib9uVp0748fbpktyQpJixI3ZpEqUtSlLo1iVbbhHAF2Di1HPAWlCYAAODxSsqcWrHzkOZs3a85W7K1NevISY+zGE512Jeq/mkr1Dp7p4LKShToKFWgo8z154cfKLBbFwWdpATZq1hiDMNQRk6hUnYdVsquw1qx87A2Z+ZVmM6XnV+sX9Zl6pd1mZKk4ACrOjSKVLfGUe4yFVknsEp/P4CaZzEM4xQzdH1DXl6eIiIilJubq/DwcLPjAACAM7Tn8FHN2bJfc7bs16LtB3S0xHHS46JDA9Vv30b1X/Cj+qUuV73CvJP/QLtdGjJE+uabGkztkl9UqjXpuVqx65BSdh3Wqt05px0Nax5TV90aR6lL4yh1axyl5PqhTOkDqqAm3v9TmgAAgEcoLnNoWdqh8qKUre37C056nMUidWwUqQEtG2hAyxi1jw6ULTxMcp7BuUxWq3TkiBQSUs3pK+dwGtqSma+U8hK1Ytdh7TlcWOl9okMDy6fzRalX03rq0CiCEgWcAUpTFVCaAADwXLsPHtXcrdnlo0kHVVh68tGkeqGB6n9eA/Vv2UD9WjRQVOhxU9mysqS4uDP/SzMzpdjYc0x+7rLyitzT+VJ2HdKGvXkqO9USfZK6NY7S/Rc1V//zGlCegEpQmqqA0gQAgOcoLnNoyY5DmrMlW3O37NeOAycfTbJapM5JURpwnms0qW1CuKynuhZSYaFUt65HjzSdicISh9bsyXGfG5Wy67ByC0tPOK5Dowjdf2FzDWwde+r/JoAfozRVAaUJAIBaVFgo5eVJ4eEVionDaejblD0aP2OrMvOKTnrXBmFBrtGk8xqob4v6Z7cwwrBh0o8/SmWVnDdUi+c0VQen09COA0e0eMchTV60U6nZFRe/aBUXpvsvaq7L2sVzcV3gOJSmKqA0AQBQCxYskMaPl6ZNc434WK3SkCEyHnlEM6NbaNz0zdr2pzf9NqtFXZOi1L+lqyi1ia9kNOlM/v5+/aTK3tZYLNL8+VKfPlX7O0zkdBqaviFTb81K1aZ9FRe6aNYgVKMubK6rOiZUeQVAwJdQmqqA0gQAQA2bMEEaNUqy2SqM9KQkttW4vrdoWWK7Codf1CpG13ZtpD7N6ysiJKD6ckycKN133wk5ZLdLDof07rvSPfdU399nAsMwNGtztt6clao16TkVvpcUXUf3Dmima7o0UqCd8gT/RWmqAkoTAAA16CQjPNujG+qVfn/V9JYVR3Q6J0VqzGWt1SM5uubyLFwovfaaNHXqsRGvq6+WHn7YK0eYTsUwDC1IPaC3ZqVqWdqhCt+LjwjWPf2b6YbuiQoOsJmUEDAPpakKKE0AANSg484lyg6N0ut9btKXHS+Rw3rszXrTQxl6rHiLBk16tfZWfTvFuVW+aOmOg3p7dqrmbztQYX+DsCDd1bepbuqZpNAgu0npgNpHaaoCShMAADWkfNW6fHuQ3u95jT7sNlSFgcHubzc4ckijF3ymG9b+T3aLPHbVOl+xavdhvTM7Vb9vyq6wP6pOgG6/IFl/7d1E4cHVOB0S8FCUpiqgNAEAUDNK9u7TlKvu1lu9b9ShOhHu/XWLj+rupd/q9hXfq05p8bE7eMj1kXzdhr25emd2qn5dn1lhXYywYLtu691Et/VJrnidK8DHUJqqgNIEAED1cjoN/bh2r16dvlnpOceWDw9wlOrmVb/o/kVfql5hxRXePPn6SL5qW1a+3pmdqh/W7NXx18wNDbTp5l6NdWffpqpfN8i8gEANoTRVAaUJAIDqs2DbAb00fZPWZ1QsRUM2zNHf53+ipNysE+/kZddH8jU7DxTo3Tmp+m5lhsqOa0/BAVYN75Gku/s1U1xEcCU/AfAulKYqoDQBAHDu1mfkatz0zScsNnBBfbse//f9apeZeuo7e/H1kXzJnsNH9d7cHfpyebpKHE73/kCbVdd1a6R7+jdTYnQdExMC1YPSVAWUJgAAqi790FG9+r8tmrZ6b4X9bRPC9fhlrdS3RQO/uD6SL8nKK9L783ZoytJdKio9Vp7sVotu6pmkMZe1VkggS5XDe1GaqoDSBADA2TtUUKK3Zm3Tp0t2qdRx7K1Co6gQPTqopa7skCCr9bjlw/3k+ki+5MCRYv1nQZr+u2inCkoc7v0tY8P09k2d1SI2zMR0QNVRmqqA0gQAwJk7WlKmjxak6b25O5RffGzUKKpOgB64qIVGnJ+kIHsloxB+dH0kX5FztEQfL9yp9+ftUGGpqzyFBNj07JC2uq5bosnpgLNHaaoCShMAAKfncBr6akW6XpuxVdn5x5YJDw6w6o4Lmuqu/k25xo+P25aVr/s/W6UtWfnufcM6N9RzQ9txcVx
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X2, y, eps=0.000001)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 22,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2023-04-13 09:46:29 +02:00
"[<matplotlib.lines.Line2D at 0x7f421483cd60>]"
2022-11-24 07:22:33 +01:00
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 22,
2022-11-24 07:22:33 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2022-11-28 11:52:13 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUuklEQVR4nO3dd3hUZcLG4WcmvYcQ0iBAgBCQ3qUqgqKigrgWLIu9LK4Crm2/VXdtqOuCq2vZxYK9i12RKiC9GhAhgQABkpBQ0uvM+f6YOBCBECCTM+V3X9dc5Jw5Ex4cJzNP3nPe12IYhiEAAAAAwDFZzQ4AAAAAAO6M0gQAAAAA9aA0AQAAAEA9KE0AAAAAUA9KEwAAAADUg9IEAAAAAPWgNAEAAABAPShNAAAAAFAPShMAAAAA1IPSBAAAAAD1MLU0LVq0SBdffLGSkpJksVj0+eef17nfMAw9/PDDSkxMVEhIiEaOHKmMjAxzwgIAAADwSaaWptLSUvXo0UMvvvjiMe9/5pln9Pzzz+uVV17RihUrFBYWplGjRqmioqKJkwIAAADwVRbDMAyzQ0iSxWLRrFmzNHbsWEmOUaakpCTdc889+stf/iJJKiwsVHx8vGbOnKmrrrrKxLQAAAAAfIW/2QGOJysrS7m5uRo5cqRzX1RUlAYMGKBly5YdtzRVVlaqsrLSuW2323XgwAE1b95cFovF5bkBAAAAmMcwDBUXFyspKUlWa+OcWOe2pSk3N1eSFB8fX2d/fHy8875jmTp1qv7xj3+4NBsAAAAA95adna1WrVo1yvdy29J0qh588EFNmTLFuV1YWKjWrVsrOztbkZGRJiYDAAC+qqrGrjEvLlH2gXJJ0ovX9NJZHeNMTgV4p6KiIiUnJysiIqLRvqfblqaEhARJUl5enhITE5378/Ly1LNnz+M+LigoSEFBQUftj4yMpDQBAABTvL4kS3tKLbIGhWpgu+a6qE97LhsAXKwxX2Nuu05TSkqKEhISNG/ePOe+oqIirVixQgMHDjQxGQAAQMMVllfr+fmHl0z5v9GdKUyAhzF1pKmkpESZmZnO7aysLK1fv14xMTFq3bq1Jk2apMcff1ypqalKSUnRQw89pKSkJOcMewAAAO7upQWZOlRWLUm6tFdLdW0ZZXIiACfL1NK0evVqDR8+3Ln927VIEyZM0MyZM3XfffeptLRUt956qw4dOqQhQ4bo+++/V3BwsFmRAQAAGmz3wTK9sXSHJCnQ36p7zutobiAAp8Rt1mlylaKiIkVFRamwsJBrmgAAQJOa9ME6fb5+ryTptrPa6cELOpucCPB+rvj877bXNAEAAHiy9N2FzsLULDRAfzq7g8mJAJwqShMAAEAjMwxDT3z7i3P7rhGpigoJMDERgNNBaQIAAGhk83/dp+XbD0iS2jYP1TUD2picCMDpoDQBAAA0ohqbXVO/+9W5fd/5nRToz0cuwJPxCgYAAGhEH67OVua+EklS79bRuqBrgsmJAJwuShMAAEAjKams0fQ5LGQLeBtKEwAAQCP536LtKiiplCRd0DVBfdrEmJwIQGOgNAEAADSCvKIKzVi0XZLkb7XovvM7mZwIQGOhNAEAADSCaT9sVXm1TZJ07ZltlBIbZnIiAI2F0gQAAHCatuQW6+M12ZKkiCB/3TUi1eREABoTpQkAAOA0Tf1us+yG4+s/De+gmLBAcwMBaFSUJgAAgNOwJKNAC7fkS5KSooJ1w+C25gYC0OgoTQAAAKfIbjf05Lebndv3np+m4AA/ExMBcAVKEwAAwCmatW6PfskpkiR1bRmpMT1ampwIgCtQmgAAAE5BRbVNz/6wxbn91ws6y2plIVvAG1GaAAAATpLdbmj63K3KKayQJA1Pa6FBHWJNTgXAVfzNDgAAAOBJtuWX6MHP0rUy64AkyWqRHryws8mpALgSpQkAAKABqmrs+t+ibXp+fqaqauzO/XeP6KiO8REmJgPgapQmAACAE1i366Ae+DRdW/KKnfuSY0L05KXdNDS1hYnJADQFShMAAMBxlFTW6NnZW/Tmsh0yahevtVqkW4a206SRHRUSyPTigC+gNAEAABzD/F/z9LdZG7W3drIHSeqSFKmnL+uuri2jTEwGoKlRmgAAAI6QX1ypR7/+RV9t2OvcFxxg1eSRHXXTkBT5+zH5MOBrKE0AAACSDMPQx2t264lvNquwvNq5f0iHWD15aTe1bh5qYjoAZqI0AQAAn7ejoFR/nZWupdv2O/dFhwboodFnaFzvlrJYWLQW8GWUJgAA4LOqbXa9ujhLz83dqsojphEf2zNJD110hpqHB5mYDoC7oDQBAACf9PPuQ7r/03Rtzily7msZHaLHL+2q4WlxJiYD4G4oTQAAwPuUl0tFRVJkpBQSUueusqoaTfthq17/KUv2I6YRv35Qiu45r6PCgvh4BKAupn8BAADeY8kSadw4KTxcSkhw/DlunPTTT5KkH7fm67zpi/TqksOFqVNChGb9abAevvgMChOAY+InAwAA8A4vvyxNnCj5+Un22uuT7Hbpq6+0f/Z8Pf7A/zSrNMx5eKC/VZNGpuqWoe0UwDTiAOpBaQIAAJ5vyRJHYTIMqabGuduQNCttqB4752YdPKIwDWzXXE+O66aU2LBjfDMAqIvSBAAAPN+0aY4RpiMKU3ZUvP46aqIWp/R27ou0VepvV/TT5X1bMY04gAajNAEAAM9WXi598YUMu10HQiK1rXmyViR31UtnXq7ywGDnYaM3L9IjC15V3KO7JQoTgJNAaQIAAB7FZje052C5MvOLtW1fqTJ35Wvb+KnKbJ6sQyGRRx2fWJSvx394SSO2rXLsKCo6akY9AKgPpQkAALilimqbtuWXaFt+qTL3lTi+3leirILSOgvRSpJadTnq8RbDrj+u/Ub3LnpL4VXljp1Wq2MacgA4CZQmAADQeOpZH+l49pdU1i1G+SXK3FeiPYfKZRgN/6sTiwvUfn+2OhRkq/2B3RqQvVEdC3YdPsDfXxozhlEmACeN0gQAAE7fkiWOyRi++MIxzbfV6igo99wjDR4swzC0t7BCW/OKlZl3uBhtyy/RwbLqBv81/laL2saGqX2LMHWIC1f7FuHqEBeudpnpCj/nYtXbsmw2afLkRvjHAvA1lCYAAHB6jlgfybDblRfeXFtjW2vrHinjiY+0tV+eMhSmksqaE3+vWuFB/mofF35UOWodE3rsNZVaDZVeekn605+OmkVP/v6OwvTSS9LgwY3wDwbgayhNAADgpBmGofySSmXMW66tr32vref9SRmxrbU1to2KgsPrHlwpSccuTAmRwWofF6YOLcLVPi7c+WdcRNDJTwl+++1St27S9OnSrFl1R7wmT6YwAThlFsM4mbOFPU9RUZGioqJUWFioSC78BADgpB0ordKW3GJl7CvW1rxibc0rUUZe8UmdVteyulhp3dorNT5cqXERSo0LV7sWYYoIDnBN6FO4tgqAd3DF539GmgAAgCSpsLy6thQVa2tubTnaV6yCkqoGf4/EonylFuxSWsFOpRbsUseCXeqwP1vhNZVSSUnTFZiQEMoSgEZDaQIAwMdl5BXrPwsy9dWGvbI38PyTuIggdWwWqNTP3lFavqMgpRbsUmRV2fEfxPpIADwUpQkAAB+1OadI/5mfqW835hx30rnY8EClxkWoY3y4UuMjlJbgOLUuOjTQcQrcn4c7rh06EdZHAuDBKE0AAPiYjXsK9cL8DM3elFdnf0xYoC7omlBbjBxFqXl40PG/UUiIY5KFr76qO1vd77E+EgAPR2kCAMBHbMg+pBfmZ2ju5n119seGB+m2Ye10zZmtFRp4kh8NpkyRPv+8/mNYHwmAh6M0AQDg5dbsPKgX5mdo4Zb8OvvjIoJ0+1ntNb5/a4UE+p3aNx8yhPWRAHg9ShMAAF5q1Y4Den5ehhZnFNTZnxgVrDvObq8r+iYrOOAUy9KRWB8JgJejNAEA4EUMw9Dy7Y6ytGz7/jr3tYwO0cThHXRZn5YK8m+EsnSkwYMdN9ZHAuCFKE0AAHgBwzD0U+Z+PT8vQyt3HKhzX+uYUN05vIMu7d1SAX5W1wZhfSQAXoj
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-7)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (*variance*) – zachodzi **nadmierne dopasowanie** (*overfitting*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu – to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie (błąd systematyczny, *bias*)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wariancja (*variance*)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"40%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.3. Regularyzacja"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 23,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(\n",
" h,\n",
" fJ,\n",
" fdJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=0.001,\n",
" maxEpochs=1.0,\n",
" batchSize=100,\n",
" adaGrad=False,\n",
" logError=False,\n",
" validate=0.0,\n",
" valStep=100,\n",
" lamb=0,\n",
" trainsetsize=1.0,\n",
"):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na następnym wykładzie)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
"\n",
" XT, YT = X, Y\n",
"\n",
" m_end = int(trainsetsize * len(X))\n",
"\n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv]\n",
" XT, YT = X[mv:m_end], Y[mv:m_end]\n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
"\n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n, 1)\n",
"\n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end, :], YT[start:end, :]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
"\n",
" if logError:\n",
" errorsX.append(float(i * batchSize) / m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i * batchSize) / m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
"\n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 33,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:, 0], data[:, 1], n)\n",
"Y = data[:, 2]\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 34,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(\n",
" X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25\n",
"):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" plt.subplot(121)\n",
" plt.scatter(\n",
" X[:, 2].tolist(),\n",
" X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100,\n",
" cmap=plt.cm.get_cmap(\"prism\"),\n",
" )\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=alpha,\n",
" adaGrad=adaGrad,\n",
" maxEpochs=maxEpochs,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=validate,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02), np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
" plt.ylim(-1, 1.2)\n",
" plt.xlim(-1, 1.2)\n",
" plt.legend()\n",
" plt.subplot(122)\n",
" plt.plot(err[0], err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2], err[3], lw=3, label=\"Validation error\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 35,
2022-11-24 07:22:33 +01:00
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2023-04-13 09:46:29 +02:00
"/tmp/ipykernel_3088/2678993393.py:5: RuntimeWarning: overflow encountered in exp\n",
2022-11-28 11:52:13 +01:00
" y = 1.0 / (1.0 + np.exp(-x))\n",
2023-04-13 09:46:29 +02:00
"/tmp/ipykernel_3088/2651435526.py:38: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2022-11-28 11:52:13 +01:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
2022-11-24 07:22:33 +01:00
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
]
},
{
"data": {
2023-04-13 09:46:29 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSAAAAKZCAYAAACod4UiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hT5fvH8XeSbiilrJZR9kb2BlFUsCCiuBUVF6J+GSryE5ClgOBARIbiAHGLOFGQqagggrJk7w1llpa2dCX5/XGkkA6apEnT8XldVy44J+c5zx0KpbnzPPdtstvtdkRERERERERERES8wOzrAERERERERERERKToUgJSREREREREREREvEYJSBEREREREREREfEaJSBFRERERERERETEa5SAFBEREREREREREa9RAlJERERERERERES8RglIERERERERERER8RolIEVERERERERERMRrlIAUERERERERERERr1ECUkRERERERERERLxGCUgREREREWDGjBlUr16doKAg2rZty9q1a694/ZQpU6hXrx7BwcFERUXx7LPPkpycnE/RioiIiBQeSkCKiIiISLE3d+5cBg8ezJgxY1i/fj1NmzYlOjqakydPZnv9559/zrBhwxgzZgzbt29n1qxZzJ07lxdeeCGfIxcREREp+Ex2u93u6yBERERERHypbdu2tG7dmunTpwNgs9mIiopi4MCBDBs2LMv1AwYMYPv27Sxfvjzj3HPPPceaNWtYuXJlvsUtIiIiUhj4+ToAX7DZbBw7dozQ0FBMJpOvwxERERFxmd1u5/z581SqVAmzWZta8iI1NZV169YxfPjwjHNms5kuXbqwevXqbMd06NCBTz/9lLVr19KmTRv27dvHwoULefDBB7O9PiUlhZSUlIxjm83G2bNnKVu2rH4eFRERkULJlZ9Hi2UC8tixY0RFRfk6DBEREZE8O3z4MFWqVPF1GIXa6dOnsVqtREREOJyPiIhgx44d2Y7p3bs3p0+f5uqrr8Zut5Oens6TTz6Z4xbsiRMn8tJLL3k8dhERERFfc+bn0WKZgAwNDQWMP6BSpUr5OBoRERER18XHxxMVFZXxc43krxUrVjBhwgTefvtt2rZty549e3j66acZN24co0aNynL98OHDGTx4cMZxXFwcVatW9erPowdeaUd1+5GM439bTaBJ1z5emcvntnwLPw667IQJnvoTSmvRgYiIiLe48vNosUxAXtzmUqpUKSUgRUREpFDT9t28K1euHBaLhRMnTjicP3HiBJGRkdmOGTVqFA8++CB9+/YFoHHjxiQmJtKvXz9GjBiRZRtSYGAggYGBWe7jzZ9HQ4MslLJd+vtRMiS46P7su/sbCLzs30LN66BqI9/FIyIiUow48/OoCgaJiIiISLEWEBBAy5YtHRrK2Gw2li9fTvv27bMdk5SUlCXJaLFYAKMekuSjs/th/++O55o/4JtYREREJFvFcgWkiIiIiMjlBg8ezEMPPUSrVq1o06YNU6ZMITExkUceeQSAPn36ULlyZSZOnAhAz549mTx5Ms2bN8/Ygj1q1Ch69uyZkYiUfLLxc8fjoNJQ/2afhCIiIiLZUwJSRERERIq9e+65h1OnTjF69GhiYmJo1qwZixYtymhMc+jQIYcVjyNHjsRkMjFy5EiOHj1K+fLl6dmzJy+//LKvXkLxZLNlTUA2uRv8g3wTj4iIiGRLCUgRERERH7DZbKSmpub4vL+/v1bS5bMBAwYwYMCAbJ9bsWKFw7Gfnx9jxoxhzJgx+RCZ5OjQnxB/xPGctl+LSDFhtVpJS0vzdRhShHny51ElIEVERETyWWpqKvv378dms13xutKlSxMZGalGMyI5+Xeu43GFhhDZxDexiIjkE7vdTkxMDOfOnfN1KFIMeOrnUSUgRURERPKR3W7n+PHjWCwWoqKisjQyuXhNUlISJ0+eBKBixYr5HaZIwZeWDFt/cDzX5G5Qwl5EiriLyccKFSoQEhKiDyrFKzz986gSkCIiIiL5KD09naSkJCpVqkRISEiO1wUHBwNw8uRJKlSooO3YkmdFrjn37iWQEud4rvFdvolFRCSfWK3WjORj2bJlfR2OFHGe/Hk060fuIiIiIuI1VqsVgICAgFyvvZigVH0n8QQTRSwDmXn7dbWrIayKb2IREcknF38muNKHmCKe5KmfR5WAFBEREfEBZ7ZLaUuV5IWdIvz350KssQLyck3u9k0sIiI+oJ8RJL946u+aEpAiIiIiIlK4bPsBrJd1kbcEQMNbfRePiIiIXJESkCIiIiIiUrj8+5Xjcd1uEFzaJ6GIiIjvVK9enSlTpjh9/YoVKzCZTOog7gNKQIqIiIiISOFx7hAcXOV4TtuvRUQKNJPJdMXHiy++6NZ9//77b/r16+f09R06dOD48eOEhYW5NZ+4T12wRURERESk8Nj6neNxUBjUudE3sYiIiFOOHz+e8fu5c+cyevRodu7cmXGuZMmSGb+32+1YrVb8/HJPWZUvX96lOAICAoiMjHRpTH5JTU3N0qTQarViMpkwm11bP+juOG8qOJGIiIiIFCN2e+4diW02Wz5EIlLIZE5ANrwV/AJ9E4uIiI/ZbHbOJKT47GGz5f7zDEBkZGTGIywsDJPJlHG8Y8cOQkND+fnnn2nZsiWBgYGsXLmSvXv3cuuttxIREUHJkiVp3bo1y5Ytc7hv5i3YJpOJDz74gNtuu42QkBDq1KnD/PnzM57PvAV7zpw5lC5dmsWLF9OgQQNKlixJt27dHBKm6enpDBo0iNKlS1O2bFmGDh3KQw89RK9eva74mleuXEmnTp0IDg4mKiqKQYMGkZiY6BD7uHHj6NOnD6VKlaJfv34Z8cyfP5+GDRsSGBjIoUOHiI2NpU+fPoSHhxMSEkL37t3ZvXt3xr1yGleQaAWkiIiISD7y9/fHZDJx6tQpypcvn21nQbvdTmpqKqdOncJsNmf5NFyk2Io9AMc2OJ5rdJtPQhERKQhik1JpOX5Z7hd6ybqRXShb0jMfAg0bNoxJkyZRs2ZNwsPDOXz4MDfddBMvv/wygYGBfPzxx/Ts2ZOdO3dStWrVHO/z0ksv8dprr/H6668zbdo07r//fg4ePEiZMmWyvT4pKYlJkybxySefYDabeeCBBxgyZAifffYZAK+++iqfffYZH374IQ0aNOCtt97i+++/57rrrssxhr1799KtWzfGjx/P7NmzOXXqFAMGDGDAgAF8+OGHGddNmjSJ0aNHM2bMGAD++OMPkpKSePXVV/nggw8oW7YsFSpU4L777mP37t3Mnz+fUqVKMXToUG666Sa2bduGv79/xuvIPK4gUQJSRAqeHTvgnXfgt9/g/HkoVQquuw6efBLq1vV1dCIieWKxWKhSpQpHjhzhwIEDV7w2JCSEqlWrFqjtM1J4ObdGpYDb9oPjcXAZqH6Nb2IRERGPGjt2LF27ds04LlOmDE2bNs04HjduHN999x3z589nwIABOd7n4Ycf5r777gNgwoQJTJ06lbVr19KtW7dsr09LS2PmzJnUqlULgAEDBjB27NiM56dNm8bw4cO57TbjA6/p06ezcOHCK76WiRMncv/99/PMM88AUKdOHaZOncq1117LO++8Q1BQEADXX389zz33XMa4P/74g7S0NN5+++2M134x8bhq1So6dOgAwGeffUZUVBTff/89d911V8bruHxcQaMEpIgUHMeOQZ8+sHw5+PlBevql5zZvhjffhOho+OgjiIjwXZwiInlUsmRJ6tSpQ1paWo7XWCwW/Pz8sl0hKVJsZd5+3aAnWPSWRkSkKGjVqpXDcUJCAi+++CILFizg+PHjpKenc+HChVy3Fjdp0iTj9yVKlKBUqVKcPHkyx+tDQkIyko8AFStWzLg+Li6OEydO0KZNm4znLRYLLVu2vGKpnE2bNvHvv/9mrKIEY4eLzWZj//79NGjQINvXDEadystfw/bt2/Hz86Nt27YZ58qWLUu9evXYvn17juMKGv1vLSIFw8GD0L49nDplHF+efASwWo1fly+H1q1h9WqoXDl/YxQR8SCLxYLFYvF1GFKsFPI1kNluv+7li0h
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 1600x800 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrażenie regularyzacyjne** – zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów – mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przedstawiona tu metoda regularyzacji to tzw. metoda L2 (*ridge*). Istnieją również inne metody regularyzacji, które charakteryzują się trochę innymi własnościami, np. L2 (*lasso*) lub *elastic net*. Więcej na ten temat można przeczytać np. tu:\n",
"* [L1 and L2 Regularization Methods](https://towardsdatascience.com/l1-and-l2-regularization-methods-ce25e7fc831c)\n",
"* [Ridge and Lasso Regression: L1 and L2 Regularization](https://towardsdatascience.com/ridge-and-lasso-regression-a-complete-guide-with-python-scikit-learn-e20e34bcbf0b)\n",
"* [Elastic Net Regression](https://towardsdatascience.com/elastic-net-regression-from-sklearn-to-tensorflow-3b48eee45e91)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej – funkcja kosztu\n",
"\n",
"$$\n",
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} h_\\theta(x^{(i)}) - y^{(i)} \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ – parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej – gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej – funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej – gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 36,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0 / m * -np.sum(\n",
" np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0\n",
" ) + lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ_(h, theta, X, y, lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" g[1:] += lamb / m * theta[1:]\n",
" return g\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 37,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(\n",
" min=0.0, max=0.5, step=0.005, value=0.01, description=r\"$\\lambda$\", width=300\n",
")\n",
"\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 38,
2022-11-24 07:22:33 +01:00
"metadata": {
"scrolled": false,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2023-04-13 09:46:29 +02:00
"model_id": "3bab505ae15548059e1037eda3e1d9f8",
2022-11-24 07:22:33 +01:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
2023-04-13 09:46:29 +02:00
"execution_count": 38,
2022-11-24 07:22:33 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 39,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(Lambda, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(r\"$\\lambda$\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 40,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2023-04-13 09:46:29 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKtCAYAAACuZBksAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACCYElEQVR4nOzdeZibZb3/8U+SmWT2zL50OmVaKG2hdKEbZVGUYgFFUREUZBOKGwhWFHpkEVErP5aDWrBHZFU8VBERDpWtAkIpLVuhlO77NntnMpNZkkny+yOZTJLJTGdLnszk/bquXMlz53mefAMtdD793vdt8vl8PgEAAAAAAADAMDMbXQAAAAAAAACA0YnwEQAAAAAAAEBMED4CAAAAAAAAiAnCRwAAAAAAAAAxQfgIAAAAAAAAICYIHwEAAAAAAADEBOEjAAAAAAAAgJggfAQAAAAAAAAQE4SPAAAAAAAAAGKC8BEAAAAAAABATBgePt5///2qrKxUWlqa5s2bp3Xr1vV5/n333adJkyYpPT1dFRUV+uEPf6j29vY4VQsAAAAAAACgvwwNH1esWKHFixfrtttu0/vvv6/p06dr4cKFqqmpiXr+X/7yF91000267bbbtGnTJj300ENasWKF/uu//ivOlQMAAAAAAAA4EpPP5/MZ9eHz5s3TnDlztGzZMkmS1+tVRUWFrr32Wt100009zr/mmmu0adMmrVq1Kjj2ox/9SGvXrtWbb74Zt7oBAAAAAAAAHFmKUR/scrn03nvvacmSJcExs9msBQsWaM2aNVGvOfnkk/XnP/9Z69at09y5c7Vz506tXLlSl1xySa+f09HRoY6OjuCx1+tVQ0ODCgoKZDKZhu8LAQAAAAAAAEnA5/OpublZY8aMkdnc98Rqw8LHuro6eTwelZSUhI2XlJRo8+bNUa+56KKLVFdXp1NPPVU+n0+dnZ36zne+0+e066VLl+r2228f1toBAAAAAACAZLdv3z6NHTu2z3MMCx8H47XXXtOvfvUrPfDAA5o3b562b9+u6667TnfccYduueWWqNcsWbJEixcvDh43NTVp3Lhx2rdvn3JycuJVemJ6+Tbp3Ye6j4//qvTF3xhXDwAAAAAAABKew+FQRUWFsrOzj3iuYeFjYWGhLBaLqqurw8arq6tVWloa9ZpbbrlFl1xyia666ipJ0gknnCCn06mrr75aP/3pT6O2edpsNtlsth7jOTk5hI9lEyRbyNTzzlop2f+ZAAAAAAAAoF/6s6ShYbtdW61WzZo1K2zzGK/Xq1WrVmn+/PlRr2ltbe0RMFosFkn+ueYYoJzy8OOm/cbUAQAAAAAAgFHJ0GnXixcv1mWXXabZs2dr7ty5uu++++R0OnXFFVdIki699FKVl5dr6dKlkqRzzz1X9957r2bOnBmcdn3LLbfo3HPPDYaQGAB7xJx8x0HJ55PYiAcAAAAAAADDwNDw8cILL1Rtba1uvfVWVVVVacaMGXrhhReCm9Ds3bs3rNPx5ptvlslk0s0336wDBw6oqKhI5557rn75y18a9RVGtsjOR0+H5KyTsoqMqQcAAAAAAACjismXZPOVHQ6H7Ha7mpqaWPPR65HuKJJ8nu6xq1+Txsw0rCQAAAAAANA7n8+nzs5OeTyeI58MDEFqamqvM40Hkq+NqN2uMczMFim7THKErPXYdIDwEQAAAACABORyuXTo0CG1trYaXQqSgMlk0tixY5WVlTWk+xA+Jjt7eXj46DhgXC0AAAAAACAqr9erXbt2yWKxaMyYMbJarf3aaRgYDJ/Pp9raWu3fv18TJ04c0l4rhI/Jzj5W2re2+5gdrwEAAAAASDgul0ter1cVFRXKyMgwuhwkgaKiIu3evVtut3tI4aP5yKdgVIvcdIbwEQAAAACAhBW6MS8QS8PVWcuv2GRnHxt+zLRrAAAAAAAADBPCx2TXo/OR8BEAAAAAAADDg/Ax2dkjwsfmQ5LXY0wtAAAAAAAAR1BZWan77ruv3+e/9tprMplMamxsjFlN6B0bziS7nIhp1z6P1FzVM5QEAAAAAAAYhNNPP10zZswYUGDYl3feeUeZmZn9Pv/kk0/WoUOHZLfbh+XzMTCEj8kus1Cy2CRPR/eY4wDhIwAAAAAACczr9elwq8vQGvIyrDKbh2dTEp/PJ4/Ho5SUI0dVRUVFA7q31WpVaWnpYEuLKZfLJavVGjbm8XhkMpkGvLnQYK+LNcLHZGcySTljpMO7usea9ksVc42rCQAAAAAA9Olwq0uzfvGKoTW8d/MCFWTZ+jzn8ssv1+uvv67XX39dv/nNbyRJu3bt0u7du/WZz3xGK1eu1M0336wNGzbopZdeUkVFhRYvXqy3335bTqdTU6ZM0dKlS7VgwYLgPSsrK3X99dfr+uuvl+TflfnBBx/U888/rxdffFHl5eW655579MUvflGSf9r1Zz7zGR0+fFi5ubl69NFHdf3112vFihW6/vrrtW/fPp166ql65JFHVFZWJknq7OzU4sWL9fjjj8tiseiqq65SVVWVmpqa9Mwzz/T6fd98800tWbJE7777rgoLC/XlL39ZS5cuDXZqVlZW6sorr9S2bdv0zDPP6Ctf+YpOP/10XX/99Xr88cd10003aevWrdq+fbvsdruuu+46Pffcc+ro6NCnP/1p/fa3v9XEiRMlKfg9Iq+rrKwczL/OmEmsKBTGYMdrAAAAAAAQA7/5zW80f/58LVq0SIcOHdKhQ4dUUVERfP+mm27Sr3/9a23atEnTpk1TS0uLzjnnHK1atUoffPCBzjrrLJ177rnau3dvn59z++2364ILLtBHH32kc845RxdffLEaGhp6Pb+1tVV33323/vSnP+k///mP9u7dqxtuuCH4/p133qknnnhCjzzyiFavXi2Hw9Fn6ChJO3bs0FlnnaWvfvWr+uijj7RixQq9+eabuuaaa8LOu/vuuzV9+nR98MEHuuWWW4L13HnnnfrjH/+ojRs3qri4WJdffrneffddPfvss1qzZo18Pp/OOeccud3usO8ReV2iofMRPcNHdrwGAAAAAADDwG63y2q1KiMjI+rU55///Oc688wzg8f5+fmaPn168PiOO+7QP/7xDz377LM9QrxQl19+ub7xjW9Ikn71q1/pt7/9rdatW6ezzjor6vlut1vLly/X0UcfLUm65ppr9POf/zz4/u9+9zstWbJEX/7ylyVJy5Yt08qVK/v8rkuXLtXFF18c7MicOHGifvvb3+rTn/60fv/73ystLU2S9NnPflY/+tGPgte98cYbcrvdeuCBB4Lffdu2bXr22We1evVqnXzyyZKkJ554QhUVFXrmmWf0ta99Lfg9Qq9LRISPkHIi1nd07DemDgAAAAAAkFRmz54ddtzS0qKf/exnev7553Xo0CF1dnaqra3tiJ2P06ZNC77OzMxUTk6Oampqej0/IyMjGDxKUllZWfD8pqYmVVdXa+7c7iXpLBaLZs2aJa/X2+s9P/zwQ3300Ud64okngmM+n09er1e7du3SlClTon5nyb8uZeh32LRpk1JSUjRv3rzgWEFBgSZNmqRNmzb1el0iInxEz81lmggfAQAAAABIZHkZVr1384IjnxjjGoYqctfqG264QS+//LLuvvtuHXPMMUpPT9f5558vl6vvzXVSU1PDjk0mU59BYbTzfT7fAKsP19LSom9/+9v6wQ9+0OO9cePGBV9H26k7PT1dJtPAN+8Z7HXxRPgIKYdp1wAAAAAAjCRms+mIm70kCqvVKo/H069zV69ercsvvzw43bmlpUW7d++OYXU92e12lZSU6J133tGnPvUpSf6dpN9//33NmDGj1+tOPPFEffLJJzrmmGOGXMOUKVPU2dmptWvXBqdd19fXa8uWLTruuOOGfP94YsMZ9Ox8dNZInR3G1AIAAAAAAEaVyspKrV27Vrt371ZdXV2fHYkTJ07U008/rfXr1+vDDz/URRdd1Of5sXLttddq6dKl+uc//6ktW7bouuuu0+HDh/vsMrzxxhv11ltv6ZprrtH69eu1bds2/fOf/+xzrcreTJw4UV/60pe0aNEivfnmm/rwww/1zW9+U+X
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 41,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=0.01,\n",
" trainsetsize=m,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(M, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(\"trainset size\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
2023-04-13 09:46:29 +02:00
"execution_count": 42,
2022-11-24 07:22:33 +01:00
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2023-04-13 09:46:29 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKnCAYAAAAP/zpKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACaRklEQVR4nOzdd3hUZfrG8XsmvYdUWpAqRaWIyIJiQRBXZcUuqIgKVmysivxWsa2y1tUVFXtbELE3FAFBBVlRFGz0FmpIgPQ+M78/TjKTQwopc3JSvp/rykXy5JyZJ4iBufO87+vweDweAQAAAAAAAICfOe1uAAAAAAAAAEDLRPgIAAAAAAAAwBKEjwAAAAAAAAAsQfgIAAAAAAAAwBKEjwAAAAAAAAAsQfgIAAAAAAAAwBKEjwAAAAAAAAAsQfgIAAAAAAAAwBKBdjfQ2Nxut3bv3q2oqCg5HA672wEAAAAAAACaFY/Ho5ycHLVv315OZ82zja0ufNy9e7dSUlLsbgMAAAAAAABo1nbs2KGOHTvWeE2rCx+joqIkGb850dHRNncDAAAAAAAANC/Z2dlKSUnx5mw1aXXhY/lS6+joaMJHAAAAAAAAoJ5qs6UhB84AAAAAAAAAsAThIwAAAAAAAABLED4CAAAAAAAAsESr2/MRAAAAAACgufJ4PCotLZXL5bK7FbRwQUFBCggIaPDjED4CAAAAAAA0A8XFxdqzZ4/y8/PtbgWtgMPhUMeOHRUZGdmgxyF8BAAAAAAAaOLcbre2bt2qgIAAtW/fXsHBwbU6aRioD4/Ho/T0dO3cuVM9evRo0AQk4SMAAAAAAEATV1xcLLfbrZSUFIWHh9vdDlqBxMREbdu2TSUlJQ0KHzlwBgAAAAAAoJlwOoly0Dj8NVnLn1gAAAAAAAAAliB8BAAAAAAAQLPRuXNnPfXUU7W+funSpXI4HMrMzLSsJ1SPPR8BAAAAAABgmVNOOUX9+/evU2BYkx9//FERERG1vn7o0KHas2ePYmJi/PL8qBsmHwEAAAAAAGArj8ej0tLSWl2bmJhYp0N3goOD1bZt2yZ5OnhxcXGlmsvlktvtrvNj1fc+qxE+AgAAAAAANDNut0f7c4tsfXO7PYftc8KECfrmm2/09NNPy+FwyOFwaNu2bd6l0F988YUGDhyokJAQLVu2TJs3b9Y555yj5ORkRUZGatCgQVq0aJHpMQ9ddu1wOPTyyy/r3HPPVXh4uHr06KFPPvnE+/lDl12//vrrio2N1YIFC9S7d29FRkbqjDPO0J49e7z3lJaW6uabb1ZsbKzi4+M1depUXXHFFRozZkyNX++yZcs0bNgwhYWFKSUlRTfffLPy8vJMvT/44IMaP368oqOjdc0113j7+eSTT9SnTx+FhIQoNTVVBw8e1Pjx49WmTRuFh4frr3/9qzZu3Oh9rOrua2pYdg0AAAAAANDMHMwv1sB/Ljr8hRZadfcIxUeG1HjN008/rQ0bNujoo4/WAw88IMmYXNy2bZsk6a677tLjjz+url27qk2bNtqxY4fOPPNMPfTQQwoJCdGbb76p0aNHa/369erUqVO1z3P//ffr0Ucf1WOPPaZnnnlGl156qbZv3664uLgqr8/Pz9fjjz+ut956S06nU5dddpluv/12zZ49W5L0yCOPaPbs2XrttdfUu3dvPf300/roo4906qmnVtvD5s2bdcYZZ+if//ynXn31VaWnp2vy5MmaPHmyXnvtNe91jz/+uKZPn657771XkvTdd98pPz9fjzzyiF5++WXFx8crKSlJY8eO1caNG/XJJ58oOjpaU6dO1Zlnnqk///xTQUFB3q/j0PuaGsJHAAAAAAAAWCImJkbBwcEKDw9X27ZtK33+gQce0MiRI70fx8XFqV+/ft6PH3zwQX344Yf65JNPNHny5GqfZ8KECRo7dqwk6eGHH9Z//vMfrVy5UmeccUaV15eUlGjWrFnq1q2bJGny5MnecFSSnnnmGU2bNk3nnnuuJGnmzJmaP39+jV/rjBkzdOmll+rWW2+VJPXo0UP/+c9/dPLJJ+v5559XaGioJGn48OH6+9//7r3vu+++U0lJiZ577jnv114eOi5fvlxDhw6VJM2ePVspKSn66KOPdOGFF3q/jor3NUWEjwAAAAAAALDFcccdZ/o4NzdX9913nz7//HPt2bNHpaWlKigoOOxy4r59+3rfj4iIUHR0tPbt21ft9eHh4d7gUZLatWvnvT4rK0tpaWk6/vjjvZ8PCAjQwIEDa9xTcc2aNfr111+905OSsZel2+3W1q1b1bt37yq/ZsnYl7Li17B27VoFBgZq8ODB3lp8fLx69uyptWvXVntfU0T4CAAAAAAAAFscemr17bffroULF+rxxx9X9+7dFRYWpgsuuKDKg1kqKl+GXM7hcNQYFFZ1vcdz+D0sa5Kbm6trr71WN998c6XPVVwyXtVJ3WFhYfU6EKe+9zUmwkcAAAAAAIBmpk14sFbdPcL2HmojODhYLperVtcuX75cEyZM8C53zs3N9e4P2VhiYmKUnJysH3/8USeddJIk4yTpn3/+Wf3796/2vmOPPVZ//vmnunfv3uAeevfurdLSUv3www/eZdf79+/X+vXr1adPnwY/fmMifAQAAAAAAGhmnE7HYQ97aSo6d+6sH374Qdu2bVNkZGS1h8BIxj6JH3zwgUaPHi2Hw6F77rmnxglGq9x0002aMWOGunfvrl69eumZZ57RwYMHa5wynDp1qv7yl79o8uTJmjhxoiIiIvTnn39q4cKFmjlzZp2ev0ePHjrnnHM0adIkvfDCC4qKitJdd92lDh066Jxzzmnol9eonHY3AAAAAAAAgJbr9ttvV0BAgPr06aPExMQa92988skn1aZNGw0dOlSjR4/WqFGjdOyxxzZit4apU6dq7NixGj9+vIYMGaLIyEiNGjXKe2hMVfr27atvvvlGGzZs0LBhwzRgwABNnz5d7du3r1cPr732mgYOHKizzz5bQ4YMkcfj0fz58ystGW/qHJ6GLmhvZrKzsxUTE6OsrCxFR0fb3Q4AAAAAAMBhFRYWauvWrerSpUuNARis4Xa71bt3b1100UV68MEH7W6nUdT0Z64u+RrLrgEAAAAAAIAKtm/frq+++konn3yyioqKNHPmTG3dulXjxo2zu7Vmh2XXQG1lpkpblkqF2XZ3AsBuBZnSga1Sac0n7gEAAABonpxOp15//XUNGjRIJ5xwgn777TctWrRIvXv3tru1ZofJR6A2/vhQ+vB6qbRAiusmXfGpFNPB7q4AWK20WMrYIO37U0r7w3jb96eUvcv4fGIv6dL3pNgUe/sEAAAA4FcpKSlavny53W20CISPwOHsWCl9cK3kKjI+PrBZ+ug66fKPJSfDw0CL4PFIWTsrh4wZGyR3afX3pa+T3rtKuvILKYC/UgEAAADgULxSAmqSmSrNHecLHstt/Vb633PS0Mn29AWg/gqzpH1rzSFj2p9SUVb9Hm/nSunbx6RTp/m3TwAAAABoAQgfgeoUZktzLpby0qv+/OL7pa6nSG2PbtS2ANSSq0Tav6lyyJiV6v/n+vZRqdtwqdNg/z82AAAAADRjhI9AVdwu6f2rjbCiOq5i6YNJ0qQlUlBo9dcBsJbHI+XsMYLFtN99IWPGeuP/U3+ISZGS+kjJRxlvAUHSu1dKHldZD27pg4nSdcuk0Bj/PCcAAAAAtACEj0BVvrpb2viVuZbQU+oyTPrxZV9t35/S4gekMx5u3P6A1qoot2zJdIWQMe13qTDTP48fEl0hZOwjJR0lJfWWwmIrX3vKNGnJP30fZ6ZKn98unf+Sf3oBAAAAgBaA8BE41E+vGvs5VhQWJ42bK0V3kHb8IO39zfe5/z0r9RgpdTu1cfsEWjJXqXRgS+WQMXO7fx7fGSjF9zCHjMlHSTEdJYejdo8xbIq0ebGUusJX+22e8f2g70X+6RMAAAAAmjnCR6CizUuMyaWKnEHSJbOluK7Gx+e9LL14slRa6Lvmo+ul67+XwuMar1egJfB4pNx9lUPG9PWVD3qqr6j2lUPGhB5SYEjDHtcZIJ33ovT8iebDaj7/u5R
2022-11-24 07:22:33 +01:00
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()\n"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}