uczenie-maszynowe/wyk/06_Problem_nadmiernego_dopasowania.ipynb

1829 lines
587 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe\n",
"# 6. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.1. Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wprowadzenie: wybór cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Niech naszym zadaniem będzie przewidzieć cenę działki o kształcie prostokąta.\n",
"\n",
"Jakie cechy wybrać?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Możemy wybrać dwie cechy:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ szerokość działki, $x_2$ długość działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"...albo jedną:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
" * $x_1$ powierzchnia działki:\n",
"$$ h_{\\theta}(\\vec{x}) = \\theta_0 + \\theta_1 x_1 $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Można też zauważyć, że cecha „powierzchnia działki” powstaje przez pomnożenie dwóch innych cech: długości działki i jej szerokości."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"**Wniosek:** możemy tworzyć nowe cechy na podstawie innych poprzez wykonywanie na nich różnych operacji matematycznych."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regresja wielomianowa"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"W regresji wielomianowej będziemy korzystać z cech, które utworzymy jako potęgi cech wyjściowych."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne funkcje\n",
"\n",
"\n",
"def cost(theta, X, y):\n",
" \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n",
" return J.item()\n",
"\n",
"\n",
"def gradient(theta, X, y):\n",
" \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n",
" return 1.0 / len(y) * (X.T * (X * theta - y))\n",
"\n",
"\n",
"def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-5):\n",
" \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y)\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if abs(prev_cost - current_cost) > 10**15:\n",
" print(\"Algorithm does not converge!\")\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta])\n",
" return theta, logs\n",
"\n",
"\n",
"def plot_data(X, y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c=\"r\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" plt.ylim(y.min() - 1, y.max() + 1)\n",
" plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"\n",
"def plot_fun(fig, fun, X):\n",
" \"\"\"Wykres funkcji `fun`\"\"\"\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = fun(Arg)\n",
" return ax.plot(Arg, Val, linewidth=\"2\")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 3,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych (mieszkania) przy pomocy biblioteki pandas\n",
"\n",
"alldata = pandas.read_csv(\n",
" \"data_flats.tsv\", header=0, sep=\"\\t\", usecols=[\"price\", \"rooms\", \"sqrMetres\"]\n",
")\n",
"data = np.matrix(alldata[[\"sqrMetres\", \"price\"]])\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 0:n]\n",
"Xn /= np.amax(Xn, axis=0)\n",
"Xn2 = np.power(Xn, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"\n",
"X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2), axis=1)).reshape(m, 2 * n + 1)\n",
"X3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn, Xn2, Xn3), axis=1)).reshape(\n",
" m, 3 * n + 1\n",
")\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Postać ogólna regresji wielomianowej:\n",
"\n",
"$$ h_{\\theta}(x) = \\sum_{i=0}^{n} \\theta_i x^i $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 4,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkcja regresji wielomianowej\n",
"\n",
"\n",
"def h_poly(Theta, x):\n",
" \"\"\"Funkcja wielomianowa\"\"\"\n",
" return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n",
"\n",
"\n",
"def polynomial_regression(theta):\n",
" \"\"\"Funkcja regresji wielomianowej\"\"\"\n",
" return lambda x: h_poly(theta, x)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Najprostszym przypadkiem regresji wielomianowej jest funkcja kwadratowa:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kwadratowa:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
2024-04-18 09:44:59 +02:00
"[<matplotlib.lines.Line2D at 0x7f737ce1baf0>]"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEnUlEQVR4nO3dd3hUZd7/8c+UNCAJhpIQCB0B6SAg0hUbrgtixbp2FAuwxeW3++w+W1zXZxXcIrDrqqy62FBAxAYISKgConTpNQk9ISF1Zn5/HDIppMwkM3OmvF/XNVdyJmdmvimE88l939/b4nK5XAIAAAAAeMxqdgEAAAAAEGoIUgAAAADgJYIUAAAAAHiJIAUAAAAAXiJIAQAAAICXCFIAAAAA4CWCFAAAAAB4iSAFAAAAAF4iSAEAAACAlwhSAAAAAOCliA5SX3/9tW666SalpqbKYrFo/vz5Xj+Hy+XSiy++qEsvvVQxMTFq2bKlnnvuOd8XCwAAACBo2M0uwEx5eXnq1auXHnzwQY0bN65Oz/HMM8/oyy+/1IsvvqgePXro9OnTOn36tI8rBQAAABBMLC6Xy2V2EcHAYrFo3rx5Gjt2rPu+wsJC/epXv9I777yjs2fPqnv37nrhhRc0YsQISdKOHTvUs2dPbd26VZ07dzancAAAAAABF9FT+2rz5JNPas2aNXr33Xf1/fff67bbbtP111+v3bt3S5IWLlyo9u3b65NPPlG7du3Utm1bPfzww4xIAQAAAGGOIFWNQ4cO6Y033tAHH3ygoUOHqkOHDvrZz36mIUOG6I033pAk7du3TwcPHtQHH3ygN998U7Nnz9bGjRt16623mlw9AAAAAH+K6DVSNdmyZYscDocuvfTSCvcXFhaqSZMmkiSn06nCwkK9+eab7vNee+019evXT7t27WK6HwAAABCmCFLVyM3Nlc1m08aNG2Wz2Sp8rFGjRpKkFi1ayG63VwhbXbt2lWSMaBGkAAAAgPBEkKpGnz595HA4dPz4cQ0dOrTKcwYPHqySkhLt3btXHTp0kCT98MMPkqQ2bdoErFYAAAAAgRXRXftyc3O1Z88eSUZwmjZtmkaOHKmkpCS1bt1a99xzj1atWqWXXnpJffr00YkTJ7R06VL17NlTN954o5xOp/r3769GjRrp5ZdfltPp1MSJE5WQkKAvv/zS5M8OAAAAgL9EdJBavny5Ro4cedH9999/v2bPnq3i4mL98Y9/1JtvvqmjR4+qadOmuuKKK/S73/1OPXr0kCQdO3ZMTz31lL788ks1bNhQN9xwg1566SUlJSUF+tMBAAAAECARHaQAAAAAoC5ofw4AAAAAXiJIAQAAAICXIq5rn9Pp1LFjxxQfHy+LxWJ2OQAAAAD8yOVy6dy5c0pNTZXV6rtxpIgLUseOHVNaWprZZQAAAAAIoMOHD6tVq1Y+e76IC1Lx8fGSjC9kQkKCJKmg2KGRLy7TuQKHYqKsWv6zEYqPjTKzTACITPn5UosWkid9kCwWKSNDiovzf11AmJq9ar9e/NLYA3PC8A568qqOJlcE+F5OTo7S0tLcOcBXIi5IlU7nS0hIcAepBEk3D+ykt9ceUrGkVQfP6/b+jFoBQMAlJEhjx0oLF0olJdWfZ7dLY8ZIyckBKw0IR5//kCNrTANJ0p2DL1VCQiOTKwL8x9fLemg2ccGt/cqC09yNR0ysBAAi3JQpksNR8zkOhzR5cmDqAcLUrsxz2p6RI0nqldZY7ZsRogBvEKQu6NUqUR2bG79A1h84rQMn80yuCAAi1JAh0owZxtQ9e6WJE3a7cf+MGdLgwebUB4SJ+ZuPut+/uXeqiZUAoYkgdYHFYtGt/coWn320iVEpADDNhAnSypXG9L3SDktWq3G8cqXxcQB15nS6tOBbI0jZrBb9qBdBCvAWQaqcm/u0lPXC1MkPNx2V0+nBYmcAgH8MHizNnSvl5kqZmcbbuXMZiQJ8YN3+0zqWXSBJGtapqZo2ijG5IiD0EKTKSU6I1bBLm0mSjp7N19p9p0yuCACguDijqQTd+QCfmf9t2bS+sX1amlgJELoIUpWUn95H0wkAABBuCood+nRLhiSpYbRN116WYnJFQGgiSFUyqmuyEmKNxc2fbs3QuYJikysCAADwna92Hte5QmN7geu7t1BctM3kioDQRJCqJDbKph9f6FxTUOzUZ1syTa4IAADAd+aVm9Z3M9P6gDojSFWBPaUAAEA4OpNXpOW7jkuSmsfHaFCHJiZXBIQuglQV2FMKAACEo0VbMlTsMLoSj+mdKltpu2IAXiNIVYE9pQAAQDiaR7c+wGcIUtVgTykAABBODp06r40Hz0iSLk1upMtaJJhcERDaCFLVYE8pAAAQTuZvLt9kopUsFqb1AfVBkKoBe0oBAIBw4HK5KmzCO+ZCh2IAdUeQqsGorslKjIuSxJ5SAAAgdH1/JFv7LjTPuqJ9klIbx5lcERD6CFI1iI2y6ce9yvaUKt0FHAAAIJSwdxTgewSpWjC9DwAAhLJih1MLvzsmSYq2W3V99xYmVwSEB4JULXq2SlSnC3tKfXPgDHtKAQCAkJK++6RO5RVJkkZ1be5etgCgfghStai8p9SH7CkFAABCSIW9o3ozrQ/wFYKUByrsKbXxCHtKAQCAkJBbWKIvt2dKkho3iNKIzs1NrggIHwQpDzRPiNXwC3tKHcsu0Br2lAIAACFg4XfHVFDslCTd2KOFou1c+gG+wr8mD93aL839Pk0nAABAKHh3/SH3+3f0T6vhTADeIkh56OpyizM/Y08pAAAQ5LYdy9Z3R7IlSd1SE9SjZaLJFQHhhSDlIfaUAgAAoeTd9Yfd7985oLUsFouJ1QDhhyDlBfaUAgAAoSC/yKH5m41ufXFRNo3pnWpyRUD4IUh5gT2lAABAKFi0JUPnCkokSTf2bKGEWPaOAnzN1CD1/PPPq3///oqPj1fz5s01duxY7dq1q8bHzJ49WxaLpcItNjY2IPWypxQAAAgF5ZtMjB9AkwnAH0wNUitWrNDEiRO1du1aLV68WMXFxbr22muVl1fzSE9CQoIyMjLct4MHDwaoYvaUAgAAwW131jltOHhGktSpeSP1bX2JyRUB4clu5ot//vnnFY5nz56t5s2ba+PGjRo2bFi1j7NYLEpJSfF3eVUq3VNq2a4T7j2lBndsakotAAAAlb37DU0mgEAIqjVS2dlGi86kpKQaz8vNzVWbNm2UlpamMWPGaNu2bdWeW1hYqJycnAq3+mJPKQAAEIwKih366MLSg2i7VeP6tDS5IiB8BU2QcjqdmjRpkgYPHqzu3btXe17nzp31+uuva8GCBXr77bfldDp15ZVX6siRqgPN888/r8TERPctLa3+84TZUwoAAASjL7Zl6sx547rkhu4puqRhtMkVAeEraILUxIkTtXXrVr377rs1njdo0CDdd9996t27t4YPH66PPvpIzZo10z//+c8qz586daqys7Pdt8OHD1d5njdiy7URLSh2atH37CkFAADMV2HvqP6tTawECH9BEaSefPJJffLJJ1q2bJlatWpV+wPKiYqKUp8+fbRnz54qPx4TE6OEhIQKN19gTykAABBMDpzM05p9pyRJ7Zo21BXta14qAaB+TA1SLpdLTz75pObNm6evvvpK7dq18/o5HA6HtmzZohYtWvihwur1aJmoS5ONPaU2HDyj/ewpBQAATFS+ycQd/dNoMgH4malBauLEiXr77bc1Z84cxcfHKzMzU5mZmcrPz3efc99992nq1Knu49///vf68ssvtW/fPm3atEn33HOPDh48qIcffjigtV+0pxSjUgAAwCTFDqd7hozdatEtfb2b4QPAe6YGqZkzZyo7O1sjRoxQixYt3Lf33nvPfc6hQ4eUkVG2BunMmTN65JFH1LVrV40ePVo5OTlavXq1LrvssoDXP7Z3S9kubCr14aYjcrCnFAAAMMHSHVk6mVsoSbrmsmQ1i48xuSIg/Jm6j5TLVXvwWL58eYXj6dOna/r06X6qyDule0p9tfO4MrILtGbvKQ3pxJ5SAAAgsN5ZX3HvKAD+FxTNJkJZxaYT9e8ICAAA4I0jZ87
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
"theta, logs = gradient_descent(cost, gradient, theta_start, X2, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Innym szczególnym przypadkiem regresji wielomianowej jest funkjca sześcienna:"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja sześcienna:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 6,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 397519.38046962]\n",
" [-841341.14146733]\n",
" [2253713.97125102]\n",
" [-244009.07081946]]\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAH+CAYAAACWZz+7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACEN0lEQVR4nO3dd3yV9fn/8fcZWRAChJEQCHvvoSAbFAcuhlrFWatWFFsRW1t/9Wtrl61V0FoB2zrqHiggThBkhKks2XtmsRMSMs85vz9ucpJAxjnJObnPeD0fjzyS+5z7nHNlEO4r1+dzXRaXy+USAAAAAMBjVrMDAAAAAIBgQyIFAAAAAF4ikQIAAAAAL5FIAQAAAICXSKQAAAAAwEskUgAAAADgJRIpAAAAAPASiRQAAAAAeIlECgAAAAC8RCIFAAAAAF4K60Rq+fLluuGGG5SUlCSLxaJ58+Z5/Rwul0vPP/+8OnfurKioKLVs2VJ/+ctffB8sAAAAgIBhNzsAM+Xm5qpPnz762c9+pokTJ9boOR599FEtXLhQzz//vHr16qVTp07p1KlTPo4UAAAAQCCxuFwul9lBBAKLxaK5c+dq/Pjx7tsKCgr0u9/9Tu+//77OnDmjnj176u9//7tGjRolSdqxY4d69+6trVu3qkuXLuYEDgAAAKDOhfXSvuo88sgjWr16tT744AP9+OOPuuWWW3TNNddoz549kqQFCxaoffv2+vzzz9WuXTu1bdtW999/PxUpAAAAIMSRSFXi8OHDeuONN/Txxx9r+PDh6tChg371q19p2LBheuONNyRJ+/fv16FDh/Txxx/rrbfe0ptvvqn169fr5ptvNjl6AAAAAP4U1nukqrJlyxY5HA517ty53O0FBQVq0qSJJMnpdKqgoEBvvfWW+7zXXntNAwYM0K5du1juBwAAAIQoEqlK5OTkyGazaf369bLZbOXui42NlSS1aNFCdru9XLLVrVs3SUZFi0QKAAAACE0kUpXo16+fHA6Hjh07puHDh1d4ztChQ1VcXKx9+/apQ4cOkqTdu3dLktq0aVNnsQIAAACoW2HdtS8nJ0d79+6VZCRO06dP1+jRoxUfH6/WrVvrzjvv1MqVK/XCCy+oX79+On78uBYvXqzevXvruuuuk9Pp1KWXXqrY2Fi9+OKLcjqdmjJliuLi4rRw4UKTPzsAAAAA/hLWidTSpUs1evToi26/55579Oabb6qoqEh//vOf9dZbbyk1NVVNmzbVZZddpmeeeUa9evWSJKWlpekXv/iFFi5cqPr162vs2LF64YUXFB8fX9efDgAAAIA6EtaJFAAAAADUBO3PAQAAAMBLJFIAAAAA4KWw69rndDqVlpamBg0ayGKxmB0OAAAAAD9yuVw6e/askpKSZLX6ro4UdolUWlqakpOTzQ4DAAAAQB06cuSIWrVq5bPnC7tEqkGDBpKML2RcXJwpMSzclqFpH22WJN08oJX+cGMPU+IAgICTlye1aCF50gfJYpHS06WYGP/HBYS5Jz/5UQt+TJck/WlcD03o77uLUcDfsrOzlZyc7M4DfCXsEqmS5XxxcXGmJVLXDqinp77cp/wip5YfzFH92AayWVlmCACKi5PGj5cWLJCKiys/z26Xxo2TEhLqLDQgXJ3KLdSivWdljaqnhjERunVoF0VH2MwOC/Car7f10GzCBPUi7RrZuZkk6UROob4/eMrkiAAggEybJjkcVZ/jcEiPPVY38QBh7qMfjqjQ4ZQk3TKgFUkUcB6JlEnG9mzh/vjrrRkmRgIAAWbYMGnmTGPpnv2ChRN2u3H7zJnS0KHmxAeEEafTpXfXHnIf33FZGxOjAQILiZRJLu/WXBE2o7z49dYMOZ3MRQYAt8mTpRUrjOV7JR2WrFbjeMUK434Afrdsz3EdOZUnSRreqanaNa1vckRA4Ai7PVKBIi46QsM6NtV3u44rIztfm46eUf/Wjc0OCwACx9ChxltenpSdbeyforEEUKfeWV1ajbqTahRQDhUpE7G8DwA8EBNjNJUgiQLq1JFT57Rk1zFJUouG0bqia3OTIwICC4mUia7snuDu1vfV1nS5PGn3CwAAUAfeX3fYPYng9oGtZbdx2QiUxb8IEzWuH6nL2sdLko6cytO2tGyTIwIAAJAKih368PsjkiS71aJbByabHBEQeEikTHYNy/sAAECA+Xprhk7mFkqSru6ZqOYNok2OCAg8JFImu7pHgkpmg321Nd3cYAAAACS9s6a0ycRdNJkAKkQiZbLmDaJ1SRujW9++47nak3nW5IgAAEA425mRre8PnpYkdWoeq0Ht4k2OCAhMJFIBoOzyvq9Y3gcAAExUthp152VtZClZOgOgHBKpAHBNz0T3xyRSAADALGfzizR3Q6okqV6kTRP6tzQ5IiBwkUgFgJaNYtSnVUNJ0o70bB06mWtyRAAAIBzN25iq3EKHJGl8v5aKi44wOSIgcJFIBQiW9wEAADO5XC69s+aw+/jOQTSZAKpCIhUgxrK8DwAAmOj7g6e163zTqwFtGqt7UpzJEQGBjUQqQLRtWl9dExtIkjYfOaO0M3kmRwQAAMLJ27Q8B7xCIhVAxjKcFwAAmOD42QJ9fX6eZXz9SI3tlVjNIwCQSAWQsr+0SKQAAEBd+eiHIypyuCRJP7kkWVF2m8kRAYGPRCqAdGoeq/bN6kuSvj90SsfO5pscEQAACHUOp0vvnl/WZ7FIdwxqbXJEQHAgkQogFovF3XTC5ZIWbss0OSIAABDqluw8prQs44+3ozo3U3J8PZMjAoIDiVSAYZ8UAACoS++UbTIxmCYTgKdIpAJMj6Q4tWocI0lavf+kTucWmhwRAAAIVYdO5mrZ7uOSpFaNYzSyc3OTIwKCB4lUgCm7vM/hdGnRDpb3AQAA/3h3bekA3jsGtZHNajExGiC4kEgFoGtY3gcAAPwsv8ihj344IkmKtFn1k0tamRwREFxIpAJQv+RGSoiLkiSl7Dmhs/lFJkcEAABCzRc/puvMOeMa49peiWoSG2VyREBwIZEKQFarRdf0MJb3FTqcWrLzmMkRAQCAUPM2TSaAWjE1kXr22Wd16aWXqkGDBmrevLnGjx+vXbt2VfmYN998UxaLpdxbdHR0HUVcd8ou7/tqC8v7AACA72xNzdKmI2ckSV0TG6h/68bmBgQEIVMTqWXLlmnKlClas2aNFi1apKKiIl111VXKzc2t8nFxcXFKT093vx06dKjK84PRwHbxalI/UpK0dPcxnSssNjkiAAAQKi5seW6x0GQC8JbdzBf/+uuvyx2/+eabat68udavX68RI0ZU+jiLxaLExER/h2cqm9Wiq3ok6P11R5Rf5NSyXcc1tleL6h8IAABQhay8Is3blCpJio2ya3zfliZHBASngNojlZWVJUmKj4+v8rycnBy1adNGycnJGjdunLZt21bpuQUFBcrOzi73FizKLe+jex8AAPCBTzccVX6RU5I0sX9L1Y8y9e/qQNAKmETK6XRq6tSpGjp0qHr27FnpeV26dNHrr7+u+fPn65133pHT6dSQIUN09OjRCs9/9tln1bBhQ/dbcnKyvz4Fnxvcvonioo1fbkt2HlNBscPkiAAAQDBzuVzlmkzceRlNJoCaCphEasqUKdq6das++OCDKs8bPHiw7r77bvXt21cjR47Up59+qmbNmunVV1+t8Pwnn3xSWVlZ7rcjR474I3y/iLRbNaZ7giQpp6BYKXtOmBwRAAAIZqv3ndT+48Ze9EHt4tU5oYHJEQHBKyASqUceeUSff/65vvvuO7Vq5d0wuIiICPXr10979+6t8P6oqCjFxcWVewsmY1neBwAAfOSdtVSjAF8xNZFyuVx65JFHNHfuXC1ZskTt2rXz+jkcDoe2bNmiFi1CsxHD8E5NVS/SJklatD1TRQ6nyREBAIBglJmdr2+2ZUqSmsZG6eoeod24C/A3UxOpKVOm6J133tF7772nBg0aKCMjQxkZGcrLy3Ofc/fdd+vJJ590H//xj3/UwoULtX//fm3YsEF33nmnDh06pPvvv9+MT8HvoiNsGt21uSSjy86a/SdNjggAAASj99cdlsPpkiR
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X3, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0]).reshape(4, 1)\n",
"theta, _ = gradient_descent(cost, gradient, theta_start, X3, y)\n",
"plot_fun(fig, polynomial_regression(theta), X)\n",
"\n",
"print(theta)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Regresję wielomianową można potraktować jako szczególny przypadek regresji liniowej wielu zmiennych:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 x^2 + \\theta_3 x^3 $$\n",
"$$ x_1 = x, \\quad x_2 = x^2, \\quad x_3 = x^3, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\\\ x_2 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(W tym przypadku za kolejne cechy przyjmujemy kolejne potęgi zmiennej $x$)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Uwaga praktyczna: przyda się normalizacja cech, szczególnie skalowanie!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Do tworzenia cech „pochodnych” możemy używać nie tylko potęgowania, ale też innych operacji matematycznych, np.:\n",
"\n",
"$$ h_{\\theta}(x) = \\theta_0 + \\theta_1 x + \\theta_2 \\sqrt{x} $$\n",
"$$ x_1 = x, \\quad x_2 = \\sqrt{x}, \\quad \\vec{x} = \\left[ \\begin{array}{ccc} x_0 \\\\ x_1 \\end{array} \\right] $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jakie zatem cechy wybrać? Najlepiej dopasować je do konkretnego problemu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Wielomianowa regresja logistyczna\n",
"\n",
"Podobne modyfikacje cech możemy również stosować dla regresji logistycznej."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def powerme(x1, x2, n):\n",
" \"\"\"Funkcja, która generuje n potęg dla zmiennych x1 i x2 oraz ich iloczynów\"\"\"\n",
" X = []\n",
" for m in range(n + 1):\n",
" for i in range(m + 1):\n",
" X.append(np.multiply(np.power(x1, i), np.power(x2, (m - i))))\n",
" return np.hstack(X)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[ 1. , 0.36596696, -0.11214686],\n",
" [ 0. , 0.4945305 , 0.47110656],\n",
" [ 0. , 0.70290604, -0.92257983],\n",
" [ 0. , 0.46658862, -0.62269739],\n",
" [ 0. , 0.87939462, -0.11408015],\n",
" [ 0. , -0.331185 , 0.84447667],\n",
" [ 0. , -0.54351701, 0.8851383 ],\n",
" [ 0. , 0.91979241, 0.41607012],\n",
" [ 0. , 0.28011742, 0.61431157],\n",
" [ 0. , 0.94754363, -0.78307311]])"
]
},
2022-11-28 11:52:13 +01:00
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wczytanie danych\n",
"import pandas\n",
"import numpy as np\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data[:, 1:]\n",
"\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"data[:10]\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_data_for_classification(X, Y, xlabel, ylabel):\n",
" \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n",
" fig = plt.figure(figsize=(16 * 0.6, 9 * 0.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" X = X.tolist()\n",
" Y = Y.tolist()\n",
" X1n = [x[1] for x, y in zip(X, Y) if y[0] == 0]\n",
" X1p = [x[1] for x, y in zip(X, Y) if y[0] == 1]\n",
" X2n = [x[2] for x, y in zip(X, Y) if y[0] == 0]\n",
" X2p = [x[2] for x, y in zip(X, Y) if y[0] == 1]\n",
" ax.scatter(X1n, X2n, c=\"r\", marker=\"x\", s=50, label=\"Dane\")\n",
" ax.scatter(X1p, X2p, c=\"g\", marker=\"o\", s=50, label=\"Dane\")\n",
"\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(0.05, 0.05)\n",
" return fig\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przyjmijmy, że mamy następujące dane i chcemy przeprowadzić klasyfikację dwuklasową dla następujących klas:\n",
" * czerwone krzyżyki\n",
" * zielone kółka"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 10,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpKUlEQVR4nO3df3hU1b3v8c9MQkI0TpALJFCHatSAtigKNYYgckuuYOkxQs9psDYqh+pB0VagKvTWeMS2WH/gfWr50VpFm1ZJ9RGRaqkGxSYBQUEqKhBRbKKSWOWQAYyZJLPvH9OMmclkMpPMr73n/XqeeZS91wxrNntm9nd/1/oum2EYhgAAAAAAUWdPdAcAAAAAwKoIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEYIuAAAAAAgRgi4AAAAACBGCLgAAAAAIEbSE90BK/B4PPr444910kknyWazJbo7AAAAAGLIMAwdPXpUo0aNkt0eOodFwBUFH3/8sZxOZ6K7AQAAACCOGhsbdcopp4RsQ8AVBSeddJIk7wF3OBwJ7g0AAACAWHK5XHI6nb44IBQCrijoGkbocDgIuAAAAIAUEc50IopmAAAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMEHABAAAAQIwQcAEAAABAjBBwAQAAAECMmCrg+tvf/qZ/+7d/06hRo2Sz2fTMM8/0+ZwtW7bo/PPPV2Zmps444ww9+uijPdqsXLlSp556qgYPHqzCwkLt2LEj+p0HAAAAkHJMFXAdP35c5557rlauXBlW+4MHD2rmzJn63//7f2v37t26+eab9YMf/EB//etffW2qqqq0aNEi3XHHHdq1a5fOPfdcTZ8+XZ988kms3gYAAACAFGEzDMNIdCf6w2azaf369br88st7bXPbbbfpueee01tvveXbNmfOHB05ckSbNm2SJBUWFuob3/iGfv3rX0uSPB6PnE6nbrrpJi1ZsiSsvrhcLuXk5KilpUUOh6P/bwoAAABA0ovk+t9UGa5Ibdu2TSUlJX7bpk+frm3btkmS3G63du7c6dfGbrerpKTE1yaYtrY2uVwuvwcAAAAABLJ0wNXU1KTc3Fy/bbm5uXK5XGptbdWnn36qzs7OoG2ampp6fd3ly5crJyfH93A6nTHp/4C53QPbDwAAAGBALB1wxcrSpUvV0tLiezQ2Nia6Sz1VVUnjxkm99a2x0bu/qiq+/QIAAABSiKUDrry8PDU3N/tta25ulsPhUFZWloYNG6a0tLSgbfLy8np93czMTDkcDr9HUnG7pYoKqb5emjq1Z9DV2OjdXl/vbUemC0CikZEHAFiUpQOuoqIibd682W/biy++qKKiIklSRkaGJkyY4NfG4/Fo8+bNvjamlJEhVVdL+fnS++/7B11dwdb773v3V1d72wNAopCRBwBYmKkCrmPHjmn37t3avXu3JG/Z9927d6uhoUGSd6jfVVdd5Ws/f/58vf/++7r11lu1b98+rVq1Sn/605+0cOFCX5tFixbpoYce0mOPPaa9e/fq+uuv1/HjxzV37ty4vreoczqlLVv8g66tW/2DrS1bvO1SBXfQgeRDRh4AYHGmCrhef/11nXfeeTrvvPMkeYOl8847TxUVFZKkQ4cO+YIvSTrttNP03HPP6cUXX9S5556r+++/X7/73e80ffp0X5uysjLdd999qqio0Pjx47V7925t2rSpRyENUwoMuoqLUzfY4g46kJzIyAMALM6063Alk6Rfh2vrVm+w1aWuTpo0KXH9iTe32xtM1dcHDza7X9QVFEh79nBRB8RbYHBVWSmVl6fuTSIAQA+t7a1ytbnkyHQoa1BWQvvCOlz4UmOj96Klu/Ly3jM9VsQddCD5kZEHAPSitqFWs6tmK3t5tvLuz1P28mzNrpqtuoa6RHctLARcVhYYTNTVBQ86UgFz2oDk53R6M1vdVVbyuQSAFLb6tdWasnaKNtZvlMfwSJI8hkcb6zfqorUXac3raxLcw74xpDAKknJIYWCw1RVM9LY9VXR//11S8TgAyYjPJwCgm9qGWk1ZO0WGeg9XbLKpZm6NikcX99omFhhSmOrcbqmkJHhQFZjpKSlJrapf3EEHkhMZeaQaKucCfVqxbYXS7Gkh26TZ0/TAqw/EqUf9Q8BlRRkZ0rJl3gIQwe4MdwVdBQXedqk0Z4k5bUDyCZZ5nzSp5zBgPqewCirnAn1qbW/Vhv0b1OHpCNmuw9Oh9fvWq7W9NU49ixwBl1WVlXmr7fWWuXE6vfvLyuLbr0TiDjqQfPqbkSc7ALNi7TkgLK42l2/OVl88hkeuNleMe9R/BFxW1lfmKtUyW9xBB5JPfzLyZAdgZlTOBcLiyHTIbgsvVLHb7HJkJkkdhSAomhEFSVk0A19iHS4g+bndoT93Xfv5PMMqWHsO6NPsqtnaWL8x5LDCdHu6SseU6qnvPhXHnlE0A/DHnDYg+YWbkSc7AKtg7TmgT4uKFqnT0xmyTaenUwsvXBinHvUPGa4oIMNlEuHeQQeQ/MgOwCq2bvUGW13q6rxD3gFIkta8vkY3PHeD0uxpfpmudHu6Oj2dWjVzleZPnB/3fpHhAoJhThtgHWQHYAVUzgX6NH/ifNXMrVHpmFLfnC67za7SMaWqmVuTkGArUmS4ooAMFwBEUSTZaLIDMCuytEDEWttb5WpzyZHpUNagrIT2hQwXAMCcIqlASHYAZkXlXKBfsgZlKTc7N+HBVqQIuAAAySGS9YmWLpUuvph19WA+/V17DoBpEXABAJJDuBUIR4+WOjulgwfJDsB8qJwLpBzmcEUBc7gAIIpCzW057TTJZut9ngvrcMEsqJwLmBpzuAAA5hWqAuErr0i/+AXZAZgflXOBlEGGKwrIcAFADISqQEh2IPnxbwTAwshwAQDMra8KhGQHklsk1SYBwOIIuAAAySVwDhcVCM0lkmqTFRVU4QNgeQRcAIDkwfpE5hdutcn8fG87spHm01eQTBAN+CHgAgAkB9Ynso7Af6+pU71z8gKD6cCiJ0h+DBcFIkbABQBIDqxPZC2hqk0SbJkTw0WBfqFKYRRQpRAAoojqdtYSqtokzCfYsF+ns/ftgEVRpRAAYF5UILSOvqpNwnwYLgpEjIALAABEH9UmrYvhokBECLgAAEB0UW3S+pxOqbLSf1tlJcEWEAQBFwAAiB6qTaYGhosCYSPgAgAA0UO1SetjuCgQEaoURgFVCgEACEC1SWuiSiEgiSqFAAAg0ag2aT0MFwX6hYALAAAAfWO4KNAvDCmMAoYUAgCAlMFwUYAhhQAAAIgRhosCESHgAgAAAIAYIeACAAAAgBgh4AIAAACAGCHgAgAAAIAYIeACAAAAgBgxZcC1cuVKnXrqqRo8eLAKCwu1Y8eOXttOnTpVNputx2PmzJm+Ntdcc02P/TNmzIjHWzGfvhYxZJFDAAAAwMd0AVdVVZUWLVqkO+64Q7t27dK5556r6dOn65NPPgna/umnn9ahQ4d8j7feektpaWn6j//4D792M2bM8Gv3xBNPxOPtmEtVlTRunNTYGHx/Y6N3f1VVfPsFAADMi5u5sDjTBVwrVqzQtddeq7lz5+rss8/WmjVrdMIJJ+iRRx4J2n7o0KHKy8vzPV588UWdcMIJPQKuzMxMv3Ynn3xyPN6OebjdUkWFVF8vTZ3aM+hqbPRur6/3tuPLEQAA9IWbuUgBpgq43G63du7cqZKSEt82u92ukpISbdu2LazXePjhhzVnzhydeOKJftu3bNmiESNGaMyYMbr++uv12Wef9foabW1tcrlcfg/Ly8iQqqul/Hzp/ff9g66uYOv99737q6tZ9BAAAITGzVykCFMFXJ9++qk6OzuVm5vrtz03N1dNTU19Pn/Hjh1666239IMf/MBv+4wZM/T73/9emzdv1i9/+Uu98soruvTSS9XZ2Rn0dZYvX66cnBzfw+l09v9NmYnTKW3Z4h90bd3qH2xt2eJtBwAAEAo3c5EibIZhGInuRLg+/vhjfeUrX9HWrVtVVFTk237rrbfqlVde0fbt20M+/7/+67+0bds2vfnmmyHbvf/++zr99NNVXV2tadOm9dj
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Propozycja hipotezy:\n",
"\n",
"$$ h_\\theta(x) = g(\\theta^T x) = g(\\theta_0 + \\theta_1 x_1 + \\theta_2 x_2 + \\theta_3 x_3 + \\theta_4 x_4 + \\theta_5 x_5) \\; , $$\n",
"\n",
"gdzie $g$ funkcja logistyczna, $x_3 = x_1^2$, $x_4 = x_2^2$, $x_5 = x_1 x_2$."
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def safeSigmoid(x, eps=0):\n",
" \"\"\"Funkcja sigmoidalna zmodyfikowana w taki sposób,\n",
" żeby wartości zawsz były odległe od asymptot o co najmniej eps\n",
" \"\"\"\n",
" y = 1.0 / (1.0 + np.exp(-x))\n",
" if eps > 0:\n",
" y[y < eps] = eps\n",
" y[y > 1 - eps] = 1 - eps\n",
" return y\n",
"\n",
"\n",
"def h(theta, X, eps=0.0):\n",
" \"\"\"Funkcja hipotezy\"\"\"\n",
" return safeSigmoid(X * theta, eps)\n",
"\n",
"\n",
"def J(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu\"\"\"\n",
" m = len(y)\n",
" f = h(theta, X, eps=10**-7)\n",
" j = (\n",
" -np.sum(np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0)\n",
" / m\n",
" )\n",
" if lamb > 0:\n",
" j += lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ(h, theta, X, y, lamb=0):\n",
" \"\"\"Pochodna funkcji kosztu\"\"\"\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" if lamb > 0:\n",
" g[1:] += lamb / float(y.shape[0]) * theta[1:]\n",
" return g\n",
"\n",
"\n",
"def classifyBi(theta, X):\n",
" \"\"\"Funkcja decyzji\"\"\"\n",
" prob = h(theta, X)\n",
" return prob\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta])\n",
" return theta, errors\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.59558981]\n",
" [ 0.12602307]\n",
" [ 0.65718518]\n",
" [-5.26367581]\n",
" [ 1.96832544]\n",
" [-6.97946065]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n",
"print(r\"theta = {}\".format(theta))\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def plot_decision_boundary(fig, theta, X):\n",
" \"\"\"Wykres granicy klas\"\"\"\n",
" ax = fig.axes[0]\n",
" xx, yy = np.meshgrid(np.arange(-1.0, 1.0, 0.02), np.arange(-1.0, 1.0, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(theta, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-18 09:44:59 +02:00
"/tmp/ipykernel_868/1169766636.py:9: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2022-11-28 11:52:13 +01:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACXNklEQVR4nOzdeVxU1fsH8M8srOKAioDmuKCCS+6kImqUlKglaQu2aJplLm1iWVZqaqWpad/KpcWlKJOs1Ex/lpioLO6au7gzLuAWjMDIwMz9/TExMjDsM3Nn+bxfL17KvWeGZy7D3Pvcc85zJIIgCCAiIiIiIiKLk4odABERERERkbNiwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhK5GIH4Az0ej2uXLmCunXrQiKRiB0OERERERFZkSAIuH37Nho3bgyptOI+LCZcFnDlyhUolUqxwyAiIiIiIhtSqVRo0qRJhW2YcFlA3bp1ARgOuEKhuLvj0iVg0CDgwgWgeXPg66+BMWPufr9xI1DJL8hpabWAu3vN9xORuEp/vhV/npW3nYiIyImo1WoolUpjHlARiSAIgg1icmpqtRq+vr7IyckxTbgAQKUCIiOBc+fubgsOBpKSAFftFUtIAKZNAxITzR8DlQqIigJmzgRiY20fHxFVTcnPt+BgID4eGD787veu/DlHREROrcLr/1KYcFlApQc8NRWIiLj7fUoK0KuX7QK0J1ot0KEDkJ5u/oKs5AVcSAhw5Ah7uojsGW8qERGRlWkKNVAXqKHwUMDLzUvscABUL+FilUJrU6kMd3xLGj7csN0VubsberaCgw0XaJGRd49F6bvliYlMtojsnVJp6NkqKT6eyRYREdVackYyhiYMhc9sHwR9GgSf2T4YmjAUKRkpYodWLUy4rKl0ApGSYj7RcDVKpeHud8ljkZpqeqx4d5zIMfCmEhERWcGSvUvQd0VfbEjfAL2gBwDoBT02pG9AnxV9sHTfUpEjrDomXNZSOtlKSjIMIyydaLjqRUnppCsigskWkaPhTSUiIrKC5IxkTNg0AQIEFOmLTPYV6YsgQMD4jeMdpqeLCZc1aLWGog/mEojSiUZUlKG9K+JQJCLHxZtKRGVVdj531fM9UTUtSFsAmVRWYRuZVIaFuxbaKKLaYcJlDe7uhgp7ISHme2uKk66QEEM7V52nxKFIRI6JN5WIykpIMBSFKu8cplIZ9ick2DYuIgejKdRg/an1ZXq2SivSF2HtybXQFGpsFFnNMeGylthYQ4W98nprlErDflcte86hSESOizeViExptYblTtLTzZ/Dis956emGdrwJQVQudYHaOGerMnpBD3WB2soR1R4TLmuq7CLDVS9COBSJyPHV5KYSh1uRs2IFXiKLUXgoIJVULUWRSqRQeFRckt0eMOEi2+JQJCLnUZ2bShxuRc6OFXiJLMLLzQsxoTGQS+UVtpNL5RjSZojdrMtVESZcZFscikTkejjcilwFK/ASWURceBx0el2FbXR6HSb2nGijiGqHCRfZHue3EbkWDrciV8IKvES11rtpbywetBgSSMr0dMmlckggweJBixHRNEKkCKuHCReJg/PbiFwLh1uRq2AFXiKLGBs2FjtH7URMaIxxTpdUIkVMaAx2jtqJsWFjRY6w6iSCIAhiB+Ho1Go1fH19kZOTA4XC/ifuERGJpmSPVjEmW+QsSvfYxscbki3eVCCqFU2hBuoCNRQeCruZs1Wd63/2cBERUe1VtQIhh1uRs2IFXiKr8XLzQqBPoN0kW9XFhIuIiGqnOhUIOdyKnBEr8BJRBZhwERFRzVWnAuGUKcD993PBc3I+rMBLRBXgHC4L4BwuInJp5oZSKZWm25s2NbTNyCi/Dee4kKPTaitOpirbT0QOg3O4iIjIdiqrQNiiBSCXl022zD2Ww63IkbECLxGZwYSLiIhqr6IFX7dvBz7+mMOtiIjIJXFIoQVwSCER0X9SUw3JVrGUFEOlNoDDrYiIyGlwSCEREdleZRUIOdzKcVW17D8REZXBhIuIiGqvdPELViB0HtUp+09ERGUw4SIiotrhgq/Oqzpl/6dNY08XEZEZTLiIiKjmuOCrc3N3BxITzSfOpRPtxEQOC3UGHD5KZHFMuIiIqOa44Kvzq6zsP9dPcx4cPkpkFaxSaAGsUkhELo8VCJ1fyR6tYky2nIdWa0im0tPN/15L/v5DQoAjR/g3TS6NVQqJiMi2WIHQ+SmVQHy86bb4eCZbzoLDR4mshgkXERERVa6ysv/k+Dh8lMgqmHARERFRxVj233WUTroiIphsEdUSEy4iIiIqH8v+ux4OHyWyKCZcREREZB7L/rsmDh8lsigmXERERGQey/67Hg4fJbI4loW3AJaFJyIip8ay/67B3PBRpbL87UQujGXhiYiIyHJY9t/5cfgokdUw4SIiIiJydRw+SmQ1HFJoARxSSERERE6Bw0eJqoRDComIiIio+jh8lMjimHARERERERFZCRMuIiIiIiIiK2HCRUREREREZCVMuIiIiIiIiKyECRcREREREZGVMOEiIiIiIiKyEiZcREREREREVuKQCdeiRYvQvHlzeHp6okePHtizZ0+5bSMjIyGRSMp8DRo0yNhm5MiRZfZHR0fb4qXUjFZbu/1EREREtsLrFnJxDpdwJSQkIC4uDtOnT8eBAwfQqVMn9O/fH9euXTPb/rfffsPVq1eNX0ePHoVMJsOTTz5p0i46Otqk3U8//WSLl1N9CQlAhw6ASmV+v0pl2J+QYNu4iIiIiErjdQuR4yVcCxYswEsvvYRRo0ahXbt2WLp0Kby9vbF8+XKz7evXr4+goCDj15YtW+Dt7V0m4fLw8DBpV69ePVu8nOrRaoFp04D0dCAysuyHl0pl2J6ebmjHO0ZEREQkFl63EAFwsIRLq9Vi//79iIqKMm6TSqWIiopCWlpalZ5j2bJlGDZsGOrUqWOyPSkpCQEBAQgNDcW4ceNw8+bNcp+joKAAarXa5Msm3N2BxEQgOBg4d870w6v4Q+vcOcP+xERDeyIiIiIx8LqFCICDJVw3btyATqdDYGCgyfbAwEBkZmZW+vg9e/bg6NGjePHFF022R0dH4/vvv8fWrVvxySefYPv27RgwYAB0Op3Z55k9ezZ8fX2NX0qlsuYvqrqUSiApyfTDKzXV9EMrKcnQjoiIiEhMvG4hgkQQBEHsIKrqypUruOeee5Camorw8HDj9smTJ2P79u3YvXt3hY9/+eWXkZaWhsOHD1fY7ty5c2jZsiUSExPRr1+/MvsLCgpQUFBg/F6tVkOpVCInJwcKhaKar6qGSt4ZKsYPLSIiIrJHvG4hJ6NWq+Hr61ul63+H6uHy9/eHTCZDVlaWyfasrCwEBQVV+Ni8vDysXr0ao0ePrvTnBAcHw9/fH2fOnDG738PDAwqFwuTL5pRKID7edFt8PD+0iIiIyP7wuoVcmEMlXO7u7ujWrRu2bt1q3KbX67F161aTHi9z1qxZg4KCAjz33HOV/pxLly7h5s2baNSoUa1jthqVChg+3HTb8OHlVwEiIiIiEguvW8iFOVTCBQBxcXH45ptv8N133+HEiRMYN24c8vLyMGrUKADAiBEjMGXKlDKPW7ZsGR577DE0aNDAZHtubi7eeust7Nq1CxcuXMDWrVsRExODVq1aoX///jZ5TVVSsnJP6YmmKSlAixZlJ6QSERERic3cdYu5QhpETkoudgDVFRsbi+vXr2PatGnIzMxE586dsXnzZmMhjYyMDEilpnnkqVOnkJycjL/++qvM88lkMhw+fBjfffcdsrOz0bhxYzz88MOYNWsWPDw8bPKaKpWQYCiXmpho+L70RNNiAQF3P7w4JpqIiIjEVjrZKr4+SUq6u53XLeTkHKpohr2qzqS5atNqDQsCpqcberEkkrLJVvEHVosWhu/PnwdCQoAjR1hilYiIiMRR8hrGXIGMkskYr1vIwTht0QyXVHINi/Pngdxc88lWcDCwfbvhKyQEmDmTH1pEREQkHnd3w/VISIj5Hqzini5et5CTYw+
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Wczytanie danych\n",
"\n",
"alldata = pandas.read_csv(\"polynomial_logistic.tsv\", sep=\"\\t\")\n",
"data = np.matrix(alldata)\n",
"\n",
"m, n_plus_1 = data.shape\n",
"Xn = data[:, 1:]\n",
"\n",
"n = 10\n",
"Xpl = powerme(data[:, 1], data[:, 2], n)\n",
"Ypl = np.matrix(data[:, 0]).reshape(m, 1)\n",
"\n",
"theta_start = np.matrix(np.zeros(Xpl.shape[1])).reshape(Xpl.shape[1], 1)\n",
"theta, errors = GD(\n",
" h, J, dJ, theta_start, Xpl, Ypl, alpha=0.1, eps=10**-7, maxSteps=10000\n",
")\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-18 09:44:59 +02:00
"/tmp/ipykernel_868/1169766636.py:9: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2022-11-28 11:52:13 +01:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHvCAYAAABAJN42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACZp0lEQVR4nOzdeVxU1fsH8M8sMIA4gCIgCioquO+Jipr9pNwqUisqc8ss0za1zRYtW8wstcUl+6YmbWilZpol5sLmrrmLO6iAC8IIjDMwc39/TIwMDPvM3Fk+79drXsq9Z4Znhpk797nnnOdIBEEQQERERERERBYnFTsAIiIiIiIiZ8WEi4iIiIiIyEqYcBEREREREVkJEy4iIiIiIiIrYcJFRERERERkJUy4iIiIiIiIrIQJFxERERERkZXIxQ7AGej1ely5cgX169eHRCIROxwiIiIiIrIiQRBw69YtBAcHQyqtvA+LCZcFXLlyBSEhIWKHQURERERENpSRkYGmTZtW2oYJlwXUr18fgOEFVyqVd3ZcugQMGwZcuAA0bw4sWwY888ydnzduBKr4AzktrRZwd6/9fiISV9njW8nxrKLtRERk9w5sPYIvp3yDnMxcAMA9j0dh/IePw6ehsvI7uiCVSoWQkBBjHlAZiSAIgg1icmoqlQo+Pj7Iy8szTbgAICMDGDAAOHfuzrawMGD7dsBVe8Xi44GZM4GEBPOvQUYGEB0NzJ4NxMbaPj4iqp7Sx7ewMCAuDhg9+s7PrnycIyJyUAV5BVj+1k/YsORvCIIAZcP6mDR/LKKf7M+pM6VUev5fBhMuC6jyBU9JAaKi7vycnAz06WO7AO2JVgt07AikpZk/ISt9AhceDhw5wp4uInvGi0pERE7p+K40LHz2a5w/kg4A6BbdES8teQbBLYNsHou6SA2VRgWlQglPN0+b/35zapJwsUqhtWVkGK74ljZ6tGG7K3J3N/RshYUZTtAGDLjzWpS9Wp6QwGSLyN6FhBh6tkqLi2OyRUTk4Nr1CsfifXPx1IdPwE3hhgMJR/BMp+n4beFG6HQ6m8SQlJ6EEfEj4D3HG0GfBcF7jjdGxI9AcnqyTX6/pTDhsqayCURysvlEw9WEhBiufpd+LVJSTF8rXh0ncgy8qERE5LTkbnI8PmM4lh3+DJ0HtIdGrcWSaSsxrf9MpJ+8bNXfvWTvEvRf0R8b0jZAL+gBAHpBjw1pG9BvRT8s3bfUqr/fkjik0ALMdimWTbZKEoiKtrsiDkUicmycw0VE5DL0ej02fbMV37wWh8Jbargp3DB65iN49NUHIZPLLPq7ktKT0H9FfwioOE2RQILE8YmICo2qsI01cUih2LRaQ9EHcycdZXt3oqMN7V0RhyIROS5zF4/69Cnfe82eLnIlVX2fu+r3PTkFqVSK+5+9F98cnY+7hnRFkaYIy9/6ES/2eQvnj6Zb9HfNT50PmbTyJE4mlWHBrgUW/b3WwoTLGtzdDRX2wsPNX+EtSbrCww3tXHWeEociETkmXlQiKi8+3lAUqqLvsIwMw/74eNvGRWRhASH++PCPGXh1xRR4+9ZD2r6zmNz9Nfz40W/QFdd9bpe6SI31p9ajWF9cabtifTHWnlwLdZG6zr/T2phwWUtsrKHCXkW9NSEhhv2uWvac89uIHBcvKhGZ0moNy52kpZn/Div5zktLM7TjRQhycBKJBPeNHYBvjs5H7wd7oLhIhxVv/4SX+72DjFN1m9ul0qiMc7aqohf0UGlUdfp9tsA5XBZQkzGcBM5vI3IWNV3AnAuekzPjdxu5KEEQkPD9Tix6cTkK8gqh8HTHhDmjEPP8YEilNe/bURep4T3Hu1pJl1QiRf6MfFFKxXMOF9kvDkUich5VJUel93O4FTk7VuAlFyWRSHDv6Lux7PBn6HZvJ2jUWix+eQVev3c2si9eq/Hjebp5IiYiBnKpvNJ2cqkcw9sMt5t1uSrDhItsi0ORiFwPh1uRqyibdEVFMdkilxEQ4o+PN7+NFxc9DQ8vBQ5tO4ZnOk3HXyu3oaYD6qb1ngadvvL5YDq9DlN7Ta1LyDbDIYUWwCGFtcChRUSuhcOtyJWkpBiSrRLJyYYqnkQu4vKZTHwybhGOp5wCAPR/pDdeWjIRygb1q/0YS/ctxeSNkyGTykwKaMilcuj0OiwethiTekyyeOzVxSGFZP9qMhSJiBwfh1uRq2AFXiI0adUY83e8h6c+fAIyuQw716RiUpdXcWjb0Wo/xqQek5A4PhExETGQSgwpi1QiRUxEDBLHJ4qabNUUe7gsgD1cRETVxAXPyZlxMXCick7tO4s5oz7H5dOZkEgkeOSVBzHu/Vi4ubtV+zHURWqoNCooFUq7mbPFHi4iIrKt6i74ygXPyVlxMXAisyJ6tMSSA59g6NMDIQgCVs9bj5f6vIX0k9UvH+/p5olA70C7SbZqigkXERHVTU0qEHK4FTkjVuAlqpRnPQ9MXTYJs359BfUbeOP0gfOY0uN1bFm1Q+zQbIIJFxER1V5NKhDOmAHcfTcXPCfnwwq8RNXSd3gklh3+DF0HdsTtQg0+GfcV5j21COqC22KHZlWcw2UBnMNFRC6tOhUIQ0MNbdPTWaWQnBcr8BJVi06nw08frUXce6uh1wto1q4p3o6fhubtHef4zzlcRERkO1VVIGzRApDLyydb5u7L4VbkyFiBl6haZDIZnnznYXySMAsNgnxx8fglPN/zDfy1cpvYoVkFEy4iIqq7yhZ83bED+OgjDrciIiITnQe0x9KD89Dt3k7QqLX49KnF+GTcVyguKq76zg6ECRcREVlGZRUIY2OBI0cqHi4YEmLYHxtr/TiJiMhu+AX6Ys6fb2Hc+49BKpWguKgYMrlM7LAsinO4LIBzuIiIwDW2nBnnJhGRDRxNPomwTs3gVd/+y79zDhcREdlW2eIXrEDoPGpS9p+IqA46RLVxiGSrpphwERFR3XDBV+dVk7L/M2ey4AkRkRlMuIiIqPa44Ktzc3cHEhLMJ85lE+2EBA4rdAZVfUb5GSaqMSZcRERUe1zw1flVVfaf8/ScB4ePElkFi2ZYAItmEJHLY1EF58eiKM5NqzUkU2lp5v+upf/+4eGGqqL8TJMLY9EMIiKyLS746vwqK/tPjo/DR4mshgkXERERVS0jAxg92nTb6NEshuJMOHyUyCqYcBEREVHlWPbfdZRNuqKimGwR1RETLiIiIqoYy/67Hg4fJbIoJlxERERkHsv+uyYOHyWyKCZcREREZB7L/rseDh8lsjiWhbcAloUnIiKnxrL/rsHc8NGQkIq3E7kwloUnIiIiy2HZf+fH4aNEVsOEi4iIiMjVcfgokdVwSKEFcEghEREROQUOHyWqFg4pJCIiIqKa4/BRIotjwkVERERERGQlTLiIiIiIiIishAkXERERERGRlTDhIiIiIiIishImXERERERERFbChIuIiIiIiMhKmHARERERERFZiUMmXIsWLULz5s3h4eGByMhI7Nmzp8K2AwYMgEQiKXcbNmyYsc24cePK7R88eLAtnkrtaLV1209ERERkKzxvIRfncAlXfHw8pk2bhlmzZuHAgQPo3LkzBg0ahKtXr5pt/9tvvyEzM9N4O3r0KGQyGR555BGTdoMHDzZp99NPP9ni6dRcfDzQsSOQkWF+f0aGYX98vG3jIiIiIiqL5y1EjpdwzZ8/HxMnTsT48ePRrl07LF26FF5eXli+fLnZ9g0aNEBQUJDxtmXLFnh5eZVLuBQKhUk7Pz8/WzydmtFqgZkzgbQ0YMCA8gevjAzD9rQ0QzteMSIiIiKx8LyFCICDJVxarRb79+9HdHS0cZtUKkV0dDRSU1Or9RjffvstHnvsMdSrV89k+/bt2xEQEICIiAg899xzuHHjRoWPodFooFKpTG424e4OJCQAYWHAuXOmB6+Sg9a5c4b9CQmG9kRERERi4HkLEQAHS7iuX78OnU6HwMBAk+2BgYHIysqq8v579uzB0aNH8fTTT5tsHzx4MFatWoWtW7di7ty52LFjB4YMGQKdTmf2cebMmQMfHx/jLSQkpPZPqqZCQoDt200PXikppget7dsN7Yi
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Przykład dla większej liczby cech\n",
"fig = plot_data_for_classification(Xpl, Ypl, xlabel=r\"$x_1$\", ylabel=r\"$x_2$\")\n",
"plot_decision_boundary(fig, theta, Xpl)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.2. Problem nadmiernego dopasowania"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie a wariancja"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Dane do prostego przykładu\n",
"\n",
"data = np.matrix(\n",
" [\n",
" [0.0, 0.0],\n",
" [0.5, 1.8],\n",
" [1.0, 4.8],\n",
" [1.6, 7.2],\n",
" [2.6, 8.8],\n",
" [3.0, 9.0],\n",
" ]\n",
")\n",
"\n",
"m, n_plus_1 = data.shape\n",
"n = n_plus_1 - 1\n",
"Xn1 = data[:, 0:n]\n",
"Xn1 /= np.amax(Xn1, axis=0)\n",
"Xn2 = np.power(Xn1, 2)\n",
"Xn2 /= np.amax(Xn2, axis=0)\n",
"Xn3 = np.power(Xn1, 3)\n",
"Xn3 /= np.amax(Xn3, axis=0)\n",
"Xn4 = np.power(Xn1, 4)\n",
"Xn4 /= np.amax(Xn4, axis=0)\n",
"Xn5 = np.power(Xn1, 5)\n",
"Xn5 /= np.amax(Xn5, axis=0)\n",
"\n",
"X1 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1), axis=1)).reshape(m, n + 1)\n",
"X2 = np.matrix(np.concatenate((np.ones((m, 1)), Xn1, Xn2), axis=1)).reshape(\n",
" m, 2 * n + 1\n",
")\n",
"X5 = np.matrix(\n",
" np.concatenate((np.ones((m, 1)), Xn1, Xn2, Xn3, Xn4, Xn5), axis=1)\n",
").reshape(m, 5 * n + 1)\n",
"y = np.matrix(data[:, -1]).reshape(m, 1)\n"
]
},
{
"cell_type": "code",
2022-11-28 11:52:13 +01:00
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlLklEQVR4nO3df3CU9Z3A8U9+lJCKG4oKgTEq/jhsxYqtymAUdeTKtZ4DMuNVz3Oo1ztPGk+Bnq3ejDrW01SvA07vBHrenDpetdpe0dOrOohVCSL+QkutR9VyylkTrdasYhol+9wfW9KLkC+/kuwmeb1mdjL77HfTT/p0Sd59nme3IsuyLAAAANiuylIPAAAAUM5EEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQEJJo+mxxx6L008/PSZMmBAVFRVx991393g8y7K44oorYvz48VFbWxszZsyIl156qTTDAgAAw1JJo2nz5s1x1FFHxY033rjdx6+//vr47ne/G8uWLYu1a9fGXnvtFTNnzozf/e53AzwpAAAwXFVkWZaVeoiIiIqKili+fHnMnj07IopHmSZMmBBf//rX4+/+7u8iIqK9vT3GjRsXt9xyS5x11lklnBYAABguqks9QG82btwYra2tMWPGjO5tdXV1MXXq1FizZk2v0dTZ2RmdnZ3d9wuFQrzzzjuxzz77REVFRb/PDQAAlE6WZfHee+/FhAkTorKyb06sK9toam1tjYiIcePG9dg+bty47se2p7m5Oa666qp+nQ0AAChvmzZtiv33379PvlfZRtPuuuyyy2LhwoXd99vb2+OAAw6ITZs2RS6XK+FkAABAf8vn89HQ0BB77713n33Pso2m+vr6iIhoa2uL8ePHd29va2uLKVOm9Pq8mpqaqKmp2WZ7LpcTTQAAMEz05aU5Zfs5TRMnToz6+vpYuXJl97Z8Ph9r166NadOmlXAyAABgOCnpkab3338/Xn755e77GzdujOeeey7GjBkTBxxwQMyfPz/+4R/+IQ477LCYOHFiXH755TFhwoTud9gDAADobyWNpqeffjpOOeWU7vtbr0WaO3du3HLLLfGNb3wjNm/eHOeff368++67ccIJJ8QDDzwQI0eOLNXIAADAMFM2n9PUX/L5fNTV1UV7e7trmgAAYIjrj7//y/aaJgAAgHIgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAw9HR0RLS1Fb8C7CHRBAAMHS0tEXPmRIwaFVFfX/w6Z07E6tWlngwYxEQTADA0LF0aMX16xL33RhQKxW2FQvH+iSdGLFtW2vmAQUs0AQCDX0tLRFNTRJZFbNnS87EtW4rbv/Y1R5yA3SKaAIDBb9GiiKqq9JqqqojFiwdmHmBIEU0AwODW0RFxzz3bHmH6uC1bIpYv9+YQwC4TTQDA4JbP/+Eaph0pFIrrAXaBaAIABrdcLqJyJ/+kqawsrgfYBaIJAOg7pfh8pNraiFmzIqqr0+uqqyPOOKO4HmAXiCYAYM+V+vORFi6M6OpKr+nqiliwYGDmAYYU0QQA7Jly+HykE06IWLIkoqJi2yNO1dXF7UuWRDQ29v8swJAjmgCA3VdOn490wQURq1YVT9Xbeo1TZWXx/qpVxccBdsMOTv4FAEjY+vlIqbf73vr5SANxlKexsXjr6Ci+S14u5xomYI+JJgBg92z9fKQdvd33//98pIEKmNpasQT0GafnAQC7x+cjAcOEaAIAdo/PRwKGCdEEAOwen48EDBOiCQDYfT4fCRgGRBMAsPt8PhIwDIgmAGDP+HwkYIjzluMAwJ7z+UjAECaaAIC+4/ORgCHI6XkAAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgATRBAAAkCCaAAAAEkQTAABAgmgCAABIEE0AAAAJogkAACBBNAEAACSIJgAAgISyjqaurq64/PLLY+LEiVFbWxuHHHJIXH311ZFlWalHAwAAhonqUg+Qct1118XSpUvj1ltvjSOOOCKefvrpOO+886Kuri4uuuiiUo8HAAAMA2UdTY8//njMmjUrTjvttIiIOOigg+KOO+6IJ598stfndHZ2RmdnZ/f9fD7f73MCAABDV1mfnnf88cfHypUr45e//GVERDz//PPR0tISX/ziF3t9TnNzc9TV1XXfGhoaBmpcAABgCKrIyvgCoUKhEH//938f119/fVRVVUVXV1dcc801cdlll/X6nO0daWpoaIj29vbI5XIDMTYAAFAi+Xw+6urq+vTv/7I+Pe+uu+6K73//+3H77bfHEUccEc8991zMnz8/JkyYEHPnzt3uc2pqaqKmpmaAJwUAAIaqso6mSy65JC699NI466yzIiLiyCOPjFdffTWam5t7jSYAAIC+VNbXNH3wwQdRWdlzxKqqqigUCiWaCAAAGG7K+kjT6aefHtdcc00ccMABccQRR8S6deti0aJF8Zd/+ZelHg0AABgmyvqNIN577724/PLLY/ny5fHmm2/GhAkT4uyzz44rrrgiRowYsVPfoz8uBAMAAMpTf/z9X9bR1BdEEwAADB/98fd/WV/TBAAAUGqiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0AQAAJIgmAACABNEEAACQIJoAAAASRBMAAECCaAIAAEgQTQAAAAmiCQAAIEE0ATA8dXREtLUVvwJAgmgCYHhpaYmYMydi1KiI+vri1zlzIlavLvVkAJQp0QTA8LF0acT06RH33htRKBS3FQrF+yeeGLFsWWnnA6AsiSYAhoeWloimpogsi9iypedjW7YUt3/ta444AbAN0QTA8LBoUURVVXpNVVXE4sUDMw8Ag4ZoAmDo6+iIuOeebY8wfdyWLRHLl3tzCAB6EE0ADH35/B+uYdqRQqG4HgB+TzQBMPTlchGVO/krr7KyuB4Afk80ATD01dZGzJoVUV2dXlddHXHGGcX1APB7ogmA4WHhwoiurvSarq6IBQsGZh4ABg3RBMDwcMIJEUuWRFRUbHvEqbq6uH3JkojGxtLMB0DZEk0ADB8XXBCxalXxVL2t1zhVVhbvr1pVfBwAPmYHJ3cDwBDT2Fi8dXQU3yUvl3MNEwBJogmA4am2ViwBsFOcngcAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQIJoAgAASBBNAAAACaIJAAAgQTQBAAAkiCYAAIAE0QQAAJAgmgAAABJEEwAAQELZR9Prr78ef/EXfxH77LNP1NbWxpFHHhlPP/10qccCAACGiepSD5Dy29/+NhobG+OUU06J+++/P/bbb7946aWX4lOf+lSpRwMAAIaJso6m6667LhoaGuLmm2/u3jZx4sQSTgQAAAw3ZX163n/+53/GMcccE2eeeWaMHTs2jj766LjpppuSz+ns7Ix8Pt/jBgAAsLvKOpp+9atfxdKlS+Owww6LBx98MObNmxcXXXRR3Hrrrb0+p7m5Oerq6rpvDQ0NAzgxAAAw1FRkWZaVeojejBgxIo455ph4/PHHu7dddNFF8dRTT8WaNWu2+5zOzs7o7Ozsvp/P56OhoSHa29sjl8v1+8wAAEDp5PP5qKur69O//8v6SNP48ePjM5/5TI9tn/70p+O1117r9Tk1NTWRy+V63AAAAHZXWUdTY2NjbNiwoce2X/7yl3H
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 51,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
2024-04-18 09:44:59 +02:00
"name": "stdout",
"output_type": "stream",
"text": [
"Koszt: 0.41863137063802436\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKd0lEQVR4nO3dd3RUdf7/8deUNEgjlIRA6BB6ggWkiKIodhCl+HX9uW5xVVwJsYG79gI2in3X72/V3+4qRSl2F1EEBFExofcaSkLPkJA6c39/jI6LQiDJJJ8pz8c5c3I+N3dmXpxxMvNy7ryvzbIsSwAAAACAk7KbDgAAAAAAgYzSBAAAAABVoDQBAAAAQBUoTQAAAABQBUoTAAAAAFSB0gQAAAAAVaA0AQAAAEAVKE0AAAAAUAVKEwAAAABUgdIEAAAAAFUwWpoWLVqkq6++WqmpqbLZbJo7d+4Jv7csSw899JCaN2+umJgYDR48WJs3bzYTFgAAAEBYMlqaiouLlZGRoZdffvmkv3/mmWf0wgsv6LXXXtPy5cvVsGFDDRkyRKWlpfWcFAAAAEC4slmWZZkOIUk2m01z5szRsGHDJHk/ZUpNTdXdd9+te+65R5JUWFio5ORkvfnmmxo9erTBtAAAAADChdN0gFPZvn278vPzNXjwYN+2hIQE9enTR8uWLTtlaSorK1NZWZlv7fF4dPjwYTVu3Fg2m63OcwMAAAAwx7IsHTt2TKmpqbLb/XNgXcCWpvz8fElScnLyCduTk5N9vzuZiRMn6tFHH63TbAAAAAACW15enlq2bOmX2wrY0lRTEyZMUHZ2tm9dWFioVq1aKS8vT/Hx8QaTAQCAurRo0wHd8e8fJEmRTrtm3Xae2jeNM5wKQH1zuVxKS0tTXJz/nv8BW5pSUlIkSQUFBWrevLlve0FBgTIzM095vaioKEVFRf1qe3x8PKUJAIAQdeBYmR75dJvsUQ0kSQ8P665e7VsYTgXAJH9+NSdgz9PUtm1bpaSkaMGCBb5tLpdLy5cvV9++fQ0mAwAAgcSyLN337kodLCqXJF3cuZl+06eV4VQAQonRT5qKioq0ZcsW33r79u3Kzc1VUlKSWrVqpaysLD3xxBPq2LGj2rZtqwcffFCpqam+CXsAAAD//Ganvtx4QJLUJDZST1/fk+FPAPzKaGn6/vvvNWjQIN/6p+8i3XzzzXrzzTd13333qbi4WLfeequOHj2qAQMG6NNPP1V0dLSpyAAAIIBsKjimJz9a71s/OyJDTWJ/fZg+ANRGwJynqa64XC4lJCSosLCQ7zQBABBCyirdGvrS19qQf0yS9Nt+bfTINd0MpwJgWl28/w/Y7zQBAABU5bnPNvoKU6fkWI2/vLPhRABCFaUJAAAEnSWbD+r1xdslSZEOu6aN7qXoCIfhVABCFaUJAAAElSPF5cqemetb3395Z3VpziH4AOoOpQkAAAQNy7I0fvYq7T9WJkk6v2MT3dKvjdlQAEIepQkAAASNGd/l6bO1BZKkRg0i9PyIDNntjBcHULcoTQAAIChsO1CkRz9Y51tPuq6nmsVzGhIAdY/SBAAAAl6F26OsGbkqqXBLkm7o3UpDuqUYTgUgXFCaAABAwJv6+Sat2l0oSWrXpKEevKqL4UQAwgmlCQAABLTl2w7plYVbJUlOu03TRvdSg0in4VQAwgmlCQAABKzCkgqNm5Ery/Ku7740XT1aJpgNBSDsUJoAAEBAsixLf5mzWnsLSyVJ57VL0q0D2xlOBSAcUZoAAEBAmpOzRx+u2idJio92avLITDkYLw7AAEoTAAAIOLsOHddD89b61k8N76HUxBiDiQCEM0oTAAAIKJVuj7Jm5KiorFKSdN1ZLXVVz1TDqQCEM0oTAAAIKC9/uVU/7DoqSWqV1ECPXNPVbCAAYY/SBAAAAsaKnUf0whebJUkOu01TRmUqLjrCcCoA4Y7SBAAAAsKx0gplzciR2+OdL37XRR11dutGhlMBAKUJAAAEiEfeX6e8wyWSpLNbN9KYQe0NJwIAL0oTAAAw7oOVe/XeD7slSbFRTk0dlSmng7cpAAIDf40AAIBRe46W6C9zVvvWjw3tprSkBgYTAcCJKE0AAMAYt8dS9oxcuUq948WvzkjVtb1aGE4FACeiNAEAAGP+vmiblm8/LElqkRijJ4Z1l81mM5wKAE5EaQIAAEas3l2o5/+zUZJks0mTR2YoIYbx4gACD6UJAADUu+PllRo7PUeVP44Xv+PC9urTrrHhVABwcpQmAABQ7x7/cL22HSyWJPVsmaCswZ0MJwKAU6M0AQCAevXZ2ny98+0uSVJMhEPTRvdSBOPFAQQw/kIBAIB6U+Aq1fj3VvnWD1/dVW2bNPT/HZWUSAUF3p8AUEuUJgAAUC88Hkv3zFqpI8crJElDuiVr1Llp/r2TJUuk4cOl2FgpJcX7c/hw6euv/Xs/AMIKpQkAANSLN5bu0OLNByVJyfFRmjS8p3/Hi7/6qjRwoPTBB5LH493m8XjX558vvfaa/+4LQFihNAEAgDq3fp9LT3+ywbd+fkSmGjWM9N8dLFkijRkjWZZUWXni7yorvdvvuINPnADUCKUJAADUqdIKt8ZOz1G52/vpzx8GtNWAjk38eyeTJ0sOR9X7OBzSlCn+vV8AYYHSBAAA6tSkTzZoU0GRJKlL83jde1m6f++gpESaN+/XnzD9UmWlNGcOwyEAVBulCQAA1JkvN+zXm0t3SJKinHa9MDpTUc7TfCJUXS7Xz99hOh2Px7s/AFQDpQkAANSJg0Vluvfdlb71X67soo7Jcf6/o/h4yX6Gb2nsdu/+AFANlCYAAOA/P54fyTp+XPe/u0oHi8olSYPSm+qm81rXzX3GxEhDh0pOZ9X7OZ3Stdd69weAaqA0AQCA2vvF+ZH+df5ILdiwX5LUJDZSz1yf4d/x4r+UnS253VXv43ZL48bVXQYAIYvSBAAAaucX50fa0rilnrjwd75fP9tgj5rGRdVthgEDpFdekWy2X3/i5HR6t7/yitS/f93mABCSKE0AAKDmfnF+pDKHU3ddfa/KIrwl6eYVH2jQPb+rn/Mj3XabtHix91C9n77jZLd714sXe38PADVwmoN/AQAAqvDT+ZF+HPf9/Pk3aV1ye0lSx4M7NWHhGz+fH6k+PuXp3997KSnxTsmLj+c7TABqjdIEAABq5qfzI/047ntJ6wz9vc91kqTIygpNe/85RVd6B0H4zo9UXwUmJoayBMBvODwPAADUzH+dH+lIdJzuvjLb96v7Fr2lrge2/7wv50cCEMQoTQAAoGZ+PD+SJemBy+5UQVxjSdKAHTn63XfzTtyX8yMBCGIcngcAAGrmx/Mjzdp+XJ+ke7+vlFji0vMfTZG3Sv3I6fQOY+BwOQBBitIEAABqbPutY/XIfw741pM+eVHJRYdP3InzIwEIchyeBwAAaqTC7VHWFoeOR3o/Qbph1X902eZlP+/A+ZEAhAhKEwAAqJFpn2/Wyt2FkqS2De16sGEB50cCEJI4PA8AAFTb8m2H9PLCLZIkp92mqb/tqwZpl0slb3B+JAAhh9IEAACqpbCkQtkzV8r6cdbDuEs6KSMt0bvg/EgAQhCH5wEAgGp5aN4a7TlaIknq3TZJt13Q3nAiAKhblCYAAHDG5ubs0bzcvZKkuGinpozKlMNuM5wKAOoWpQkAAJyRvMPH9eDcNb71U9f2UItEDsUDEPooTQAA4LQq3R6Nm5GrY2WVkqThvVro6oxUw6kAoH5QmgAAwGm9snCrvt95RJKUlhSjR4d2M5wIAOoPpQkAAFTph11HNG3BZkmS3SZNHZWpuOgIw6kAoP5QmgAAwCkVlVUqa3qu3B7vfPE/X9RRZ7dOMpwKAOoXpQkAAJzSo++v1a7DxyVJvVol6s8XdTCcCADqH6UJAACc1Eer9mnWit2SpIaRDk0b1UtOB28dAIQf/vIBAIBf2Xu0RBNmr/KtHx3aXa0aNzCYCADMoTQBAIATuD2WsmfmylXqHS9+Zc/muu6sFoZTAYA5lCYAAHCC1xdv0zfbDkuSmidE66lhPWSz2QynAgBzKE0AAMBnzZ5CPf+fjZI
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X1, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0]).reshape(2, 1)\n",
2024-04-18 09:44:59 +02:00
"theta, history = gradient_descent(cost, gradient, theta_start, X1, y, eps=10**-8)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n",
"print(f\"Koszt: {history[-1][0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Ten model ma duże **obciążenie** (**błąd systematyczny**, *bias*) zachodzi **niedostateczne dopasowanie** (*underfitting*)."
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 52,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
2024-04-18 09:44:59 +02:00
"name": "stdout",
"output_type": "stream",
"text": [
"Koszt: 0.05470339875188901\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABZJklEQVR4nO3dd3hUZeL28Xtm0kkjQBqEEHrvHQVUFDuIa0F01XWtuCuyuqv7/tRd3ZXVta2K2AUVFUUBK4ogvffeWwgk9IT0ZOa8f0wYiEKAkOSZ8v1c11zknJyJd3Z2MnPPc87z2CzLsgQAAAAAOCW76QAAAAAA4M0oTQAAAABQAUoTAAAAAFSA0gQAAAAAFaA0AQAAAEAFKE0AAAAAUAFKEwAAAABUgNIEAAAAABWgNAEAAABABShNAAAAAFABo6Vp9uzZuuaaa5ScnCybzabJkyeX+75lWXryySeVlJSk8PBwDRgwQFu2bDETFgAAAEBAMlqa8vLy1KFDB40ePfqU33/++ef16quv6s0339SiRYtUq1YtDRw4UIWFhTWcFAAAAECgslmWZZkOIUk2m02TJk3S4MGDJblHmZKTk/WXv/xFjzzyiCQpOztbCQkJGjt2rG6++WaDaQEAAAAEiiDTAU5nx44dyszM1IABAzz7YmJi1KNHDy1YsOC0pamoqEhFRUWebZfLpcOHD6tOnTqy2WzVnhsAAACAOZZl6dixY0pOTpbdXjUn1nltacrMzJQkJSQklNufkJDg+d6pjBo1Sv/85z+rNRsAAAAA75aenq4GDRpUyc/y2tJUWY8//rhGjhzp2c7OzlbDhg2Vnp6u6Ohog8kAAAAAVLecnBylpKQoKiqqyn6m15amxMRESVJWVpaSkpI8+7OystSxY8fT3i80NFShoaG/2R8dHU1pAgAAAAJEVV6a47XrNKWlpSkxMVHTp0/37MvJydGiRYvUq1cvg8kAAAAABBKjI025ubnaunWrZ3vHjh1auXKl4uLi1LBhQ40YMUL/+te/1KxZM6WlpemJJ55QcnKyZ4Y9AAAAAKhuRkvT0qVLddFFF3m2j1+LdPvtt2vs2LH661//qry8PN1zzz06evSoLrjgAk2dOlVhYWGmIgMAAAAIMF6zTlN1ycnJUUxMjLKzs7mmCQAAAPBz1fH+32uvaQIAAAAAb0BpAgAAAIAKUJoAAAAAoAKUJgAAAACoAKUJAAAAACpAaQIAAACAClCaAAAAAKAClCYAAAAAqAClCQAAAAAqQGkCAAAAgApQmgAAAACgApQmAAAAAKgApQkAAAAAKkBpAgAAAIAKUJoAAAAAoAKUJgAAAACoAKUJAAAAACpAaQIAAACAClCaAAAAAKAClCYAAAAAqAClCQAAAAAqQGkCAAAAgApQmgAAAACgApQmAAAAAKgApQkAAAAAKkBpAgAAAIAKUJoAAID/KSiQsrLc/wLAeaI0AQAA/zF3rjRkiBQZKSUmuv8dMkSaN890MgA+LMh0AAAAgPPldFna9Np7WjZuspY1aKWVd72pvJAwRZQUKrykSBFvLlD4tEyFp6UqIsShiBCHwsv+jQgJUliw48T+YPe+8HLbx48PksNuM/3rAqhhlCYAAOBzcgpLtHL3US3bdUTLdh3Rih0HledMki67//R3KpS0Ieu8/9shQXZ3wQp2KDIsSF0bxemy1gnq1aSOQoMc5/3zAXgfShMAAPBqlmUp/XCBlu467ClJm7KOybJOf5/QkiLVKchWfnCY8oPDVBwUUmV5iktdKi516ahKpGxpc1auPlm0W5GhQerfop4ua5Oo/i3qKTosuMr+mwDMojQBAACvUlTq1NqMHC3fdaSsKB3VwdyiCu+TcOyQumasV+eMDeqSsVGts7YrxFXq+b7TZldBcKjyQ8NVsGGz8u3Byi92qrDEqfxip/KLS1VQ7P66oMR50telZd8/+VinCord+4/ml6jY6ZIk5RaV6tvV+/Tt6n0KdtjUq0ldXdY6QZe2TlBCdFi1/m8GoHrZLKuiz2l8X05OjmJiYpSdna3o6GjTcQAAwK8czC3Ssl1HtLxsFGl1RraKS12nPd5uk1olRatLam33rZZT9Zun6qyvNMrMlBISqiR7fnGpZm8+qJ/WZ2r6hv3KLig55XEdU2J1WZsEXdY6UU3jI6vkvw3g1Krj/T+lCQAA1BiXy9KW/blaVjaKtHzXEe08lF/hfaLCgtS5YW1PSeqYEqtaoSedLFNQ4J4lz3X6ouVht0u5uVJ4+Hn+Jr9V6nRp8c7D+mldlqatz1LG0VNPd964Xi1d1jpRl7VJUMcGsbIzsQRQpShNlUBpAgCgBhUUSDk5UnR0uWLidFn6ZtVevTZji7YdyKvwRzSqE6HOqbXVNTVOXVJrq1l85JmLxZAh0jffSKWlpz8mKEgaNEiaOPFcfqNKsSxL6/bm6Kf1WfppXaY2Zh475XH1okJ1aesEJpIAqhClqRIoTQAA1IC5c6WXXpKmTHGP+Njt0qBBKn14pL6JbKTXZmzV9lOUpZAgu9rXj1GX1NrqXDaSVDcytHL//b59VeHsEDabNGeO1KfPuf/887T7UL5+Wp+pn9ZnaenOw3KdIiYTSQBVg9JUCZQmAACq2Zgx0vDhksPhGekptdn1dduL9XrP32l7XINyh3drVFuXtU5U59Taals/uupGV958U3rggXI5JLlHmJxO6Y03pPvuq5r/1nk4lFuk6Rv366d1WZqz5YCKTnH9FhNJAJVHaaoEShMAANXoVyM8pTa7prTur9d736QdcfXLHdojLU4jBjRXryZ1qi/PvHnSyy9LkyadGPG67jrp4YeNjDCdydlOJNE1tbYeGdhCPRtX4/92gJ+gNFUCpQkAgGpUdi1RqdOlyW366/VeN2tnXHK5Q3ruXqOHHHvU66PXay7Xaa6t8mYlTpeWnGEiiavaJ+nvV7ZS/Vjf+J0AEyhNlUBpAgCgmhQUqDQqWpNa9dPrvW/Srtrly1KvXav00LxP1TN9bbXOWuePTp5I4rvVe8tNnhEWbNd9/Zrovn5NFBbMxBHAr1GaKoHSBABA1StxujRp1ga9PmGBdtdOKve93jtX6aF5n6jHnnXl71SF6yMFEqfL0udL0/XfHzfpcF6xZ3/92HD9v6ta6Yq2ibLZmLYcOI7SVAmUJgAAqk6J06VJyzP02i9blH64/OljfXau1EPzPlX3X5cliZGmKpBdUKL//bxFHy7YqdKTpt/r2ThO/7i2jVom8j4HkChNlUJpAgDg/JU4Xfpq+R69/svW35SlC3at1ENzPlG3jPWnvnMNro8UCLbuP6Z/frNec7Yc9Oyz26Rbe6Zq5KXNFRsRYjAdYB6lqRIoTQAAVF5x6YmytOdI+bJ0YbO6eqhegboOvthr10fyV5Zladr6LP3ruw3afTjfsz82Ilh/ubS5hnZvqCCH3WBCwBxKUyVQmgAAOHfFpS59uXyPXp+x9TezuF3YrK5GDGimLqlx7h0+sj6SPyosceq9uTs0+petyi92eva3TIzSU9e0qd7p3QEvRWmqBEoTAABnr7jUpYnL9mj0L78tS32b19NDlzRTl9Tav72jj62P5G8yswv13NSNmrQio9z+K9sl6u9XtlKD2hGGkgE1j9JUCZQmAADOrNTp0oSl6Xrjl22/KUv9mtfTQwOaqXPDU5SlX/PB9ZH8ybJdh/WPr9drTUa2Z19o0IkpysNDmKIc/o/SVAmUJgAAKpaZXagHP1mupbuOlNvfv4V7ZKnT2ZQleA2Xy9IXy9L1/NRNOvSrKcofv7KlrmqXxBTl8GuUpkqgNAEAcHrztx7Unz9boYO5J95cX9Sinh4a0FwdU2LNBcN5yyks0as/b9HY+eWnKO+RFqenrmmj1sm8L4J/ojRVAqUJAIDfcrksjZm1TS/+tEnH308nx4Tpf0M7qVujOLPhUKW27s/V09+u1+zNBzz77Dbplh4N9ZdLW6h2LaYoh3+hNFUCpQkAgPKO5hdr5OerNGPjfs++fs3r6ZWbOvIG2k9ZlqXpG/brme/Wa9ehE1OUx4QHa+SlzTWsB1OUw39QmiqB0gQAwAmr9xzVA+OXe9Zcstm
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X2, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0]).reshape(3, 1)\n",
2024-04-18 09:44:59 +02:00
"theta, history = gradient_descent(cost, gradient, theta_start, X2, y, eps=10**-8)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n",
"print(f\"Koszt: {history[-1][0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model jest odpowiednio dopasowany."
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 53,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
2024-04-18 09:44:59 +02:00
"name": "stdout",
"output_type": "stream",
"text": [
"Koszt: 0.007232337911078077\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAH0CAYAAADhWca4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABiC0lEQVR4nO3dd3yV9d3G8euM7EkgExL23nspaEGpE7TViqNq3cWnVVtbbZ/Wx2pL1T7aR+totQp1j4q7LlAB2TvsTULIYGXvc+7njxMOiUBIyPid8Xm/Xnl5zsl9yIWHjCu/3/29bZZlWQIAAAAAnJTddAAAAAAA8GWUJgAAAABoBKUJAAAAABpBaQIAAACARlCaAAAAAKARlCYAAAAAaASlCQAAAAAaQWkCAAAAgEZQmgAAAACgEZQmAAAAAGiE0dK0cOFCXXLJJUpLS5PNZtN7773X4P2WZen3v/+9UlNTFRERoalTp2rHjh1mwgIAAAAISkZLU1lZmYYOHaqnn376pO9/9NFH9eSTT+q5557T8uXLFRUVpWnTpqmysrKdkwIAAAAIVjbLsizTISTJZrNp3rx5mjFjhiTPKlNaWpp+8Ytf6Je//KUkqaioSMnJyZozZ46uuuoqg2kBAAAABAun6QCnsmfPHuXl5Wnq1Knex+Li4jR27FgtXbr0lKWpqqpKVVVV3vtut1tHjhxRx44dZbPZ2jw3AAAAAHMsy1JJSYnS0tJkt7fOxjqfLU15eXmSpOTk5AaPJycne993MrNnz9aDDz7YptkAAAAA+Lbs7Gx16dKlVf4sny1NZ+r+++/XPffc471fVFSkjIwMZWdnKzY21mAyAAAQiP7r9TX6autBSdLcn4zWyK4JhhMBwa24uFjp6emKiYlptT/TZ0tTSkqKJCk/P1+pqanex/Pz8zVs2LBTPi8sLExhYWEnPB4bG0tpAgAAraq0qlZLsytkD4tUYkyYJg/sKrud0wEAX9Cap+b47HWaunfvrpSUFM2fP9/7WHFxsZYvX67x48cbTAYAAOCxYGuBqmvdkqTvD0yhMAEByuhKU2lpqXbu3Om9v2fPHq1bt04JCQnKyMjQXXfdpYcffli9e/dW9+7d9bvf/U5paWneCXsAAAAmfbox13v7gsEpBpMAaEtGS9OqVat07rnneu8fOxfp+uuv15w5c/SrX/1KZWVluvXWW1VYWKizzjpLn376qcLDw01FBgAAkCSVV9d6z2VKiArVmG6cywQEKp+5TlNbKS4uVlxcnIqKijinCQAAtJr/ZObqjlfXSJJmjknX7MuHGE4EQGqbn/999pwmAAAAX/bRhnpb8walNnIkAH9HaQIAAGimeWv36+NMT2mKiwjR+J4dDScC0JYoTQAAAM2wau8R/fqdTO/9e87roxAHP1IBgYzPcAAAgCbKPlKu215erWqXZ8z41WMz9OPxXQ2nAtDWKE0AAABNUFJZo5vmrtThsmpJ0sReHfXgpQNb9QKaAHwTpQkAAOA0al1u/dfra7U9v1SS1KNTlJ65eiTb8oAgwWc6AADAaTz88RZ9vc1zTaa4iBD984bRiosMMZwKQHuhNAEAADTi5WX7NGfJXkmS027Tc9eOVPdOUWZDAWhXlCYAAIBTWLTjoP7ng03e+3+8bBDjxYEgRGkCAAA4iZ0Fpfrpq2vkcluSpFsn9dCPRmcYTgXABEoTAADAdxwtq9ZNc1eqpLJWkjS1f7J+/f1+hlMBMIXSBAAAUE91rVu3vbJa+w6XS5L6p8bq/64aJoed0eJAsKI0AQAA1LEsS7+dl6kVe45IkjpFh+mF60cpKsxpOBkAkyhNAAAAdf6xcLfeXr1fkhTmtOv5H49U5/gIw6kAmEZpAgAAkPTZpjz9+dOt3vt/uWKohmd0MJgIgK+gNAEAgKC3MadId72xTpZnUJ7untpHlwxNMxsKgM+gNAEAgKBWUFypW/61ShU1LknSpUPT9LMpvQynAuBLKE0AACBoVVS7dMu/Vim3qFKSNDwjXo/+cIhsNiblATiO0gQAAIKS223pl2+v1/r9RZKkzvER+sd1oxQe4jCcDICvoTQBAICg9Ncvt+vjzFxJUlSoQy9cP0qJMWGGUwHwRZQmAAAQdN5bm6MnF+yUJNls0pMzh6t/aqzhVAB8FaUJAAAEldX7juhX72zw3v/thf01pX+ywUQAfB2lCQAABI3sI+W69V+rVe1yS5JmjknXTWd1N5wKgK+jNAEAgKBQUlmjm+eu0uGyaknShJ4d9Yfpg5iUB+C0KE0AACDgudyWfvb6Wm3LL5Ek9egUpWevGakQBz8KATg9vlIAAICA98ePt+irbQclSXERIfrnDaMVFxliOBUAf+E0HQAAAKDVVVRIxcVSbKxe3VCgF7/dI0ly2m169toR6t4pynBAAP6ElSYAABA4Fi+WLr9cio6WUlK0eOBE/f7d45PyHp4xSBN6djIYEIA/YqUJAAAEhmeflWbNkhwOye3WroTO+umlv5bL5vkd8S2xxbpqTIbhkAD8EStNAADA/y1e7ClMliXV1upoeIxu+sEDKg6PliRN2blC9/3uOunbbw0HBeCPKE0AAMD/Pf645HDIkrS9U4Zuu+w32puQJknqV7BH//fhY3LYbdITT5jNCcAvsT0PAAD4terSMi1ft0/zJ/9E83uNUXZ8ivd9nUqP6p/v/EHR1RWeB+bN8wyJiIgwlBaAP6I0AQAAv3O4tEpfbTuo+VvytWh7gUqv/MMJx0RVlesf8x5W55KDxx90uz1T9ShNAJqB0gQAAHyeZVnanl+q+VvzNX9LgdZkHZVlnXhciKtGY7M3asrOFbpw27dKLj3S8AC7XYqNbZ/QAAIGpQkAALSeetdHaulqTlWtS8t3H9GCrQX6cku+9h+tOOlxHSJDdG7eZk35Zp4m7VypmOqTHyenU5o+nVUmAM1GaQIAAC23eLFnGMP773u2wNntnoLyi19IEyc2+Y+pv+1u4faDKqt2nfS43knRmtI/WVP7J2l4Rgc5lkRKz92jky4/HeNySXff3dy/GQBQmgAAQAt95/pIkjz//fBD6b33pGeekW6//aRPPbbt7sst+Zq/JV9rswtPvu3OYdPY7h31vX5Jmto/WRkdIxsecNZZno/z0596ctTWHn+f0+kpTM8806wCBwDH2CyrsV/J+L/i4mLFxcWpqKhIsexhBgCgdS1eLE2a1PgKj80mLVrkLSzN2nbXN0lT+ifr7D6dFBsecvo8337rGSs+b97xFa/LLvOsMFGYgKDQFj//s9IEAADOXN31keqv7FiSisKjlRObqAOxicqNT1bOvxbpQFakcgsrtDWvRKVVtSf943olRWtKf89q0oiMDp5rKzXHxImet1Y8twoAKE0AAKBZKmtcyi2q1IH8Qh3YWaIDY36oA3UF6UBsog7EJKoiNPzEJ64/cMJDTrtNY3skaEq/ZE3pn6SuHaNaJ2REBGUJQKuhNAEAAC+329LB0iodKKzQgcJKz3+LKrz3c4sqdKi0+vgTLryr2R8jMSZMZ/XqpCn9kzSpT2LTtt0BgEGUJgAAglzW4XK9tSpbn2TmKvtouWpcZ366c0R1pVJLDqpz8UGlFR9UaskhpRUfVOeSQ0pd9KVSk+IVEepoxfQA0PYoTQAABKHKGpc+25SnN1dma8muw016jt0mJceGKy0+wvMWF660d15R2pKvlHo0X52LDyq+skQnnIV07PpIXTq2+t8DANoDpQkAgCCy+UCx3lyZpffWHVBRRU2D9znsNvVOilaXDhFKjasrRvHHS1JyTJicDnvDPzD2AmnObK6PBCCgUZoAAAhwxZU1+mDdAb25MluZOUUnvL97pyhdOSpdPxjZWUkxJxng0BiujwQgCFCaAAAIQJZlaeXeo3pjZZY+ycxVZY27wfvDQ+y6cHCqfjQqXWO6J8hma+Zo7/puv10aPPjE6yNNn871kQAEBEoTAAABpKCkUu+uydFbK7O1+1DZCe8f3DlOPxqdrkuHpbXu1DqujwQggFGaAADwc7Uut77ZflBvrszW/K0Fcrkbnl8UG+7UZcM768rR6RqYFte2Ybg+EoAARGkCAMBP7Tt
"text/plain": [
"<Figure size 960x540 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plot_data(X5, y, xlabel=\"x\", ylabel=\"y\")\n",
"theta_start = np.matrix([0, 0, 0, 0, 0, 0]).reshape(6, 1)\n",
2024-04-18 09:44:59 +02:00
"theta, history = gradient_descent(cost, gradient, theta_start, X5, y, alpha=0.5, eps=10**-8)\n",
"plot_fun(fig, polynomial_regression(theta), X1)\n",
"print(f\"Koszt: {history[-1][0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Ten model ma dużą **wariancję** (*variance*) zachodzi **nadmierne dopasowanie** (*overfitting*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"(Zwróć uwagę na dziwny kształt krzywej w lewej części wykresu to m.in. efekt nadmiernego dopasowania)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Nadmierne dopasowanie występuje, gdy model ma zbyt dużo stopni swobody w stosunku do ilości danych wejściowych.\n",
"\n",
"Jest to zjawisko niepożądane.\n",
"\n",
"Możemy obrazowo powiedzieć, że nadmierne dopasowanie występuje, gdy model zaczyna modelować szum/zakłócenia w danych zamiast ich „głównego nurtu”. "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zobacz też: https://pl.wikipedia.org/wiki/Nadmierne_dopasowanie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"90%\" src=\"fit.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Obciążenie (błąd systematyczny, *bias*)\n",
"\n",
"* Wynika z błędnych założeń co do algorytmu uczącego się.\n",
"* Duże obciążenie powoduje niedostateczne dopasowanie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Wariancja (*variance*)\n",
"\n",
"* Wynika z nadwrażliwości na niewielkie fluktuacje w zbiorze uczącym.\n",
"* Wysoka wariancja może spowodować nadmierne dopasowanie (modelując szum zamiast sygnału)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"40%\" src=\"bias2.png\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"<img style=\"margin:auto\" width=\"60%\" src=\"curves.jpg\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 6.3. Regularyzacja"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 41,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def SGD(\n",
" h,\n",
" fJ,\n",
" fdJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=0.001,\n",
" maxEpochs=1.0,\n",
" batchSize=100,\n",
" adaGrad=False,\n",
" logError=False,\n",
" validate=0.0,\n",
" valStep=100,\n",
" lamb=0,\n",
" trainsetsize=1.0,\n",
"):\n",
" \"\"\"Stochastic Gradient Descent - stochastyczna wersja metody gradientu prostego\n",
" (więcej na ten temat na następnym wykładzie)\n",
" \"\"\"\n",
" errorsX, errorsY = [], []\n",
" errorsVX, errorsVY = [], []\n",
"\n",
" XT, YT = X, Y\n",
"\n",
" m_end = int(trainsetsize * len(X))\n",
"\n",
" if validate > 0:\n",
" mv = int(X.shape[0] * validate)\n",
" XV, YV = X[:mv], Y[:mv]\n",
" XT, YT = X[mv:m_end], Y[mv:m_end]\n",
" m, n = XT.shape\n",
"\n",
" start, end = 0, batchSize\n",
" maxSteps = (m * float(maxEpochs)) / batchSize\n",
"\n",
" if adaGrad:\n",
" hgrad = np.matrix(np.zeros(n)).reshape(n, 1)\n",
"\n",
" for i in range(int(maxSteps)):\n",
" XBatch, YBatch = XT[start:end, :], YT[start:end, :]\n",
"\n",
" grad = fdJ(h, theta, XBatch, YBatch, lamb=lamb)\n",
" if adaGrad:\n",
" hgrad += np.multiply(grad, grad)\n",
" Gt = 1.0 / (10**-7 + np.sqrt(hgrad))\n",
" theta = theta - np.multiply(alpha * Gt, grad)\n",
" else:\n",
" theta = theta - alpha * grad\n",
"\n",
" if logError:\n",
" errorsX.append(float(i * batchSize) / m)\n",
" errorsY.append(fJ(h, theta, XBatch, YBatch).item())\n",
" if validate > 0 and i % valStep == 0:\n",
" errorsVX.append(float(i * batchSize) / m)\n",
" errorsVY.append(fJ(h, theta, XV, YV).item())\n",
"\n",
" if start + batchSize < m:\n",
" start += batchSize\n",
" else:\n",
" start = 0\n",
" end = min(start + batchSize, m)\n",
" return theta, (errorsX, errorsY, errorsVX, errorsVY)\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 42,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przygotowanie danych do przykładu regularyzacji\n",
"\n",
"n = 6\n",
"\n",
"data = np.matrix(np.loadtxt(\"ex2data2.txt\", delimiter=\",\"))\n",
"np.random.shuffle(data)\n",
"\n",
"X = powerme(data[:, 0], data[:, 1], n)\n",
"Y = data[:, 2]\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 43,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def draw_regularization_example(\n",
" X, Y, lamb=0, alpha=1, adaGrad=True, maxEpochs=2500, validate=0.25\n",
"):\n",
" \"\"\"Rusuje przykład regularyzacji\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" plt.subplot(121)\n",
" plt.scatter(\n",
" X[:, 2].tolist(),\n",
" X[:, 1].tolist(),\n",
" c=Y.tolist(),\n",
" s=100,\n",
" cmap=plt.cm.get_cmap(\"prism\"),\n",
" )\n",
"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=alpha,\n",
" adaGrad=adaGrad,\n",
" maxEpochs=maxEpochs,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=validate,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
"\n",
" xx, yy = np.meshgrid(np.arange(-1.5, 1.5, 0.02), np.arange(-1.5, 1.5, 0.02))\n",
" l = len(xx.ravel())\n",
" C = powerme(xx.reshape(l, 1), yy.reshape(l, 1), n)\n",
" z = classifyBi(thetaBest, C).reshape(int(np.sqrt(l)), int(np.sqrt(l)))\n",
"\n",
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
" plt.ylim(-1, 1.2)\n",
" plt.xlim(-1, 1.2)\n",
" plt.legend()\n",
" plt.subplot(122)\n",
" plt.plot(err[0], err[1], lw=3, label=\"Training error\")\n",
" if validate > 0:\n",
" plt.plot(err[2], err[3], lw=3, label=\"Validation error\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 44,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-18 09:44:59 +02:00
"/tmp/ipykernel_868/2651435526.py:12: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
" cmap=plt.cm.get_cmap(\"prism\"),\n",
"/tmp/ipykernel_868/2678993393.py:5: RuntimeWarning: overflow encountered in exp\n",
2022-11-28 11:52:13 +01:00
" y = 1.0 / (1.0 + np.exp(-x))\n",
2024-04-18 09:44:59 +02:00
"/tmp/ipykernel_868/2651435526.py:38: UserWarning: The following kwargs were not used by contour: 'lw'\n",
2022-11-28 11:52:13 +01:00
" plt.contour(xx, yy, z, levels=[0.5], lw=3)\n",
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
]
},
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABSAAAAKZCAYAAACod4UiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeZzN9RfH8de9dxYzwwzZ933PvpOtZE2URJHsWkikkC0kKSFLlF3ZUmlBloQfkjVC9l3Gzox1lnvv749vhjGDe+/cO3eW9/PxuA/u/X4/38+5Zszce+75fI7JbrfbEREREREREREREfEAs7cDEBERERERERERkZRLCUgRERERERERERHxGCUgRURERERERERExGOUgBQRERERERERERGPUQJSREREREREREREPEYJSBEREREREREREfEYJSBFRERERERERETEY5SAFBEREREREREREY9RAlJEREREREREREQ8RglIERERERERERER8RglIEVEREREgEmTJpEvXz7SpElDlSpV2LJly0PPHzduHEWLFiUgIIDcuXPTq1cvbt++nUjRioiIiCQfSkCKiIiISKq3cOFCevfuzZAhQ9ixYwdlypShQYMGnD9/Pt7z582bR79+/RgyZAj79u1j+vTpLFy4kPfffz+RIxcRERFJ+kx2u93u7SBERERERLypSpUqVKpUiYkTJwJgs9nInTs3PXr0oF+/fnHO7969O/v27WP16tUxj73zzjts3ryZDRs2JFrcIiIiIsmBj7cD8AabzcaZM2dIly4dJpPJ2+GIiIiIOM1ut3Pt2jVy5MiB2axFLQkRGRnJ9u3b6d+/f8xjZrOZevXqsWnTpnjHVK9enW+++YYtW7ZQuXJljh49yrJly3jllVfiPT8iIoKIiIiY+zabjcuXL5MxY0a9HhUREZFkyZnXo6kyAXnmzBly587t7TBEREREEuzUqVPkypXL22EkaxcvXsRqtZI1a9ZYj2fNmpX9+/fHO+bll1/m4sWLPPHEE9jtdqKjo3nttdceuAR75MiRDB061O2xi4iIiHibI69HU2UCMl26dIDxDxQcHOzlaMTbvh+3lDlDFuIf6M/oNUPIU0xv4u74uN14Nv20jZf6P0frfs95OxwREblHeHg4uXPnjnldI4lr7dq1fPTRR3zxxRdUqVKFw4cP07NnT4YPH86gQYPinN+/f3969+4dcz8sLIw8efJ49PXooY9rUth+LOb+nnJDeLxhF4/MJSIiIqmPM69HU2UC8s4yl+DgYCUghVcHtmLf+kP8tXo3n7z8BRP+/IiQTPq+sEZb+WfdIXxMvtRsVl3/V0REkigt3024TJkyYbFYOHfuXKzHz507R7Zs2eIdM2jQIF555RU6d+4MQKlSpbhx4wZdu3ZlwIABcZYh+fv74+/vH+c6nnw9mjaNhWDb3e+PtIEB+n0uIiIibufI61FtGCSpntlsZsD8t8mWPwuhR88x9IXRREVGeTssr9u3+RA3wm6S7rG0FKlYwNvhiIiIeIyfnx8VKlSI1VDGZrOxevVqqlWrFu+YmzdvxkkyWiwWwNgPSURERETuUgJSBAjJFMyHv/QjMDiA3f/bx+evTU31bx62Ld8JQIX6ZWLeUImIiKRUvXv3ZurUqcyePZt9+/bx+uuvc+PGDTp06ABAu3btYjWpadq0KZMnT2bBggUcO3aMVatWMWjQIJo2barfmyIiIiL3SZVLsEXik7dEbgYu6MXAZ0ayYtYaCpXPT/PujbwdltdsXbETgIr1y3g3EBERkUTQqlUrLly4wODBgzl79ixly5Zl+fLlMY1pTp48GaviceDAgZhMJgYOHMi///5L5syZadq0KSNGjPDWUxARERFJskz2VFjmFR4eTkhICGFhYdoHR+L4fuwSprwzG18/H8b/+RGFyub3dkiJLuxiOC2zdsZutzP/9JdkyvGYt0MSEUlxbDYbkZGRDzzu6+v70Eo6vZ5J3hLj63dgWHmK2o7E3P+74khKP/OGR+YSEZHEZ7VaiYrS9mHiOe58PaoKSJH7PP92E3at28umn7fx0UvjmLT1YwLSBng7rES1fdXf2O12CpTOq+SjiIgHREZGcuzYMWw220PPS58+PdmyZVOjGXGLVFd1ICKSQtntds6ePcvVq1e9HYqkAu56PaoEpMh9TCYTfaa/QbeyfTh14Ayj2k1g0KJ3UtV+Ttu0/FpExGPsdjuhoaFYLBZy584dp5HJnXNu3rzJ+fPnAciePXtihykiIiJJ1J3kY5YsWQgMDNQHleIR7n49qgSkSDyCM6ZjwIJevFdvGBt/3MqUXrN54/MOqeIHu91uZ/vKXQBUbFjWu8GIiKRA0dHR3Lx5kxw5chAYGPjA8wICjOr78+fPkyVLllT1QZiIiIjEz2q1xiQfM2bM6O1wJIVz5+tRdcEWeYDHaxSj7+zuAPw48Ve+G7PEyxEljqN/n+Dy2aukCfSnZI1i3g5HRCTFsVqtAPj5+T3y3DsJSu3vJG6R+rZ+FxFJce68JnjYh5gi7uSu16NKQIo8RO0Xq9P103YATOv7NTvX7PFyRJ53p/qxTN2S+Pn7ejkaEZGUy5Gq+tRQeS+eY0ffPyIiKZVeI0hicdf3mhKQIo/wQu9naNC+LjabnY9eHsfls1e8HZJH/fX7bgDKP1Xay5GIiIiIiIiISEqgBKTII5hMJrpP7ES+krm5ci6MkW3HxyyfS2miIqPYs34/AOWeetzL0YiIiIiIiIg8WL58+Rg3bpzD569duxaTyaQO4l6gBKSIA9IE+jPw296kCfJn5+97mNF/nrdD8ogdq/7m9s0IHsuegXyP5/F2OCIiIiIiIpICmEymh94++OADl667detWunbt6vD51atXJzQ0lJCQEJfmE9cpASnioLzFc9Hry24AfDv6ZxaN/tnLEbnfukWbAKjVoqr2FBEREUnm9JtcRESSitDQ0JjbuHHjCA4OjvVYnz59Ys612+1ER0c7dN3MmTM71ZDHz8+PbNmyJcn3u5GRkXEes1qt2Gw2p6/l6jhPUgJSxAlPvlyTTiPbAPDVe1/z+7z1Xo7IfSIjotj44xYAar9YzcvRiIikfHYHOhIntReOIiIikrTYbHYuXY/w2s1me/TrGYBs2bLF3EJCQjCZTDH39+/fT7p06fj111+pUKEC/v7+bNiwgSNHjtCsWTOyZs1K2rRpqVSpEr/99lus696/BNtkMjFt2jSee+45AgMDKVy4MD//fLd46P4l2LNmzSJ9+vSsWLGC4sWLkzZtWho2bEhoaGjMmOjoaN566y3Sp09PxowZ6du3L6+++irNmzd/6HPesGEDNWvWJCAggNy5c/PWW29x48aNWLEPHz6cdu3aERwcTNeuXWPi+fnnnylRogT+/v6cPHmSK1eu0K5dOzJkyEBgYCCNGjXi0KFDMdd60LikxMfbAYgkN63ea0bYhXC+G/MLY7pMIX+pPOQvldfbYSXYrjV7uBl+i4w5MlCielFvhyMikmL5+vpiMpm4cOECmTNnjvcTeLvdTmRkJBcuXMBsNuPn5+eFSEVERCSpu3Izkgof/vboEz1k+8B6ZEzr75Zr9evXj9GjR1OgQAEyZMjAqVOnaNy4MSNGjMDf3585c+bQtGlTDhw4QJ48D94ybOjQoXzyySd8+umnTJgwgTZt2nDixAkee+yxeM+/efMmo0eP5uuvv8ZsNtO2bVv69OnD3LlzARg1ahRz585l5syZFC9enM8//5wff/yRunXrPjCGI0eO0LBhQz788ENmzJjBhQsX6N69O927d2fmzJkx540ePZrBgwczZMgQANavX8/NmzcZNWoU06ZNI2PGjGTJkoWXXnqJQ4cO8fPPPxMcHEzfvn1p3Lgx//zzD76+vjHP4/5xSYkSkCJOMplMdB7VhmN7TrJ95S6GvvAZk7Z+TFCw42XfSdGmX7YDULVJBcxmLxdH798PkyfDunVw7RoEB0PduvDaa1CkiHdjExFJIIvFQq5cuTh9+jTHjx9/6LmBgYHkyZPH+z+XRURERDxs2LBhPP300zH3H3vsMcqUKRNzf/jw4SxevJiff/6Z7t27P/A
"text/plain": [
"<Figure size 1600x800 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draw_regularization_example(X, Y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Regularyzacja\n",
"\n",
"Regularyzacja jest metodą zapobiegania zjawisku nadmiernego dopasowania (*overfitting*) poprzez odpowiednie zmodyfikowanie funkcji kosztu.\n",
"\n",
"Do funkcji kosztu dodawane jest specjalne wyrażenie (**wyrażenie regularyzacyjne** zaznaczone na czerwono w poniższych wzorach), będące „karą” za ekstremalne wartości parametrów $\\theta$.\n",
"\n",
"W ten sposób preferowane są wektory $\\theta$ z mniejszymi wartosciami parametrów mają automatycznie niższy koszt.\n",
"\n",
"Jak silną regularyzację chcemy zastosować? Możemy o tym zadecydować, dobierajac odpowiednio **parametr regularyzacji** $\\lambda$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Przedstawiona tu metoda regularyzacji to tzw. metoda L2 (*ridge*). Istnieją również inne metody regularyzacji, które charakteryzują się trochę innymi własnościami, np. L2 (*lasso*) lub *elastic net*. Więcej na ten temat można przeczytać np. tu:\n",
"* [L1 and L2 Regularization Methods](https://towardsdatascience.com/l1-and-l2-regularization-methods-ce25e7fc831c)\n",
"* [Ridge and Lasso Regression: L1 and L2 Regularization](https://towardsdatascience.com/ridge-and-lasso-regression-a-complete-guide-with-python-scikit-learn-e20e34bcbf0b)\n",
"* [Elastic Net Regression](https://towardsdatascience.com/elastic-net-regression-from-sklearn-to-tensorflow-3b48eee45e91)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej funkcja kosztu\n",
"\n",
"$$\n",
"J(\\theta) \\, = \\, \\dfrac{1}{2m} \\left( \\displaystyle\\sum_{i=1}^{m} \\left( h_\\theta(x^{(i)}) - y^{(i)} \\right)^2 \\color{red}{ + \\lambda \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* $\\lambda$ parametr regularyzacji\n",
"* jeżeli $\\lambda$ jest zbyt mały, skutkuje to nadmiernym dopasowaniem\n",
"* jeżeli $\\lambda$ jest zbyt duży, skutkuje to niedostatecznym dopasowaniem"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji liniowej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej funkcja kosztu\n",
"\n",
"$$\n",
"\\begin{array}{rtl}\n",
"J(\\theta) & = & -\\dfrac{1}{m} \\left( \\displaystyle\\sum_{i=1}^{m} y^{(i)} \\log h_\\theta(x^{(i)}) + \\left( 1-y^{(i)} \\right) \\log \\left( 1-h_\\theta(x^{(i)}) \\right) \\right) \\\\\n",
"& & \\color{red}{ + \\dfrac{\\lambda}{2m} \\displaystyle\\sum_{j=1}^{n} \\theta^2_j } \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Regularyzacja dla regresji logistycznej gradient\n",
"\n",
"$$\\small\n",
"\\begin{array}{llll}\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_0} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_0 & \\textrm{dla $j = 0$ }\\\\\n",
"\\dfrac{\\partial J(\\theta)}{\\partial \\theta_j} &=& \\dfrac{1}{m}\\displaystyle\\sum_{i=1}^m \\left( h_{\\theta}(x^{(i)})-y^{(i)} \\right) x^{(i)}_j \\color{red}{+ \\dfrac{\\lambda}{m}\\theta_j} & \\textrm{dla $j = 1, 2, \\ldots, n $} \\\\\n",
"\\end{array} \n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementacja metody regularyzacji"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 45,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def J_(h, theta, X, y, lamb=0):\n",
" \"\"\"Funkcja kosztu z regularyzacją\"\"\"\n",
" m = float(len(y))\n",
" f = h(theta, X, eps=10**-7)\n",
" j = 1.0 / m * -np.sum(\n",
" np.multiply(y, np.log(f)) + np.multiply(1 - y, np.log(1 - f)), axis=0\n",
" ) + lamb / (2 * m) * np.sum(np.power(theta[1:], 2))\n",
" return j\n",
"\n",
"\n",
"def dJ_(h, theta, X, y, lamb=0):\n",
" \"\"\"Gradient funkcji kosztu z regularyzacją\"\"\"\n",
" m = float(y.shape[0])\n",
" g = 1.0 / y.shape[0] * (X.T * (h(theta, X) - y))\n",
" g[1:] += lamb / m * theta[1:]\n",
" return g\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 46,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"slider_lambda = widgets.FloatSlider(\n",
" min=0.0, max=0.5, step=0.005, value=0.01, description=r\"$\\lambda$\", width=300\n",
")\n",
"\n",
"\n",
"def slide_regularization_example_2(lamb):\n",
" draw_regularization_example(X, Y, lamb=lamb)\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 47,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-04-18 09:44:59 +02:00
"model_id": "2485672f2aee49bbba398c5847212626",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.01, description='$\\\\lambda$', max=0.5, step=0.005), Button(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.slide_regularization_example_2(lamb)>"
]
},
2024-04-18 09:44:59 +02:00
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"widgets.interact_manual(slide_regularization_example_2, lamb=slider_lambda)\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 48,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_lambda_fun(lamb):\n",
" \"\"\"Koszt w zależności od parametru regularyzacji lambda\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=lamb,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_cost_lambda():\n",
" \"\"\"Wykres kosztu w zależności od parametru regularyzacji lambda\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" Lambda = np.arange(0.0, 1.0, 0.01)\n",
" Costs = [cost_lambda_fun(lamb) for lamb in Lambda]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(Lambda, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(Lambda, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(r\"$\\lambda$\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n",
" plt.ylim(0.2, 0.8)\n"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 49,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKtCAYAAACuZBksAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB7t0lEQVR4nOz9eXycZ2Hv/X9HshbLsiTvdrzEzkLiQOKQFQcoAcIJhzalKynQJuSBdKFw4LicB/IACUtbl1PgpG1oaVmbFh5yDm0p/ZEHCi6hJIQEEhKyOM5mx6skb5Isy9Y6vz9GliwviWVrLNt6v1+veY3m1tz3XBPiQD5c130VisViMQAAAAAAY6xivAcAAAAAAJyaxEcAAAAAoCzERwAAAACgLMRHAAAAAKAsxEcAAAAAoCzERwAAAACgLMRHAAAAAKAsxEcAAAAAoCzERwAAAACgLMRHAAAAAKAsxj0+fuYzn8nixYtTW1ubyy+/PPfff//zvv/WW2/NOeeck8mTJ2fhwoX57//9v2fv3r3HabQAAAAAwJEa1/h4xx13ZMWKFbnlllvy4IMPZtmyZbn66qvT2tp6yPd/9atfzQc+8IHccsstWb16db7whS/kjjvuyP/z//w/x3nkAAAAAMALKRSLxeJ4ffjll1+eSy+9NLfddluSZGBgIAsXLsy73/3ufOADHzjo/e9617uyevXqrFq1aujYH/3RH+W+++7L3XfffdzGDQAAAAC8sEnj9cE9PT154IEHctNNNw0dq6ioyFVXXZV77733kOdcccUV+cd//Mfcf//9ueyyy/Lss8/mzjvvzO/8zu8c9nO6u7vT3d099HpgYCA7duzIjBkzUigUxu4LAQAAAMAEUCwWs2vXrpx22mmpqHj+hdXjFh+3bduW/v7+zJkzZ8TxOXPm5IknnjjkOW95y1uybdu2vOIVr0ixWExfX19+//d//3mXXa9cuTIf/ehHx3TsAAAAADDRbdiwIQsWLHje94xbfDwad911V/70T/80f/3Xf53LL788Tz/9dN7znvfk4x//eD784Q8f8pybbropK1asGHrd3t6eRYsWZcOGDWloaDheQwcAAACAU0JHR0cWLlyYqVOnvuB7xy0+zpw5M5WVlWlpaRlxvKWlJXPnzj3kOR/+8IfzO7/zO3nHO96RJDn//POze/fu/O7v/m4++MEPHnKaZ01NTWpqag463tDQID4CAAAAwFE6klsajttu19XV1bn44otHbB4zMDCQVatWZfny5Yc8p6ur66DAWFlZmaS01hwAAAAAOHGM67LrFStW5Prrr88ll1ySyy67LLfeemt2796dG264IUly3XXXZf78+Vm5cmWS5JprrsmnP/3pvPSlLx1adv3hD38411xzzVCEBAAAAABODOMaH6+99tps3bo1N998c5qbm3PhhRfm29/+9tAmNOvXrx8x0/FDH/pQCoVCPvShD2XTpk2ZNWtWrrnmmvzJn/zJeH0FAAAAAOAwCsUJtl65o6MjjY2NaW9vd89HAAAA4KRSLBbT19eX/v7+8R4Kp7iqqqrDrjQeTV87qXa7BgAAAJioenp6smXLlnR1dY33UJgACoVCFixYkPr6+mO6jvgIAAAAcIIbGBjI2rVrU1lZmdNOOy3V1dVHtNMwHI1isZitW7dm48aNOfvss49prxXxEQAAAOAE19PTk4GBgSxcuDB1dXXjPRwmgFmzZmXdunXp7e09pvhY8cJvAQAAAOBEsP/GvFBOYzWz1t+xAAAAAEBZiI8AAAAAQFmIjwAAAACcNBYvXpxbb731iN9/1113pVAopK2trWxj4vBsOAMAAABA2Vx55ZW58MILRxUMn89PfvKTTJky5Yjff8UVV2TLli1pbGwck89ndMRHAAAAgJPMwEAxO7t6xnUM0+qqU1ExNpuSFIvF9Pf3Z9KkF05Vs2bNGtW1q6urM3fu3KMdWln19PSkurp6xLH+/v4UCoVRby50tOeVm/gIAAAAcJLZ2dWTi//4e+M6hgc+dFVm1Nc873ve9ra35Qc/+EF+8IMf5C/+4i+SJGvXrs26devy6le/OnfeeWc+9KEP5ZFHHsm///u/Z+HChVmxYkV+/OMfZ/fu3Vm6dGlWrlyZq666auiaixcvznvf+968973vTVLalflzn/tcvvWtb+U73/lO5s+fn0996lP55V/+5SSlZdevfvWrs3PnzjQ1NeXLX/5y3vve9+aOO+7Ie9/73mzYsCGveMUr8qUvfSnz5s1LkvT19WXFihW5/fbbU1lZmXe84x1pbm5Oe3t7vvGNbxz2+95999256aab8tOf/jQzZ87Mr/7qr2blypVDMzUXL16ct7/97XnqqafyjW98I7/2a7+WK6+8Mu9973tz++235wMf+ECefPLJPP3002lsbMx73vOe/Nu//Vu6u7vzqle9Kn/5l3+Zs88+O0mGvseB5y1evPho/uMsmxMrhQIAAABwyviLv/iLLF++PDfeeGO2bNmSLVu2ZOHChUO//8AHPpA/+7M/y+rVq3PBBReks7Mzb3jDG7Jq1ar87Gc/y+tf//pcc801Wb9+/fN+zkc/+tG86U1vys9//vO84Q1vyFvf+tbs2LHjsO/v6urKJz/5yfzDP/xD/vM//zPr16/P+973vqHff+ITn8hXvvKVfOlLX8o999yTjo6O542OSfLMM8/k9a9/fX791389P//5z3PHHXfk7rvvzrve9a4R7/vkJz+ZZcuW5Wc/+1k+/OEPD43nE5/4RD7/+c/nsccey+zZs/O2t70tP/3pT/PNb34z9957b4rFYt7whjekt7d3xPc48LwTjZmPAAAAAJRFY2NjqqurU1dXd8ilzx/72Mfyute9buj19OnTs2zZsqHXH//4x/Mv//Iv+eY3v3lQxNvf2972trz5zW9Okvzpn/5p/vIv/zL3339/Xv/61x/y/b29vfnsZz+bM888M0nyrne9Kx/72MeGfv9Xf/VXuemmm/Krv/qrSZLbbrstd9555/N+15UrV+atb33r0IzMs88+O3/5l3+ZV73qVfmbv/mb1NbWJkle85rX5I/+6I+GzvvhD3+Y3t7e/PVf//XQd3/qqafyzW9+M/fcc0+uuOKKJMlXvvKVLFy4MN/4xjfym7/5m0PfY//zTkTiIwAAAADj4pJLLhnxurOzMx/5yEfyrW99K1u2bElfX1/27NnzgjMfL7jggqGfp0yZkoaGhrS2th72/XV1dUPhMUnmzZs39P729va0tLTksssuG/p9ZWVlLr744gwMDBz2mg8//HB+/vOf5ytf+crQsWKxmIGBgaxduzZLly495HdOSvel3P87rF69OpMmTcrll18+dGzGjBk555xzsnr16sOedyISHwEAAABOMtPqqvPAh6564TeWeQzH6sBdq9/3vvflu9/9bj75yU/mrLPOyuTJk/Mbv/Eb6el5/s11qqqqRrwuFArPGwoP9f5isTjK0Y/U2dmZ3/u938t/+2//7aDfLVq0aOjnQ+3UPXny5BQKo9+852jPO57ERwAAAICTTEVF4QU3ezlRVFdXp7+//4jee8899+Rtb3vb0HLnzs7OrFu3royjO1hjY2PmzJmTn/zkJ/mFX/iFJKWdpB988MFceOGFhz3voosuyuOPP56zzjrrmMewdOnS9PX15b777htadr19+/asWbMm55133jFf/3iy4QwAAAAAZbN48eLcd999WbduXbZt2/a8MxLPPvvs/PM//3MeeuihPPzww3nLW97yvO8vl3e/+91ZuXJl/vVf/zVr1qzJe97znuzcufN5Zxm+//3vz49+9KO8613vykMPPZSnnnoq//qv//q896o8nLPPPjtvfOMbc+ONN+buu+/Oww8/nN/+7d/O/Pnz88Y3vvFYvtpxJz4CAAAAUDbve9/7UllZmfPOOy+zZs163vs3fvrTn860adNyxRVX5JprrsnVV1+diy666DiOtuT9739/3vzmN+e6667L8uXLU19fn6uvvnpo05hDueCCC/KDH/wgTz75ZF75ylfmpS99aW6++eacdtppRzWGL33pS7n44ovzS7/0S1m+fHmKxWLuvPPOg5a
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_cost_lambda()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 5.4. Krzywa uczenia się"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"* Krzywa uczenia pozwala sprawdzić, czy uczenie przebiega poprawnie.\n",
"* Krzywa uczenia to wykres zależności między wielkością zbioru treningowego a wartością funkcji kosztu.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze treningowym rośnie.\n",
"* Wraz ze wzrostem wielkości zbioru treningowego wartość funkcji kosztu na zbiorze walidacyjnym maleje."
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 32,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def cost_trainsetsize_fun(m):\n",
" \"\"\"Koszt w zależności od wielkości zbioru uczącego\"\"\"\n",
" theta = np.matrix(np.zeros(X.shape[1])).reshape(X.shape[1], 1)\n",
" thetaBest, err = SGD(\n",
" h,\n",
" J,\n",
" dJ,\n",
" theta,\n",
" X,\n",
" Y,\n",
" alpha=1,\n",
" adaGrad=True,\n",
" maxEpochs=2500,\n",
" batchSize=100,\n",
" logError=True,\n",
" validate=0.25,\n",
" valStep=1,\n",
" lamb=0.01,\n",
" trainsetsize=m,\n",
" )\n",
" return err[1][-1], err[3][-1]\n",
"\n",
"\n",
"def plot_learning_curve():\n",
" \"\"\"Wykres krzywej uczenia się\"\"\"\n",
" plt.figure(figsize=(16, 8))\n",
" ax = plt.subplot(111)\n",
" M = np.arange(0.3, 1.0, 0.05)\n",
" Costs = [cost_trainsetsize_fun(m) for m in M]\n",
" CostTrain = [cost[0] for cost in Costs]\n",
" CostCV = [cost[1] for cost in Costs]\n",
" plt.plot(M, CostTrain, lw=3, label=\"training error\")\n",
" plt.plot(M, CostCV, lw=3, label=\"validation error\")\n",
" ax.set_xlabel(\"trainset size\")\n",
" ax.set_ylabel(\"cost\")\n",
" plt.legend()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Krzywa uczenia a obciążenie i wariancja\n",
"\n",
"Wykreślenie krzywej uczenia pomaga diagnozować nadmierne i niedostateczne dopasowanie:\n",
"\n",
"<img width=\"100%\" src=\"learning-curves.png\"/>\n",
"\n",
"Źródło: http://www.ritchieng.com/machinelearning-learning-curve"
]
},
{
"cell_type": "code",
2024-04-18 09:44:59 +02:00
"execution_count": 54,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2024-04-18 09:44:59 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAKnCAYAAAAP/zpKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACTSUlEQVR4nOzdeXicZb3/8c8smcm+NU3SJV1ok1JK99Ja9qUlgFZQEQS0gCwuB9RTOSw/FUQO9iCgIAXZlIrCAQQEPGAXCmWpyF4o0Dbpmm5p02ZfJ5mZ3x9PMslMksk2M8/M5P26rrmS5362b7Bg++n3vm+L1+v1CgAAAAAAAABCzGp2AQAAAAAAAADiE+EjAAAAAAAAgLAgfAQAAAAAAAAQFoSPAAAAAAAAAMKC8BEAAAAAAABAWBA+AgAAAAAAAAgLwkcAAAAAAAAAYUH4CAAAAAAAACAs7GYXEGkej0f79+9XWlqaLBaL2eUAAAAAAAAAMcXr9aqurk6jR4+W1Rq8t3HYhY/79+9XQUGB2WUAAAAAAAAAMW3Pnj0aO3Zs0GuGXfiYlpYmyfiHk56ebnI1AAAAAAAAQGypra1VQUGBL2cLZtiFjx1TrdPT0wkfAQAAAAAAgEHqz5KGbDgDAAAAAAAAICwIHwEAAAAAAACEBeEjAAAAAAAAgLAYdms+AgAAAAAAxCqv16u2tja53W6zS0GcS0hIkM1mG/JzCB8BAAAAAABigMvl0oEDB9TY2Gh2KRgGLBaLxo4dq9TU1CE9h/ARAAAAAAAgynk8Hu3cuVM2m02jR4+Ww+Ho107DwGB4vV5VVFRo7969KiwsHFIHJOEjAAAAAABAlHO5XPJ4PCooKFBycrLZ5WAYGDlypHbt2qXW1tYhhY9sOAMAAAAAABAjrFaiHERGqDprTf0V++abb2rJkiUaPXq0LBaLXnjhhX7fu2HDBtntds2aNSts9QEAAAAAAAAYPFPDx4aGBs2cOVP333//gO6rrq7W0qVLdcYZZ4SpMgAAAAAAAESjCRMm6J577un39evXr5fFYlF1dXXYakLvTF3z8eyzz9bZZ5894Pu+//3v6+KLL5bNZhtQtyQAAAAAAAAi69RTT9WsWbMGFBgG8/777yslJaXf1x9//PE6cOCAMjIyQvJ+DEzMLRTw2GOPaceOHbrlllv6dX1LS4tqa2v9PgAAAAAAAIgeXq9XbW1t/bp25MiRA9p0x+FwKD8/Pyp3B3e5XN3G3G63PB7PgJ812PvCLabCx9LSUt14443661//Kru9f02by5cvV0ZGhu9TUFAQ5ioBAAAAAADCy+Px6kh9i6kfj8fbZ52XXXaZ3njjDd17772yWCyyWCzatWuXbyr0P//5T82dO1dOp1Nvv/22tm/frnPPPVd5eXlKTU3Vcccdp1dffdXvmYHTri0Wix599FF97WtfU3JysgoLC/XSSy/5zgdOu165cqUyMzO1evVqTZ06VampqTrrrLN04MAB3z1tbW360Y9+pMzMTI0YMUI33HCDLr30Up133nlBf963335bJ510kpKSklRQUKAf/ehHamho8Kv9tttu09KlS5Wenq6rr77aV89LL72kY445Rk6nU2VlZaqqqtLSpUuVlZWl5ORknX322SotLfU9q7f7oo2p064Hwu126+KLL9att96qoqKift930003admyZb7j2tpaAkgAAAAAABDTqhpdmvvfr/Z9YRh9+PNFGpHqDHrNvffeq5KSEh177LH61a9+JcnoXNy1a5ck6cYbb9Rdd92lo446SllZWdqzZ4/OOecc3X777XI6nXr88ce1ZMkSbd26VePGjev1Pbfeeqt+85vf6M4779R9992nSy65RLt371Z2dnaP1zc2Nuquu+7SX/7yF1mtVn3729/WddddpyeeeEKSdMcdd+iJJ57QY489pqlTp+ree+/VCy+8oNNOO63XGrZv366zzjpL//3f/60//elPqqio0DXXXKNrrrlGjz32mO+6u+66SzfffLNvVu9bb72lxsZG3XHHHXr00Uc1YsQI5ebm6qKLLlJpaaleeuklpaen64YbbtA555yjL774QgkJCb6fI/C+aBMz4WNdXZ0++OADffzxx7rmmmskSR6PR16vV3a7XWvWrNHpp5/e7T6n0ymnM/i/CAAAAAAAAAi9jIwMORwOJScnKz8/v9v5X/3qV1q8eLHvODs7WzNnzvQd33bbbfr73/+ul156yZcH9eSyyy7TRRddJEn69a9/rd///vd67733dNZZZ/V4fWtrqx588EFNmjRJknTNNdf4wlFJuu+++3TTTTfpa1/7miRpxYoVeuWVV4L+rMuXL9cll1yin/zkJ5KkwsJC/f73v9cpp5yiP/zhD0pMTJQknX766frpT3/qu++tt95Sa2urHnjgAd/P3hE6btiwQccff7wk6YknnlBBQYFeeOEFffOb3/T9HF3vi0YxEz6mp6dr06ZNfmMPPPCAXnvtNT377LOaOHGiSZUBAAAAAABgMObNm+d3XF9fr1/+8pd6+eWXdeDAAbW1tampqanP6cQzZszwfZ+SkqL09HQdOnSo1+uTk5N9waMkjRo1ynd9TU2NDh48qPnz5/vO22w2zZ07N+iaip988ok+/fRTX/ekZKxl6fF4tHPnTk2dOrXHn1ky1qXs+jNs3rxZdrtdCxYs8I2NGDFCU6ZM0ebNm3u9LxqZGj7W19dr27ZtvuOdO3dq48aNys7O1rhx43TTTTdp3759evzxx2W1WnXsscf63Z+bm6vExMRu4wAAAAAAAIh+gbtWX3fddVq7dq3uuusuTZ48WUlJSTr//PN73Jilq45pyB0sFkvQoLCn673evtewDKa+vl7f+9739KMf/ajbua5TxnvaqTspKWlQG+IM9r5IMjV8/OCDD/zmyneszXjppZdq5cqVOnDgQFQulAkAAAAAAGCmrGSHPvz5ItNr6A+HwyG3292vazds2KDLLrvMN925vr7etz5kpGRkZCgvL0/vv/++Tj75ZEnGXiQfffSRZs2a1et9c+bM0RdffKHJkycPuYapU6eqra1N7777rm/a9ZEjR7R161Ydc8wxQ35+JJkaPp566qlBU+WVK1cGvf+Xv/ylfvnLX4a2KAAAAAAAgChntVr63OwlWkyYMEHvvvuudu3apdTU1F43gZGMdRKff/55LVmyRBaLRb/4xS+CdjCGy7XXXqvly5dr8uTJOvroo3XfffepqqoqaJfhDTfcoC996Uu65pprdOWVVyolJUVffPGF1q5dqxUrVgzo/YWFhTr33HN11VVX6aGHHlJaWppuvPFGjRkzRueee+5Qf7yIsppdAAAAAAAAAOLXddddJ5vNpmOOOUYjR44MOsv1t7/9rbKysnT88cdryZIlKi4u1pw5cyJYreGGG27QRRddpKVLl2rhwoVKTU1VcXGxb9OYnsyYMUNvvPGGSkpKdNJJJ2n27Nm6+eabNXr06EHV8Nhjj2nu3Ln6yle+ooULF8rr9eqVV17pNmU82lm8Q53QHmNqa2uVkZGhmpoapaenm10OAAAAAABAn5qbm7Vz505NnDgxaACG8PB4PJo6daouuOAC3XbbbWaXExHBfs0NJF+Lmd2uMQCuBsnRffFSAAAAAAAA9G337t1as2aNTjnlFLW0tGjFihXauXOnLr74YrNLizlMu443Jaule2dJO94wuxIAAAAAAICYZLVatXLlSh133HE64YQTtGnTJr366quaOnWq2aXFHDof40nVbun5q6Xmaukv50mn/Uw6cZlkJWMGAAAAAADor4KCAm3YsMHsMuICqVS8aGuR/napETxKktcjvXab9L8XSo2VppYGAAAAAACA4YnwMV60NUspud3HS9dID50i7fsw8jUBAAAAAABgWCN8jBeJGdJFT0ln3CxZAv5nrSmT/lgsvfeINLw2NwcAAAAAAICJCB/jidUqnfRT6TsvSCkj/c95WqVXrpOeu1JqqTelPAAAAAAAAAwvhI/x6KhTpO+9JY07vvu5z56VHjldOrQl8nUBAAAAAABgWCF8jFfpo6RL/yEd/6Pu5w5vlR45Tfr0b5GvCwAAAAA
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_learning_curve()\n"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}