From 85885137284ed0cae06f0b352d54c77870fb8090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Sk=C3=B3rzewski?= Date: Fri, 14 Oct 2022 16:11:55 +0200 Subject: [PATCH] =?UTF-8?q?Wyk=C5=82ad=204.=20Regresja=20logistyczna?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wyk/04_Regresja_logistyczna.ipynb | 6193 +++++++++++++++++++++++++++++ 1 file changed, 6193 insertions(+) create mode 100644 wyk/04_Regresja_logistyczna.ipynb diff --git a/wyk/04_Regresja_logistyczna.ipynb b/wyk/04_Regresja_logistyczna.ipynb new file mode 100644 index 0000000..b61990d --- /dev/null +++ b/wyk/04_Regresja_logistyczna.ipynb @@ -0,0 +1,6193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### AITech — Uczenie maszynowe\n", + "# 3. Regresja logistyczna" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "**Uwaga**: Wbrew nazwie, *regresja* logistyczna jest algorytmem służącym do rozwiązywania problemów *klasyfikacji* (wcale nie problemów *regresji*!)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Do demonstracji metody regresji ligistycznej wykorzystamy klasyczny zbiór danych *Iris flower data set*, składający się ze 150 przykładów wartości 4 cech dla 3 gatunków irysów (kosaćców)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### *Iris flower data set*\n", + "\n", + "* 150 przykładów\n", + "* 4 cechy\n", + "* 3 kategorie" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "| | | |\n", + "| :--- | :--- | :--- |\n", + "| *Iris setosa* | *Iris virginica* | *Iris versicolor* |\n", + "| kosaciec szczecinkowy | kosaciec amerykański | kosaciec różnobarwny |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "4 cechy:\n", + " * długość działek kielicha (*sepal length*, `sl`)\n", + " * szerokość działek kielicha (*sepal width*, `sw`)\n", + " * długość płatka (*petal length*, `pl`)\n", + " * szerokość płatka (*petal width*, `pw`)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## 3.1. Dwuklasowa regresja logistyczna" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Zacznijmy od najprostszego przypadku:\n", + " * ograniczmy się do **2** klas\n", + " * ograniczmy się do **1** zmiennej" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "→ dwuklasowa regresja logistyczna jednej zmiennej" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Przydatne importy\n", + "\n", + "import numpy as np\n", + "import matplotlib\n", + "import matplotlib.pyplot as pl\n", + "import pandas\n", + "import ipywidgets as widgets\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'svg'\n", + "\n", + "from IPython.display import display, Math, Latex\n", + "\n", + "# Przydatne funkcje\n", + "\n", + "# Wyświetlanie macierzy w LaTeX-u\n", + "def LatexMatrix(matrix):\n", + " ltx = r'\\left[\\begin{array}'\n", + " m, n = matrix.shape\n", + " ltx += '{' + (\"r\" * n) + '}'\n", + " for i in range(m):\n", + " ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n", + " ltx += r'\\end{array}\\right]'\n", + " return ltx\n", + "\n", + "# Hipoteza (wersja macierzowa)\n", + "def hMx(theta, X):\n", + " return X * theta\n", + "\n", + "# Wykres danych (wersja macierzowa)\n", + "def regdotsMx(X, y, xlabel, ylabel): \n", + " fig = pl.figure(figsize=(16*.6, 9*.6))\n", + " ax = fig.add_subplot(111)\n", + " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", + " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", + " \n", + " ax.set_xlabel(xlabel)\n", + " ax.set_ylabel(ylabel)\n", + " ax.margins(.05, .05)\n", + " pl.ylim(y.min() - 1, y.max() + 1)\n", + " pl.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", + " return fig\n", + "\n", + "# Wykres krzywej regresji (wersja macierzowa)\n", + "def reglineMx(fig, fun, theta, X):\n", + " ax = fig.axes[0]\n", + " x0 = np.min(X[:, 1]) - 1.0\n", + " x1 = np.max(X[:, 1]) + 1.0\n", + " L = [x0, x1]\n", + " LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n", + " ax.plot(L, fun(theta, LX), linewidth='2',\n", + " label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n", + " theta0=float(theta[0][0]),\n", + " theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n", + " op='+' if theta[1][0] >= 0 else '-')))\n", + "\n", + "# Legenda wykresu\n", + "def legend(fig):\n", + " ax = fig.axes[0]\n", + " handles, labels = ax.get_legend_handles_labels()\n", + " # try-except block is a fix for a bug in Poly3DCollection\n", + " try:\n", + " fig.legend(handles, labels, fontsize='15', loc='lower right')\n", + " except AttributeError:\n", + " pass\n", + "\n", + "# Wersja macierzowa funkcji kosztu\n", + "def JMx(theta,X,y):\n", + " m = len(y)\n", + " J = 1.0 / (2.0 * m) * ((X * theta - y).T * ( X * theta - y))\n", + " return J.item()\n", + "\n", + "# Wersja macierzowa gradientu funkcji kosztu\n", + "def dJMx(theta,X,y):\n", + " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", + "\n", + "# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n", + "def GDMx(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-3):\n", + " current_cost = fJ(theta, X, y)\n", + " logs = [[current_cost, theta]]\n", + " while True:\n", + " theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n", + " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", + " if current_cost > 10000:\n", + " break\n", + " if abs(prev_cost - current_cost) <= eps:\n", + " break\n", + " logs.append([current_cost, theta]) \n", + " return theta, logs\n", + "\n", + "thetaStartMx = np.matrix([0, 0]).reshape(2, 1)\n", + "\n", + "# Funkcja, która rysuje próg\n", + "def threshold(fig, theta):\n", + " x_thr = (0.5 - theta.item(0)) / theta.item(1)\n", + " ax = fig.axes[0]\n", + " ax.plot([x_thr, x_thr], [-1, 2],\n", + " color='orange', linestyle='dashed',\n", + " label=u'próg: $x={:.2F}$'.format(x_thr))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " sl sw pl pw Gatunek\n", + "0 5.2 3.4 1.4 0.2 Iris-setosa\n", + "1 5.1 3.7 1.5 0.4 Iris-setosa\n", + "2 6.7 3.1 5.6 2.4 Iris-virginica\n", + "3 6.5 3.2 5.1 2.0 Iris-virginica\n", + "4 4.9 2.5 4.5 1.7 Iris-virginica\n", + "5 6.0 2.7 5.1 1.6 Iris-versicolor\n" + ] + } + ], + "source": [ + "# Wczytanie pełnych (oryginalnych) danych\n", + "\n", + "data_iris = pandas.read_csv('iris.csv')\n", + "print(data_iris[:6])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true, + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " dł. płatka Iris setosa?\n", + "0 1.4 1\n", + "1 1.5 1\n", + "2 5.6 0\n", + "3 5.1 0\n", + "4 4.5 0\n", + "5 5.1 0\n" + ] + } + ], + "source": [ + "# Ograniczenie danych do 2 klas i 1 cechy\n", + "\n", + "data_iris_setosa = pandas.DataFrame()\n", + "data_iris_setosa['dł. płatka'] = data_iris['pl'] # \"pl\" oznacza \"petal length\"\n", + "data_iris_setosa['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n", + "print(data_iris_setosa[:6])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# Przygotowanie danych\n", + "m, n_plus_1 = data_iris_setosa.values.shape\n", + "n = n_plus_1 - 1\n", + "Xn = data_iris_setosa.values[:, 0:n].reshape(m, n)\n", + "\n", + "XMx3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", + "yMx3 = np.matrix(data_iris_setosa.values[:, 1]).reshape(m, 1)\n", + "\n", + "# Regresja liniowa\n", + "theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n", + "legend(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Próba zastosowania regresji liniowej do problemu klasyfikacji\n", + "\n", + "Najpierw z ciekawości sprawdźmy, co otrzymalibyśmy, gdybyśmy zastosowali regresję liniową do problemu klasyfikacji." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n", + "reglineMx(fig, hMx, theta_e3, XMx3)\n", + "legend(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "A gdyby tak przyjąć, że klasyfikator zwraca $1$ dla $h(x) > 0.5$ i $0$ w przeciwnym przypadku?" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n", + "theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)\n", + "reglineMx(fig, hMx, theta_e3, XMx3)\n", + "threshold(fig, theta_e3) # pomarańczowa linia oznacza granicę między klasą \"1\" a klasą \"0\" wyznaczoną przez próg \"h(x) = 0.5\"\n", + "legend(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + " * Krzywa regresji liniowej jest niezbyt dopasowana do danych klasyfikacyjnych.\n", + " * Zastosowanie progu $y = 0.5$ nie zawsze pomaga uzyskać sensowny rezultat.\n", + " * $h(x)$ może przyjmować wartości mniejsze od $0$ i większe od $1$ – jak interpretować takie wyniki?\n", + "\n", + "Wniosek: w przypadku problemów klasyfikacyjnych regresja liniowa nie wydaje się najlepszym rozwiązaniem." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Wprowadźmy zatem pewne modyfikacje do naszego modelu." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Zdefiniujmy następującą funkcję, którą będziemy nazywać funkcją *logistyczną* (albo *sigmoidalną*):" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "**Funkcja logistyczna (sigmoidalna)**:\n", + "\n", + "$$g(x) = \\dfrac{1}{1+e^{-x}}$$" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "# Funkjca logistycza\n", + "\n", + "def logistic(x):\n", + " return 1.0 / (1.0 + np.exp(-x))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "def plot_logistic():\n", + " x = np.linspace(-5,5,200)\n", + " y = logistic(x)\n", + "\n", + " fig = plt.figure(figsize=(7,5))\n", + " ax = fig.add_subplot(111)\n", + " plt.ylim(-.1,1.1)\n", + " ax.plot(x, y, linewidth='2')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Wykres funkcji logistycznej $g(x) = \\dfrac{1}{1+e^{-x}}$:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": true, + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_logistic()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Funkcja logistyczna przekształca zbiór liczb rzeczywistych $\\mathbb{R}$ w przedział otwarty $(0, 1)$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Funkcja regresji logistycznej dla pojedynczego przykładu o cechach wyrażonych wektorem $x$:\n", + "\n", + "$$h_\\theta(x) = g(\\theta^T \\, x) = \\dfrac{1}{1 + e^{-\\theta^T x}}$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Dla całej macierzy cech $X$:\n", + "\n", + "$$h_\\theta(X) = g(X \\, \\theta) = \\dfrac{1}{1 + e^{-X \\theta}}$$" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "# Funkcja regresji logistcznej\n", + "def h(theta, X):\n", + " return 1.0/(1.0 + np.exp(-X * theta))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Funkcja kosztu dla regresji logistycznej:\n", + "\n", + "$$J(\\theta) = -\\dfrac{1}{m} \\left( \\sum_{i=1}^{m} y^{(i)} \\log h_\\theta( x^{(i)} ) + \\left( 1 - y^{(i)} \\right) \\log \\left( 1 - h_\\theta (x^{(i)}) \\right) \\right)$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Gradient dla regresji logistycznej (wersja macierzowa):\n", + "\n", + "$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T \\left( h_\\theta(X) - \\vec y \\right)$$\n", + "\n", + "(Jedyna różnica między gradientem dla regresji logistycznej a gradientem dla regresji liniowej to postać $h_\\theta$)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "# Funkcja kosztu dla regresji logistycznej\n", + "def J(h, theta, X, y):\n", + " m = len(y)\n", + " h_val = h(theta, X)\n", + " s1 = np.multiply(y, np.log(h_val))\n", + " s2 = np.multiply((1 - y), np.log(1 - h_val))\n", + " return -np.sum(s1 + s2, axis=0) / m" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "# Gradient dla regresji logistycznej\n", + "def dJ(h, theta, X, y):\n", + " return 1.0 / len(y) * (X.T * (h(theta, X) - y))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "# Metoda gradientu prostego dla regresji logistycznej\n", + "def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n", + " errorCurr = fJ(h, theta, X, y)\n", + " errors = [[errorCurr, theta]]\n", + " while True:\n", + " # oblicz nowe theta\n", + " theta = theta - alpha * fdJ(h, theta, X, y)\n", + " # raportuj poziom błędu\n", + " errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n", + " # kryteria stopu\n", + " if abs(errorPrev - errorCurr) <= eps:\n", + " break\n", + " if len(errors) > maxSteps:\n", + " break\n", + " errors.append([errorCurr, theta]) \n", + " return theta, errors" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "error = [[0.05755617]]\n", + "theta = [[ 5.02530461]\n", + " [-1.99174803]]\n" + ] + } + ], + "source": [ + "# Uruchomienie metody gradientu prostego dla regresji logistycznej\n", + "thetaBest, errors = GD(h, J, dJ, thetaStartMx, XMx3, yMx3, \n", + " alpha=0.1, eps=10**-7, maxSteps=1000)\n", + "print(\"error =\", errors[-1][0])\n", + "print(\"theta =\", thetaBest)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Funkcja regresji logistycznej (wersja skalarna)\n", + "def scalar_logistic_regression_function(theta, x):\n", + " return 1.0/(1.0 + np.exp(-(theta.item(0) + theta.item(1) * x)))\n", + "\n", + "# Rysowanie progu\n", + "def threshold_val(fig, x_thr):\n", + " ax = fig.axes[0]\n", + " ax.plot([x_thr, x_thr], [-1, 2],\n", + " color='orange', linestyle='dashed',\n", + " label=u'próg: $x={:.2F}$'.format(x_thr))\n", + "\n", + "# Wykres krzywej regresji logistycznej\n", + "def logistic_regline(fig, theta, X):\n", + " ax = fig.axes[0]\n", + " x0 = np.min(X[:, 1]) - 1.0\n", + " x1 = np.max(X[:, 1]) + 1.0\n", + " Arg = np.arange(x0, x1, 0.1)\n", + " Val = scalar_logistic_regression_function(theta, Arg)\n", + " ax.plot(Arg, Val, linewidth='2')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "scrolled": true, + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + " \r\n", + "\r\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = regdotsMx(XMx3, yMx3, xlabel='x', ylabel='Iris setosa?')\n", + "logistic_regline(fig, thetaBest, XMx3)\n", + "threshold_val(fig, 2.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Traktujemy wartość $h_\\theta(x)$ jako prawdopodobieństwo, że cecha przyjmie wartość pozytywną:\n", + "\n", + "$$ h_\\theta(x) = P(y = 1 \\, | \\, x; \\theta) $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Jeżeli $h_\\theta(x) > 0.5$, to dla takiego $x$ będziemy przewidywać wartość $y = 1$.\n", + "W przeciwnym wypadku uprzewidzimy $y = 0$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Dlaczego możemy traktować wartość funkcji regresji logistycznej jako prawdopodobieństwo?\n", + "\n", + "Można o tym poczytać w zewnętrznych źródłach, np. https://towardsdatascience.com/logit-of-logistic-regression-understanding-the-fundamentals-f384152a33d1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Dwuklasowa regresja logistyczna: więcej cech" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Jak postąpić, jeżeli będziemy mieli więcej niż jedną cechę $x$?\n", + "\n", + "Weźmy teraz wszystkie cechy występujące w zbiorze *Iris*." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " dł. płatków szer. płatków dł. dz. k. szer. dz. k. Iris setosa?\n", + "0 1.4 0.2 5.2 3.4 1\n", + "1 1.5 0.4 5.1 3.7 1\n", + "2 5.6 2.4 6.7 3.1 0\n", + "3 5.1 2.0 6.5 3.2 0\n", + "4 4.5 1.7 4.9 2.5 0\n", + "5 5.1 1.6 6.0 2.7 0\n" + ] + } + ], + "source": [ + "data_iris_setosa_multi = pandas.DataFrame()\n", + "data_iris_setosa_multi['dł. płatków'] = data_iris['pl'] # \"pl\" oznacza \"petal length\" (długość płatków)\n", + "data_iris_setosa_multi['szer. płatków'] = data_iris['pw'] # \"pw\" oznacza \"petal width\" (szerokość płatków)\n", + "data_iris_setosa_multi['dł. dz. k.'] = data_iris['sl'] # \"sl\" oznacza \"sepal length\" (długość działek kielicha)\n", + "data_iris_setosa_multi['szer. dz. k.'] = data_iris['sw'] # \"sw\" oznacza \"sepal width\" (szerokość działek kielicha)\n", + "data_iris_setosa_multi['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n", + "print(data_iris_setosa_multi[:6])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1. 1.4 0.2 5.2 3.4]\n", + " [1. 1.5 0.4 5.1 3.7]\n", + " [1. 5.6 2.4 6.7 3.1]\n", + " [1. 5.1 2. 6.5 3.2]\n", + " [1. 4.5 1.7 4.9 2.5]\n", + " [1. 5.1 1.6 6. 2.7]]\n", + "[[1.]\n", + " [1.]\n", + " [0.]\n", + " [0.]\n", + " [0.]\n", + " [0.]]\n" + ] + } + ], + "source": [ + "# Przygotowanie danych\n", + "m, n_plus_1 = data_iris_setosa_multi.values.shape\n", + "n = n_plus_1 - 1\n", + "Xn = data_iris_setosa_multi.values[:, 0:n].reshape(m, n)\n", + "\n", + "XMx4 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", + "yMx4 = np.matrix(data_iris_setosa_multi.values[:, n]).reshape(m, 1)\n", + "\n", + "print(XMx4[:6])\n", + "print(yMx4[:6])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "# Podział danych na zbiór trenujący i testowy\n", + "XTrain, XTest = XMx4[:100], XMx4[100:]\n", + "yTrain, yTest = yMx4[:100], yMx4[100:]\n", + "\n", + "# Macierz parametrów początkowych\n", + "thetaTemp = np.ones(5).reshape(5,1)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "error = [[0.006797]]\n", + "theta = [[ 1.11414027]\n", + " [-2.89324615]\n", + " [-0.66543637]\n", + " [ 0.14887292]\n", + " [ 2.13284493]]\n" + ] + } + ], + "source": [ + "thetaBest, errors = GD(h, J, dJ, thetaTemp, XTrain, yTrain, \n", + " alpha=0.1, eps=10**-7, maxSteps=1000)\n", + "print(\"error =\", errors[-1][0])\n", + "print(\"theta =\", thetaBest)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Funkcja decyzyjna regresji logistycznej\n", + "\n", + "Funkcja decyzyjna mówi o tym, kiedy nasz algorytm będzie przewidywał $y = 1$, a kiedy $y = 0$\n", + "\n", + "$$ c = \\left\\{ \n", + "\\begin{array}{ll}\n", + "1, & \\mbox{gdy } P(y=1 \\, | \\, x; \\theta) > 0.5 \\\\\n", + "0 & \\mbox{w przeciwnym przypadku}\n", + "\\end{array}\\right.\n", + "$$\n", + "\n", + "$$ P(y=1 \\,| \\, x; \\theta) = h_\\theta(x) $$" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "theta = [[ 1.11414027]\n", + " [-2.89324615]\n", + " [-0.66543637]\n", + " [ 0.14887292]\n", + " [ 2.13284493]]\n", + "x0 = [[1. 6.3 1.8 7.3 2.9]]\n", + "h(x0) = 1.6061436959824898e-05\n", + "c(x0) = (0, 1.6061436959824898e-05) \n", + "\n" + ] + } + ], + "source": [ + "def classifyBi(theta, X):\n", + " prob = h(theta, X).item()\n", + " return (1, prob) if prob > 0.5 else (0, prob)\n", + "\n", + "print(\"theta =\", thetaBest)\n", + "print(\"x0 =\", XTest[0])\n", + "print(\"h(x0) =\", h(thetaBest, XTest[0]).item())\n", + "print(\"c(x0) =\", classifyBi(thetaBest, XTest[0]), \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Obliczmy teraz skuteczność modelu (więcej na ten temat na następnym wykładzie, poświęconym metodom ewaluacji)." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 <=> 0 -- prob: 0.0\n", + "1 <=> 1 -- prob: 0.9816\n", + "0 <=> 0 -- prob: 0.0001\n", + "0 <=> 0 -- prob: 0.0005\n", + "0 <=> 0 -- prob: 0.0001\n", + "1 <=> 1 -- prob: 0.9936\n", + "0 <=> 0 -- prob: 0.0059\n", + "0 <=> 0 -- prob: 0.0992\n", + "0 <=> 0 -- prob: 0.0001\n", + "0 <=> 0 -- prob: 0.0001\n", + "\n", + "Accuracy: 1.0\n" + ] + } + ], + "source": [ + "acc = 0.0\n", + "for i, rest in enumerate(yTest):\n", + " cls, prob = classifyBi(thetaBest, XTest[i])\n", + " if i < 10:\n", + " print(int(yTest[i].item()), \"<=>\", cls, \"-- prob:\", round(prob, 4))\n", + " acc += cls == yTest[i].item()\n", + "\n", + "print(\"\\nAccuracy:\", acc / len(XTest))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## 3.2. Wieloklasowa regresja logistyczna" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Przykład: wszystkie cechy ze zbioru *Iris*, wszystkie 3 klasy ze zbioru *Iris*." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
slswplpwGatunek
05.23.41.40.2Iris-setosa
15.13.71.50.4Iris-setosa
26.73.15.62.4Iris-virginica
36.53.25.12.0Iris-virginica
44.92.54.51.7Iris-virginica
56.02.75.11.6Iris-versicolor
\n", + "
" + ], + "text/plain": [ + " sl sw pl pw Gatunek\n", + "0 5.2 3.4 1.4 0.2 Iris-setosa\n", + "1 5.1 3.7 1.5 0.4 Iris-setosa\n", + "2 6.7 3.1 5.6 2.4 Iris-virginica\n", + "3 6.5 3.2 5.1 2.0 Iris-virginica\n", + "4 4.9 2.5 4.5 1.7 Iris-virginica\n", + "5 6.0 2.7 5.1 1.6 Iris-versicolor" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas\n", + "data_iris = pandas.read_csv('iris.csv')\n", + "data_iris[:6]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X = [[1. 5.2 3.4 1.4 0.2]\n", + " [1. 5.1 3.7 1.5 0.4]\n", + " [1. 6.7 3.1 5.6 2.4]\n", + " [1. 6.5 3.2 5.1 2. ]]\n", + "y = [['Iris-setosa']\n", + " ['Iris-setosa']\n", + " ['Iris-virginica']\n", + " ['Iris-virginica']]\n" + ] + } + ], + "source": [ + "# Przygotowanie danych\n", + "\n", + "import numpy as np\n", + "\n", + "features = ['sl', 'sw', 'pl', 'pw']\n", + "m = len(data_iris)\n", + "X = np.matrix(data_iris[features])\n", + "X0 = np.ones(m).reshape(m, 1)\n", + "X = np.hstack((X0, X))\n", + "y = np.matrix(data_iris[[\"Gatunek\"]]).reshape(m, 1)\n", + "\n", + "print(\"X = \", X[:4])\n", + "print(\"y = \", y[:4])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Zamieńmy etykiety tekstowe w tablicy $y$ na wektory jednostkowe (*one-hot vectors*):\n", + "\n", + "$$\n", + "\\begin{array}{ccc}\n", + "\\mbox{\"Iris-setosa\"} & \\mapsto & \\left[ \\begin{array}{ccc} 1 & 0 & 0 \\\\ \\end{array} \\right] \\\\\n", + "\\mbox{\"Iris-virginica\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 1 & 0 \\\\ \\end{array} \\right] \\\\\n", + "\\mbox{\"Iris-versicolor\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 0 & 1 \\\\ \\end{array} \\right] \\\\\n", + "\\end{array}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Wówczas zamiast wektora $y$ otrzymamy macierz $Y$:\n", + "\n", + "$$\n", + "y \\; = \\;\n", + "\\left[\n", + "\\begin{array}{c}\n", + "y^{(1)} \\\\\n", + "y^{(2)} \\\\\n", + "y^{(3)} \\\\\n", + "y^{(4)} \\\\\n", + "y^{(5)} \\\\\n", + "\\vdots \\\\\n", + "\\end{array}\n", + "\\right]\n", + "\\; = \\;\n", + "\\left[\n", + "\\begin{array}{c}\n", + "\\mbox{\"Iris-setosa\"} \\\\\n", + "\\mbox{\"Iris-setosa\"} \\\\\n", + "\\mbox{\"Iris-virginica\"} \\\\\n", + "\\mbox{\"Iris-versicolor\"} \\\\\n", + "\\mbox{\"Iris-virginica\"} \\\\\n", + "\\vdots \\\\\n", + "\\end{array}\n", + "\\right]\n", + "\\quad \\mapsto \\quad\n", + "Y \\; = \\;\n", + "\\left[\n", + "\\begin{array}{ccc}\n", + "1 & 0 & 0 \\\\\n", + "1 & 0 & 0 \\\\\n", + "0 & 1 & 0 \\\\\n", + "0 & 0 & 1 \\\\\n", + "0 & 1 & 0 \\\\\n", + "\\vdots & \\vdots & \\vdots \\\\\n", + "\\end{array}\n", + "\\right]\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "def mapY(y, cls):\n", + " m = len(y)\n", + " yBi = np.matrix(np.zeros(m)).reshape(m, 1)\n", + " yBi[y == cls] = 1.\n", + " return yBi\n", + "\n", + "def indicatorMatrix(y):\n", + " classes = np.unique(y.tolist())\n", + " m = len(y)\n", + " k = len(classes)\n", + " Y = np.matrix(np.zeros((m, k)))\n", + " for i, cls in enumerate(classes):\n", + " Y[:, i] = mapY(y, cls)\n", + " return Y\n", + "\n", + "# one-hot matrix\n", + "Y = indicatorMatrix(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "outputs": [], + "source": [ + "# Podział danych na zbiór trenujący i testowy\n", + "XTrain, XTest = X[:100], X[100:]\n", + "YTrain, YTest = Y[:100], Y[100:]\n", + "\n", + "# Macierz parametrów początkowych - niech skłąda się z samych jedynek\n", + "thetaTemp = np.ones(5).reshape(5,1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Od regresji logistycznej dwuklasowej do wieloklasowej\n", + "\n", + "* Irysy są przydzielone do trzech klas: _Iris-setosa_ (0), _Iris-versicolor_ (1), _Iris-virginica_ (2).\n", + "* Wiemy, jak stworzyć klasyfikatory dwuklasowe typu _Iris-setosa_ vs. _Nie-Iris-setosa_ (tzw. *one-vs-all*).\n", + "* Możemy stworzyć trzy klasyfikatory $h_{\\theta_1}, h_{\\theta_2}, h_{\\theta_3}$ (otrzymując trzy zestawy parametrów $\\theta$) i wybrać klasę o najwyższym prawdopodobieństwie." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, + "source": [ + "Pomoże nam w tym funkcja *softmax*, która jest uogólnieniem funkcji logistycznej na większą liczbę wymiarów." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Funkcja _softmax_\n", + "\n", + "Odpowiednikiem funkcji logistycznej dla wieloklasowej regresji logistycznej jest funkcja $\\mathrm{softmax}$:\n", + "\n", + "$$ \\textrm{softmax} \\colon \\mathbb{R}^k \\to [0,1]^k $$\n", + "\n", + "$$ \\textrm{softmax}(z_1,z_2,\\dots,z_k) = \\left( \\dfrac{e^{z_1}}{\\sum_{i=1}^{k}e^{z_i}}, \\dfrac{e^{z_2}}{\\sum_{i=1}^{k}e^{z_i}}, \\ldots, \\dfrac{e^{z_k}}{\\sum_{i=1}^{k}e^{z_i}} \\right) $$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "$$ \\textrm{softmax}( \\left[ \\begin{array}{c} \\theta_1^T x \\\\ \\theta_2^T x \\\\ \\vdots \\\\ \\theta_k^T x \\end{array} \\right] ) = \\left[ \\begin{array}{c} P(y=1 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ P(y=2 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ \\vdots \\\\ P(y=k \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\end{array} \\right] $$" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "# Zapis macierzowy funkcji softmax\n", + "def softmax(X):\n", + " return np.exp(X) / np.sum(np.exp(X))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Wartości funkcji $\\mathrm{softmax}$ sumują się do 1:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9999999999999999\n" + ] + } + ], + "source": [ + "Z = np.matrix([[2.1, 0.5, 0.8, 0.9, 3.2]])\n", + "P = softmax(Z)\n", + "print(np.sum(P)) " + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Otrzymana macierz parametrów theta dla klasy 0:\n", + " [[ 0.68590262]\n", + " [ 0.39948964]\n", + " [ 1.13312933]\n", + " [-2.17550597]\n", + " [-0.53088875]] \n", + "\n", + "Otrzymana macierz parametrów theta dla klasy 1:\n", + " [[ 0.95431453]\n", + " [ 0.07249434]\n", + " [-1.07233395]\n", + " [ 0.53801787]\n", + " [-0.65001214]] \n", + "\n", + "Otrzymana macierz parametrów theta dla klasy 2:\n", + " [[-0.66101185]\n", + " [-1.40133883]\n", + " [-2.01776182]\n", + " [ 2.18505283]\n", + " [ 2.74690482]] \n", + "\n" + ] + } + ], + "source": [ + "# Dla każdej klasy wytrenujmy osobny klasyfikator dwuklasowy.\n", + "\n", + "def trainMaxEnt(X, Y):\n", + " n = X.shape[1]\n", + " thetas = []\n", + " for c in range(Y.shape[1]):\n", + " YBi = Y[:,c]\n", + " theta = np.matrix(np.random.random(n)).reshape(n,1)\n", + " # Macierz parametrów theta obliczona dla każdej klasy osobno.\n", + " thetaBest, errors = GD(h, J, dJ, theta, \n", + " X, YBi, alpha=0.1, eps=10**-4)\n", + " thetas.append(thetaBest)\n", + " return thetas\n", + "\n", + "# Macierze theta dla każdej klasy\n", + "thetas = trainMaxEnt(XTrain, YTrain);\n", + "for c, theta in enumerate(thetas):\n", + " print(f\"Otrzymana macierz parametrów theta dla klasy {c}:\\n\", theta, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Funkcja decyzyjna wieloklasowej regresji logistycznej\n", + "\n", + "$$ c = \\mathop{\\textrm{arg}\\,\\textrm{max}}_{i \\in \\{1, \\ldots ,k\\}} P(y=i|x;\\theta_1,\\ldots,\\theta_k) $$" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dla x = [[1. 7.3 2.9 6.3 1.8]]:\n", + "Po zastosowaniu regresji: [-7.77303536 0.59324542 1.96796697]\n", + "Otrzymane prawdopodobieństwa: [0. 0.202 0.798]\n", + "Wybrana klasa: 2\n", + "Obliczone y = 2\n", + "Oczekiwane y = 2\n", + "\n", + "Dla x = [[1. 4.8 3. 1.4 0.3]]:\n", + "Po zastosowaniu regresji: [ 2.79786587 -1.35649314 -9.55757825]\n", + "Otrzymane prawdopodobieństwa: [0.985 0.015 0. ]\n", + "Wybrana klasa: 0\n", + "Obliczone y = 0\n", + "Oczekiwane y = 0\n", + "\n", + "Dla x = [[1. 7.1 3. 5.9 2.1]]:\n", + "Po zastosowaniu regresji: [-7.02868459 0.06130237 1.99650886]\n", + "Otrzymane prawdopodobieństwa: [0. 0.126 0.874]\n", + "Wybrana klasa: 2\n", + "Obliczone y = 2\n", + "Oczekiwane y = 2\n", + "\n", + "Dla x = [[1. 5.9 3. 5.1 1.8]]:\n", + "Po zastosowaniu regresji: [-5.60840075 -0.26110148 1.10600174]\n", + "Otrzymane prawdopodobieństwa: [0.001 0.203 0.796]\n", + "Wybrana klasa: 2\n", + "Obliczone y = 2\n", + "Oczekiwane y = 2\n", + "\n", + "Dla x = [[1. 6.1 2.6 5.6 1.4]]:\n", + "Po zastosowaniu regresji: [-6.85715204 0.71134476 1.62660319]\n", + "Otrzymane prawdopodobieństwa: [0. 0.286 0.714]\n", + "Wybrana klasa: 2\n", + "Obliczone y = 2\n", + "Oczekiwane y = 2\n", + "\n" + ] + } + ], + "source": [ + "def classify(thetas, X, debug=False):\n", + " regs = np.array([(X*theta).item() for theta in thetas])\n", + " if debug:\n", + " print(\"Po zastosowaniu regresji: \", regs)\n", + " probs = softmax(regs)\n", + " if debug:\n", + " print(\"Otrzymane prawdopodobieństwa: \", np.around(probs,decimals=3))\n", + " result = np.argmax(probs)\n", + " if debug:\n", + " print(\"Wybrana klasa: \", result)\n", + " return result\n", + "\n", + "for i in range(5):\n", + " print(f\"Dla x = {XTest[i]}:\")\n", + " YPredicted = classify(thetas, XTest[i], debug=True)\n", + " print(f\"Obliczone y = {YPredicted}\")\n", + " print(f\"Oczekiwane y = {np.argmax(YTest[i])}\")\n", + " print()" + ] + } + ], + "metadata": { + "celltoolbar": "Slideshow", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "livereveal": { + "start_slideshow_at": "selected", + "theme": "white" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}