{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Uczenie maszynowe\n", "# 4. Regresja logistyczna" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Uwaga**: Wbrew nazwie, *regresja* logistyczna jest algorytmem służącym do rozwiązywania problemów *klasyfikacji* (wcale nie problemów *regresji*!)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Do demonstracji metody regresji ligistycznej wykorzystamy klasyczny zbiór danych *Iris flower data set*, składający się ze 150 przykładów wartości 4 cech dla 3 gatunków irysów (kosaćców)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### *Iris flower data set*\n", "\n", "* 150 przykładów\n", "* 4 cechy\n", "* 3 kategorie" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "| | | |\n", "| :--- | :--- | :--- |\n", "| *Iris setosa* | *Iris virginica* | *Iris versicolor* |\n", "| kosaciec szczecinkowy | kosaciec amerykański | kosaciec różnobarwny |\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "4 cechy:\n", " * długość działek kielicha (*sepal length*, `sl`)\n", " * szerokość działek kielicha (*sepal width*, `sw`)\n", " * długość płatka (*petal length*, `pl`)\n", " * szerokość płatka (*petal width*, `pw`)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 4.1. Dwuklasowa regresja logistyczna" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Zacznijmy od najprostszego przypadku:\n", " * ograniczmy się do **2** klas\n", " * ograniczmy się do **1** zmiennej" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "→ dwuklasowa regresja logistyczna jednej zmiennej" ] }, { "cell_type": "code", "execution_count": 147, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przydatne importy\n", "\n", "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas\n", "import ipywidgets as widgets\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "from IPython.display import display, Math, Latex\n", "\n", "# Przydatne funkcje\n", "\n", "def LatexMatrix(matrix):\n", " \"\"\"Wyświetlanie macierzy w LaTeX-u\"\"\"\n", " ltx = r'\\left[\\begin{array}'\n", " m, n = matrix.shape\n", " ltx += '{' + (\"r\" * n) + '}'\n", " for i in range(m):\n", " ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n", " ltx += r'\\end{array}\\right]'\n", " return ltx\n", "\n", "def h(theta, X):\n", " \"\"\"Hipoteza (wersja macierzowa)\"\"\"\n", " return X * theta\n", "\n", "def regdots(X, y, xlabel, ylabel):\n", " \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n", " fig = plt.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.margins(.05, .05)\n", " plt.ylim(y.min() - 1, y.max() + 1)\n", " plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", " return fig\n", "\n", "def regline(fig, fun, theta, X):\n", " \"\"\"Wykres krzywej regresji (wersja macierzowa)\"\"\"\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " L = [x0, x1]\n", " LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n", " ax.plot(L, fun(theta, LX), linewidth='2',\n", " label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n", " theta0=float(theta[0][0]),\n", " theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n", " op='+' if theta[1][0] >= 0 else '-')))\n", "\n", "def legend(fig):\n", " \"\"\"Legenda wykresu\"\"\"\n", " ax = fig.axes[0]\n", " handles, labels = ax.get_legend_handles_labels()\n", " # try-except block is a fix for a bug in Poly3DCollection\n", " try:\n", " fig.legend(handles, labels, fontsize='15', loc='lower right')\n", " except AttributeError:\n", " pass\n", "\n", "def J(theta,X,y):\n", " \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n", " m = len(y)\n", " J = 1.0 / (2.0 * m) * ((X * theta - y).T * ( X * theta - y))\n", " return J.item()\n", "\n", "def dJ(theta,X,y):\n", " \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n", " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", "\n", "def GD(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-3):\n", " \"\"\"Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\"\"\"\n", " current_cost = fJ(theta, X, y)\n", " while True:\n", " theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n", " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", " if current_cost > 10000:\n", " break\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " return theta\n", "\n", "theta_start = np.matrix([0, 0]).reshape(2, 1)\n", "\n", "def threshold(fig, theta):\n", " \"\"\"Funkcja, która rysuje próg\"\"\"\n", " x_thr = (0.5 - theta.item(0)) / theta.item(1)\n", " ax = fig.axes[0]\n", " ax.plot([x_thr, x_thr], [-1, 2],\n", " color='orange', linestyle='dashed',\n", " label=u'próg: $x={:.2F}$'.format(x_thr))" ] }, { "cell_type": "code", "execution_count": 148, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " sl sw pl pw Gatunek\n", "0 5.2 3.4 1.4 0.2 Iris-setosa\n", "1 5.1 3.7 1.5 0.4 Iris-setosa\n", "2 6.7 3.1 5.6 2.4 Iris-virginica\n", "3 6.5 3.2 5.1 2.0 Iris-virginica\n", "4 4.9 2.5 4.5 1.7 Iris-virginica\n", "5 6.0 2.7 5.1 1.6 Iris-versicolor\n" ] } ], "source": [ "# Wczytanie pełnych (oryginalnych) danych\n", "\n", "data_iris = pandas.read_csv(\"iris.csv\")\n", "print(data_iris[:6])\n" ] }, { "cell_type": "code", "execution_count": 149, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " dł. płatka Iris setosa?\n", "0 1.4 1\n", "1 1.5 1\n", "2 5.6 0\n", "3 5.1 0\n", "4 4.5 0\n", "5 5.1 0\n" ] } ], "source": [ "# Ograniczenie danych do 2 klas i 1 cechy\n", "\n", "data_iris_setosa = pandas.DataFrame()\n", "data_iris_setosa[\"dł. płatka\"] = data_iris[\"pl\"] # \"pl\" oznacza \"petal length\"\n", "data_iris_setosa[\"Iris setosa?\"] = data_iris[\"Gatunek\"].apply(\n", " lambda x: 1 if x == \"Iris-setosa\" else 0\n", ")\n", "print(data_iris_setosa[:6])\n" ] }, { "cell_type": "code", "execution_count": 150, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "# Przygotowanie danych\n", "m, n_plus_1 = data_iris_setosa.values.shape\n", "n = n_plus_1 - 1\n", "Xn = data_iris_setosa.values[:, 0:n].reshape(m, n)\n", "\n", "X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", "y = np.matrix(data_iris_setosa.values[:, 1]).reshape(m, 1)\n", "\n", "# Regresja liniowa\n", "theta_lin = GD(J, dJ, theta_start, X, y, alpha=0.03, eps=0.000001)\n" ] }, { "cell_type": "code", "execution_count": 151, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-11-03T15:23:33.850909\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = regdots(X, y, \"x\", \"Iris setosa?\")\n", "legend(fig)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Próba zastosowania regresji liniowej do problemu klasyfikacji\n", "\n", "Najpierw z ciekawości sprawdźmy, co otrzymalibyśmy, gdybyśmy zastosowali regresję liniową do problemu klasyfikacji." ] }, { "cell_type": "code", "execution_count": 152, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-11-03T15:23:34.362318\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = regdots(X, y, \"x\", \"Iris setosa?\")\n", "regline(fig, h, theta_lin, X)\n", "legend(fig)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "A gdyby tak przyjąć, że klasyfikator zwraca $1$ dla $h(x) > 0.5$ i $0$ w przeciwnym przypadku?" ] }, { "cell_type": "code", "execution_count": 153, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-11-03T15:23:35.035715\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = regdots(X, y, \"x\", \"Iris setosa?\")\n", "theta_lin = GD(J, dJ, theta_start, X, y, alpha=0.03, eps=0.000001)\n", "regline(fig, h, theta_lin, X)\n", "threshold(\n", " fig, theta_lin\n", ") # pomarańczowa linia oznacza granicę między klasą \"1\" a klasą \"0\" wyznaczoną przez próg \"h(x) = 0.5\"\n", "legend(fig)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ " * Krzywa regresji liniowej jest niezbyt dopasowana do danych klasyfikacyjnych.\n", " * Zastosowanie progu $y = 0.5$ nie zawsze pomaga uzyskać sensowny rezultat.\n", " * $h(x)$ może przyjmować wartości mniejsze od $0$ i większe od $1$ – jak interpretować takie wyniki?\n", "\n", "Wniosek: w przypadku problemów klasyfikacyjnych regresja liniowa nie wydaje się najlepszym rozwiązaniem." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Wprowadźmy zatem pewne modyfikacje do naszego modelu." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Zdefiniujmy następującą funkcję, którą będziemy nazywać funkcją *logistyczną* (albo *sigmoidalną*):" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "**Funkcja logistyczna (sigmoidalna)**:\n", "\n", "$$g(x) = \\dfrac{1}{1+e^{-x}}$$" ] }, { "cell_type": "code", "execution_count": 154, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def logistic(x):\n", " \"\"\"Funkcja logistyczna\"\"\"\n", " return 1.0 / (1.0 + np.exp(-x))\n" ] }, { "cell_type": "code", "execution_count": 155, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "\n", "def plot_logistic():\n", " \"\"\"Wykres funkcji logistycznej\"\"\"\n", " x = np.linspace(-5, 5, 200)\n", " y = logistic(x)\n", " fig = plt.figure(figsize=(7, 5))\n", " ax = fig.add_subplot(111)\n", " plt.ylim(-0.1, 1.1)\n", " ax.plot(x, y, linewidth=\"2\")\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Wykres funkcji logistycznej $g(x) = \\dfrac{1}{1+e^{-x}}$:" ] }, { "cell_type": "code", "execution_count": 156, "metadata": { "scrolled": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-11-03T15:23:35.446636\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_logistic()\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Funkcja logistyczna przekształca zbiór liczb rzeczywistych $\\mathbb{R}$ w przedział otwarty $(0, 1)$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Funkcja regresji logistycznej dla pojedynczego przykładu o cechach wyrażonych wektorem $x$:\n", "\n", "$$h_\\theta(x) = g(\\theta^T \\, x) = \\dfrac{1}{1 + e^{-\\theta^T x}}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Dla całej macierzy cech $X$:\n", "\n", "$$h_\\theta(X) = g(X \\, \\theta) = \\dfrac{1}{1 + e^{-X \\theta}}$$" ] }, { "cell_type": "code", "execution_count": 157, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "def h(theta, X):\n", " \"\"\"Funkcja regresji logistcznej\"\"\"\n", " return 1.0 / (1.0 + np.exp(-X * theta))\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Funkcja kosztu dla regresji logistycznej:\n", "\n", "$$J(\\theta) = -\\dfrac{1}{m} \\left( \\sum_{i=1}^{m} y^{(i)} \\log h_\\theta( x^{(i)} ) + \\left( 1 - y^{(i)} \\right) \\log \\left( 1 - h_\\theta (x^{(i)}) \\right) \\right)$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Gradient dla regresji logistycznej (wersja macierzowa):\n", "\n", "$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T \\left( h_\\theta(X) - \\vec y \\right)$$\n", "\n", "(Jedyna różnica między gradientem dla regresji logistycznej a gradientem dla regresji liniowej to postać $h_\\theta$)." ] }, { "cell_type": "code", "execution_count": 158, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def J(h, theta, X, y):\n", " \"\"\"Funkcja kosztu dla regresji logistycznej\"\"\"\n", " m = len(y)\n", " h_val = h(theta, X)\n", " s1 = np.multiply(y, np.log(h_val))\n", " s2 = np.multiply((1 - y), np.log(1 - h_val))\n", " return -np.sum(s1 + s2, axis=0) / m\n" ] }, { "cell_type": "code", "execution_count": 159, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def dJ(h, theta, X, y):\n", " \"\"\"Gradient dla regresji logistycznej\"\"\"\n", " return 1.0 / len(y) * (X.T * (h(theta, X) - y))\n" ] }, { "cell_type": "code", "execution_count": 160, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, max_steps=10000):\n", " \"\"\"Metoda gradientu prostego dla regresji logistycznej\"\"\"\n", " curr_cost = fJ(h, theta, X, y)\n", " history = [[curr_cost, theta]]\n", " while True:\n", " # oblicz nowe theta\n", " theta = theta - alpha * fdJ(h, theta, X, y)\n", " # raportuj poziom błędu\n", " prev_cost = curr_cost\n", " curr_cost = fJ(h, theta, X, y)\n", " # kryteria stopu\n", " if abs(prev_cost - curr_cost) <= eps:\n", " break\n", " if len(history) > max_steps:\n", " break\n", " history.append([curr_cost, theta])\n", " return theta, history\n" ] }, { "cell_type": "code", "execution_count": 161, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Koszt: [[0.05755617]]\n", "theta = [[ 5.02530461]\n", " [-1.99174803]]\n" ] } ], "source": [ "# Uruchomienie metody gradientu prostego dla regresji logistycznej\n", "theta_best, history = GD(\n", " h, J, dJ, theta_start, X, y, alpha=0.1, eps=10**-7, max_steps=1000\n", ")\n", "print(f\"Koszt: {history[-1][0]}\")\n", "print(f\"theta = {theta_best}\")\n" ] }, { "cell_type": "code", "execution_count": 162, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "def scalar_logistic_regression_function(theta, x):\n", " \"\"\"Funkcja regresji logistycznej (wersja skalarna)\"\"\"\n", " return 1.0 / (1.0 + np.exp(-(theta.item(0) + theta.item(1) * x)))\n", "\n", "\n", "def threshold_val(fig, x_thr):\n", " \"\"\"Rysowanie progu\"\"\"\n", " ax = fig.axes[0]\n", " ax.plot(\n", " [x_thr, x_thr],\n", " [-1, 2],\n", " color=\"orange\",\n", " linestyle=\"dashed\",\n", " label=\"próg: $x={:.2F}$\".format(x_thr),\n", " )\n", "\n", "\n", "def logistic_regline(fig, theta, X):\n", " \"\"\"Wykres krzywej regresji logistycznej\"\"\"\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " Arg = np.arange(x0, x1, 0.1)\n", " Val = scalar_logistic_regression_function(theta, Arg)\n", " ax.plot(Arg, Val, linewidth=\"2\")\n" ] }, { "cell_type": "code", "execution_count": 163, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-11-03T15:23:36.128355\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.6.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = regdots(X, y, xlabel=\"x\", ylabel=\"Iris setosa?\")\n", "logistic_regline(fig, theta_best, X)\n", "threshold_val(fig, 2.5)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Traktujemy wartość $h_\\theta(x)$ jako prawdopodobieństwo, że cecha przyjmie wartość pozytywną:\n", "\n", "$$ h_\\theta(x) = P(y = 1 \\, | \\, x; \\theta) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Jeżeli $h_\\theta(x) > 0.5$, to dla takiego $x$ będziemy przewidywać wartość $y = 1$.\n", "W przeciwnym wypadku uprzewidzimy $y = 0$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Dlaczego możemy traktować wartość funkcji regresji logistycznej jako prawdopodobieństwo?\n", "\n", "Można o tym poczytać w zewnętrznych źródłach, np. https://towardsdatascience.com/logit-of-logistic-regression-understanding-the-fundamentals-f384152a33d1" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Dwuklasowa regresja logistyczna: więcej cech" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Jak postąpić, jeżeli będziemy mieli więcej niż jedną cechę $x$?\n", "\n", "Weźmy teraz wszystkie cechy występujące w zbiorze *Iris*:\n", "* długość płatków (`pl`, *petal length*)\n", "* szerokość płatków (`pw`, *petal width*)\n", "* długość działek kielicha (`sl`, *sepal length*)\n", "* szerokość działek kielicha (`sw`, *sepal width*)" ] }, { "cell_type": "code", "execution_count": 164, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " pl pw sl sw Iris setosa?\n", "0 1.4 0.2 5.2 3.4 1\n", "1 1.5 0.4 5.1 3.7 1\n", "2 5.6 2.4 6.7 3.1 0\n", "3 5.1 2.0 6.5 3.2 0\n", "4 4.5 1.7 4.9 2.5 0\n", "5 5.1 1.6 6.0 2.7 0\n" ] } ], "source": [ "data_iris_setosa_multi = pandas.DataFrame()\n", "for feature in [\"pl\", \"pw\", \"sl\", \"sw\"]:\n", " data_iris_setosa_multi[feature] = data_iris[feature]\n", "data_iris_setosa_multi[\"Iris setosa?\"] = data_iris[\"Gatunek\"].apply(\n", " lambda x: 1 if x == \"Iris-setosa\" else 0\n", ")\n", "print(data_iris_setosa_multi[:6])\n" ] }, { "cell_type": "code", "execution_count": 165, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1. 1.4 0.2 5.2 3.4]\n", " [1. 1.5 0.4 5.1 3.7]\n", " [1. 5.6 2.4 6.7 3.1]\n", " [1. 5.1 2. 6.5 3.2]\n", " [1. 4.5 1.7 4.9 2.5]\n", " [1. 5.1 1.6 6. 2.7]]\n", "[[1.]\n", " [1.]\n", " [0.]\n", " [0.]\n", " [0.]\n", " [0.]]\n" ] } ], "source": [ "# Przygotowanie danych\n", "m, n_plus_1 = data_iris_setosa_multi.values.shape\n", "n = n_plus_1 - 1\n", "Xn = data_iris_setosa_multi.values[:, 0:n].reshape(m, n)\n", "\n", "X = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", "y = np.matrix(data_iris_setosa_multi.values[:, n]).reshape(m, 1)\n", "\n", "print(X[:6])\n", "print(y[:6])\n" ] }, { "cell_type": "code", "execution_count": 166, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Podział danych na zbiór trenujący i testowy\n", "XTrain, XTest = X[:100], X[100:]\n", "yTrain, yTest = y[:100], y[100:]\n", "\n", "# Macierz parametrów początkowych\n", "theta_start = np.ones(5).reshape(5, 1)\n" ] }, { "cell_type": "code", "execution_count": 167, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Koszt: [[0.006797]]\n", "theta = [[ 1.11414027]\n", " [-2.89324615]\n", " [-0.66543637]\n", " [ 0.14887292]\n", " [ 2.13284493]]\n" ] } ], "source": [ "theta_best, history = GD(\n", " h, J, dJ, theta_start, XTrain, yTrain, alpha=0.1, eps=10**-7, max_steps=1000\n", ")\n", "print(f\"Koszt: {history[-1][0]}\")\n", "print(f\"theta = {theta_best}\")\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Funkcja decyzyjna regresji logistycznej\n", "\n", "Funkcja decyzyjna mówi o tym, kiedy nasz algorytm będzie przewidywał $y = 1$, a kiedy $y = 0$:\n", "\n", "$$ c(x) := \\left\\{ \n", "\\begin{array}{ll}\n", "1, & \\mbox{gdy } P(y=1 \\, | \\, x; \\theta) > 0.5 \\\\\n", "0 & \\mbox{w przeciwnym przypadku}\n", "\\end{array}\\right.\n", "$$\n", "\n", "$$ P(y=1 \\,| \\, x; \\theta) = h_\\theta(x) $$" ] }, { "cell_type": "code", "execution_count": 168, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "theta = [[ 1.11414027]\n", " [-2.89324615]\n", " [-0.66543637]\n", " [ 0.14887292]\n", " [ 2.13284493]]\n", "x0 = [[1. 6.3 1.8 7.3 2.9]]\n", "h(x0) = 1.606143695982487e-05\n", "c(x0) = (0, 1.606143695982487e-05)\n" ] } ], "source": [ "def classifyBi(theta, X):\n", " \"\"\"Funkcja decyzyjna regresji logistycznej\"\"\"\n", " prob = h(theta, X).item()\n", " return (1, prob) if prob > 0.5 else (0, prob)\n", "\n", "\n", "print(f\"theta = {theta_best}\")\n", "print(f\"x0 = {XTest[0]}\")\n", "print(f\"h(x0) = {h(theta_best, XTest[0]).item()}\")\n", "print(f\"c(x0) = {classifyBi(theta_best, XTest[0])}\")\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Obliczmy teraz skuteczność modelu (więcej na ten temat na następnym wykładzie, poświęconym metodom ewaluacji)." ] }, { "cell_type": "code", "execution_count": 169, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 <=> 0 -- prob: 0.0000\n", "1 <=> 1 -- prob: 0.9816\n", "0 <=> 0 -- prob: 0.0001\n", "0 <=> 0 -- prob: 0.0005\n", "0 <=> 0 -- prob: 0.0001\n", "1 <=> 1 -- prob: 0.9936\n", "0 <=> 0 -- prob: 0.0059\n", "0 <=> 0 -- prob: 0.0992\n", "0 <=> 0 -- prob: 0.0001\n", "0 <=> 0 -- prob: 0.0001\n", "\n", "Accuracy: 1.0\n" ] } ], "source": [ "correct = 0\n", "for i, rest in enumerate(yTest):\n", " cls, prob = classifyBi(theta_best, XTest[i])\n", " if i < 10:\n", " print(f\"{yTest[i].item():1.0f} <=> {cls} -- prob: {prob:6.4f}\")\n", " correct += cls == yTest[i].item()\n", "accuracy = correct / len(XTest)\n", "\n", "print(f\"\\nAccuracy: {accuracy}\")\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 4.2. Wieloklasowa regresja logistyczna" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Przykład: wszystkie cechy ze zbioru *Iris*, wszystkie 3 klasy ze zbioru *Iris*." ] }, { "cell_type": "code", "execution_count": 170, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
slswplpwGatunek
05.23.41.40.2Iris-setosa
15.13.71.50.4Iris-setosa
26.73.15.62.4Iris-virginica
36.53.25.12.0Iris-virginica
44.92.54.51.7Iris-virginica
56.02.75.11.6Iris-versicolor
\n", "
" ], "text/plain": [ " sl sw pl pw Gatunek\n", "0 5.2 3.4 1.4 0.2 Iris-setosa\n", "1 5.1 3.7 1.5 0.4 Iris-setosa\n", "2 6.7 3.1 5.6 2.4 Iris-virginica\n", "3 6.5 3.2 5.1 2.0 Iris-virginica\n", "4 4.9 2.5 4.5 1.7 Iris-virginica\n", "5 6.0 2.7 5.1 1.6 Iris-versicolor" ] }, "execution_count": 170, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas\n", "\n", "data_iris = pandas.read_csv(\"iris.csv\")\n", "data_iris[:6]\n" ] }, { "cell_type": "code", "execution_count": 171, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X = [[1. 5.2 3.4 1.4 0.2]\n", " [1. 5.1 3.7 1.5 0.4]\n", " [1. 6.7 3.1 5.6 2.4]\n", " [1. 6.5 3.2 5.1 2. ]]\n", "y = [['Iris-setosa']\n", " ['Iris-setosa']\n", " ['Iris-virginica']\n", " ['Iris-virginica']]\n" ] } ], "source": [ "# Przygotowanie danych\n", "\n", "import numpy as np\n", "\n", "features = [\"sl\", \"sw\", \"pl\", \"pw\"]\n", "m = len(data_iris)\n", "X = np.matrix(data_iris[features])\n", "X0 = np.ones(m).reshape(m, 1)\n", "X = np.hstack((X0, X))\n", "y = np.matrix(data_iris[[\"Gatunek\"]]).reshape(m, 1)\n", "\n", "print(\"X = \", X[:4])\n", "print(\"y = \", y[:4])\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Zamieńmy etykiety tekstowe w tablicy $y$ na wektory jednostkowe (*one-hot vectors*):\n", "\n", "$$\n", "\\begin{array}{ccc}\n", "\\mbox{\"Iris-setosa\"} & \\mapsto & \\left[ \\begin{array}{ccc} 1 & 0 & 0 \\\\ \\end{array} \\right] \\\\\n", "\\mbox{\"Iris-virginica\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 1 & 0 \\\\ \\end{array} \\right] \\\\\n", "\\mbox{\"Iris-versicolor\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 0 & 1 \\\\ \\end{array} \\right] \\\\\n", "\\end{array}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Wówczas zamiast wektora $y$ otrzymamy macierz $Y$:\n", "\n", "$$\n", "y \\; = \\;\n", "\\left[\n", "\\begin{array}{c}\n", "y^{(1)} \\\\\n", "y^{(2)} \\\\\n", "y^{(3)} \\\\\n", "y^{(4)} \\\\\n", "y^{(5)} \\\\\n", "\\vdots \\\\\n", "\\end{array}\n", "\\right]\n", "\\; = \\;\n", "\\left[\n", "\\begin{array}{c}\n", "\\mbox{\"Iris-setosa\"} \\\\\n", "\\mbox{\"Iris-setosa\"} \\\\\n", "\\mbox{\"Iris-virginica\"} \\\\\n", "\\mbox{\"Iris-versicolor\"} \\\\\n", "\\mbox{\"Iris-virginica\"} \\\\\n", "\\vdots \\\\\n", "\\end{array}\n", "\\right]\n", "\\quad \\mapsto \\quad\n", "Y \\; = \\;\n", "\\left[\n", "\\begin{array}{ccc}\n", "1 & 0 & 0 \\\\\n", "1 & 0 & 0 \\\\\n", "0 & 1 & 0 \\\\\n", "0 & 0 & 1 \\\\\n", "0 & 1 & 0 \\\\\n", "\\vdots & \\vdots & \\vdots \\\\\n", "\\end{array}\n", "\\right]\n", "$$" ] }, { "cell_type": "code", "execution_count": 172, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "def mapY(y, cls):\n", " m = len(y)\n", " yBi = np.matrix(np.zeros(m)).reshape(m, 1)\n", " yBi[y == cls] = 1.0\n", " return yBi\n", "\n", "\n", "def indicatorMatrix(y):\n", " classes = np.unique(y.tolist())\n", " m = len(y)\n", " k = len(classes)\n", " Y = np.matrix(np.zeros((m, k)))\n", " for i, cls in enumerate(classes):\n", " Y[:, i] = mapY(y, cls)\n", " return Y\n", "\n", "\n", "# Macierz jednostkowa\n", "Y = indicatorMatrix(y)\n" ] }, { "cell_type": "code", "execution_count": 173, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Podział danych na zbiór trenujący i testowy\n", "XTrain, XTest = X[:100], X[100:]\n", "YTrain, YTest = Y[:100], Y[100:]\n", "\n", "# Macierz parametrów początkowych - niech skłąda się z samych jedynek\n", "theta_start = np.ones(5).reshape(5, 1)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Regresja logistyczna jest metodą rozwiązywania problemów klasyfikacji **dwuklasowej**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Aby znaleźć rozwiązanie problemu klasyfikacji **wieloklasowej** metodą regresji logistycznej, trzeba przekształcić problem na zbiór problemów klasyfikacji dwuklasowej." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Alternatywnie, można użyć **wielomianowej regresji logistycznej** (zob. https://machinelearningmastery.com/multinomial-logistic-regression-with-python)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Od regresji logistycznej dwuklasowej do wieloklasowej\n", "\n", "* Irysy są przydzielone do trzech klas: _Iris-setosa_ (0), _Iris-versicolor_ (1), _Iris-virginica_ (2).\n", "* Wiemy, jak stworzyć klasyfikatory dwuklasowe typu _Iris-setosa_ vs. _Nie-Iris-setosa_ (tzw. *one-vs-all*).\n", "* Możemy stworzyć trzy klasyfikatory $h_{\\theta_1}, h_{\\theta_2}, h_{\\theta_3}$ (otrzymując trzy zestawy parametrów $\\theta$) i wybrać klasę o najwyższym prawdopodobieństwie." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Pomoże nam w tym funkcja *softmax*, która jest uogólnieniem funkcji logistycznej na większą liczbę wymiarów." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Funkcja _softmax_\n", "\n", "Odpowiednikiem funkcji logistycznej dla wieloklasowej regresji logistycznej jest funkcja $\\mathrm{softmax}$:\n", "\n", "$$ \\textrm{softmax} \\colon \\mathbb{R}^k \\to [0,1]^k $$\n", "\n", "$$ \\textrm{softmax}(z_1,z_2,\\dots,z_k) = \\left( \\dfrac{e^{z_1}}{\\sum_{i=1}^{k}e^{z_i}}, \\dfrac{e^{z_2}}{\\sum_{i=1}^{k}e^{z_i}}, \\ldots, \\dfrac{e^{z_k}}{\\sum_{i=1}^{k}e^{z_i}} \\right) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ \\textrm{softmax}( \\left[ \\begin{array}{c} \\theta_1^T x \\\\ \\theta_2^T x \\\\ \\vdots \\\\ \\theta_k^T x \\end{array} \\right] ) = \\left[ \\begin{array}{c} P(y=1 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ P(y=2 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ \\vdots \\\\ P(y=k \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\end{array} \\right] $$" ] }, { "cell_type": "code", "execution_count": 174, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def softmax(X):\n", " \"\"\"Funkcja softmax (wersja macierzowa)\"\"\"\n", " return np.exp(X) / np.sum(np.exp(X))\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Wartości funkcji $\\mathrm{softmax}$ sumują się do 1:" ] }, { "cell_type": "code", "execution_count": 175, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9999999999999999\n" ] } ], "source": [ "Z = np.matrix([[2.1, 0.5, 0.8, 0.9, 3.2]])\n", "P = softmax(Z)\n", "print(np.sum(P))\n" ] }, { "cell_type": "code", "execution_count": 176, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def multiple_binary_classifiers(X, Y):\n", " n = X.shape[1]\n", " thetas = []\n", " # Dla każdej klasy wytrenujmy osobny klasyfikator dwuklasowy.\n", " for c in range(Y.shape[1]):\n", " YBi = Y[:, c]\n", " theta = np.matrix(np.random.random(n)).reshape(n, 1)\n", " # Macierz parametrów theta obliczona dla każdej klasy osobno.\n", " theta_best, history = GD(h, J, dJ, theta, X, YBi, alpha=0.1, eps=10**-4)\n", " thetas.append(theta_best)\n", " return thetas\n" ] }, { "cell_type": "code", "execution_count": 177, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Otrzymana macierz parametrów theta dla klasy 0:\n", " [[ 0.30877778]\n", " [-0.16504776]\n", " [ 1.92701194]\n", " [-1.83418434]\n", " [-0.50458444]] \n", "\n", "Otrzymana macierz parametrów theta dla klasy 1:\n", " [[ 0.93723385]\n", " [-0.13501701]\n", " [-0.8448612 ]\n", " [ 0.77823106]\n", " [-0.95577092]] \n", "\n", "Otrzymana macierz parametrów theta dla klasy 2:\n", " [[-0.72237014]\n", " [-1.56606505]\n", " [-1.71063165]\n", " [ 2.21207268]\n", " [ 2.78489436]] \n", "\n" ] } ], "source": [ "# Macierze theta dla każdej klasy\n", "thetas = multiple_binary_classifiers(XTrain, YTrain)\n", "for c, theta in enumerate(thetas):\n", " print(f\"Otrzymana macierz parametrów theta dla klasy {c}:\\n\", theta, \"\\n\")\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Funkcja decyzyjna wieloklasowej regresji logistycznej\n", "\n", "$$ c = \\mathop{\\textrm{arg}\\,\\textrm{max}}_{i \\in \\{1, \\ldots ,k\\}} P(y=i|x;\\theta_1,\\ldots,\\theta_k) $$" ] }, { "cell_type": "code", "execution_count": 178, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def classify(thetas, X, debug=False):\n", " regs = np.array([(X * theta).item() for theta in thetas])\n", " if debug:\n", " print(\"Po zastosowaniu regresji: \", regs)\n", " probs = softmax(regs)\n", " if debug:\n", " print(\"Otrzymane prawdopodobieństwa: \", np.around(probs, decimals=3))\n", " result = np.argmax(probs)\n", " if debug:\n", " print(\"Wybrana klasa: \", result)\n", " return result\n" ] }, { "cell_type": "code", "execution_count": 179, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dla x = [[1. 7.3 2.9 6.3 1.8]]:\n", "Po zastosowaniu regresji: [-7.77134959 0.68398021 1.83339097]\n", "Otrzymane prawdopodobieństwa: [0. 0.241 0.759]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n", "Dla x = [[1. 4.8 3. 1.4 0.3]]:\n", "Po zastosowaniu regresji: [ 2.57835094 -1.4426392 -9.43900726]\n", "Otrzymane prawdopodobieństwa: [0.982 0.018 0. ]\n", "Wybrana klasa: 0\n", "Obliczone y = 0\n", "Oczekiwane y = 0\n", "\n", "Dla x = [[1. 7.1 3. 5.9 2.1]]:\n", "Po zastosowaniu regresji: [-6.96334044 0.0284738 1.92618005]\n", "Otrzymane prawdopodobieństwa: [0. 0.13 0.87]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n", "Dla x = [[1. 5.9 3. 5.1 1.8]]:\n", "Po zastosowaniu regresji: [-5.14656032 -0.14535936 1.20033165]\n", "Otrzymane prawdopodobieństwa: [0.001 0.206 0.792]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n" ] } ], "source": [ "for i in range(4):\n", " print(f\"Dla x = {XTest[i]}:\")\n", " YPredicted = classify(thetas, XTest[i], debug=True)\n", " print(f\"Obliczone y = {YPredicted}\")\n", " print(f\"Oczekiwane y = {np.argmax(YTest[i])}\")\n", " print()\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "livereveal": { "start_slideshow_at": "selected", "theme": "white" }, "vscode": { "interpreter": { "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" } } }, "nbformat": 4, "nbformat_minor": 4 }