{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### AITech — Uczenie maszynowe\n", "# 3. Regresja logistyczna" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Uwaga**: Wbrew nazwie, *regresja* logistyczna jest algorytmem służącym do rozwiązywania problemów *klasyfikacji* (wcale nie problemów *regresji*!)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Do demonstracji metody regresji ligistycznej wykorzystamy klasyczny zbiór danych *Iris flower data set*, składający się ze 150 przykładów wartości 4 cech dla 3 gatunków irysów (kosaćców)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### *Iris flower data set*\n", "\n", "* 150 przykładów\n", "* 4 cechy\n", "* 3 kategorie" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "| | | |\n", "| :--- | :--- | :--- |\n", "| *Iris setosa* | *Iris virginica* | *Iris versicolor* |\n", "| kosaciec szczecinkowy | kosaciec amerykański | kosaciec różnobarwny |\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "4 cechy:\n", " * długość działek kielicha (*sepal length*, `sl`)\n", " * szerokość działek kielicha (*sepal width*, `sw`)\n", " * długość płatka (*petal length*, `pl`)\n", " * szerokość płatka (*petal width*, `pw`)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 3.1. Dwuklasowa regresja logistyczna" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Zacznijmy od najprostszego przypadku:\n", " * ograniczmy się do **2** klas\n", " * ograniczmy się do **1** zmiennej" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "→ dwuklasowa regresja logistyczna jednej zmiennej" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przydatne importy\n", "\n", "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as pl\n", "import pandas\n", "import ipywidgets as widgets\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "from IPython.display import display, Math, Latex\n", "\n", "# Przydatne funkcje\n", "\n", "# Wyświetlanie macierzy w LaTeX-u\n", "def LatexMatrix(matrix):\n", " ltx = r'\\left[\\begin{array}'\n", " m, n = matrix.shape\n", " ltx += '{' + (\"r\" * n) + '}'\n", " for i in range(m):\n", " ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n", " ltx += r'\\end{array}\\right]'\n", " return ltx\n", "\n", "# Hipoteza (wersja macierzowa)\n", "def hMx(theta, X):\n", " return X * theta\n", "\n", "# Wykres danych (wersja macierzowa)\n", "def regdotsMx(X, y, xlabel, ylabel): \n", " fig = pl.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.margins(.05, .05)\n", " pl.ylim(y.min() - 1, y.max() + 1)\n", " pl.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", " return fig\n", "\n", "# Wykres krzywej regresji (wersja macierzowa)\n", "def reglineMx(fig, fun, theta, X):\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " L = [x0, x1]\n", " LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n", " ax.plot(L, fun(theta, LX), linewidth='2',\n", " label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n", " theta0=float(theta[0][0]),\n", " theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n", " op='+' if theta[1][0] >= 0 else '-')))\n", "\n", "# Legenda wykresu\n", "def legend(fig):\n", " ax = fig.axes[0]\n", " handles, labels = ax.get_legend_handles_labels()\n", " # try-except block is a fix for a bug in Poly3DCollection\n", " try:\n", " fig.legend(handles, labels, fontsize='15', loc='lower right')\n", " except AttributeError:\n", " pass\n", "\n", "# Wersja macierzowa funkcji kosztu\n", "def JMx(theta,X,y):\n", " m = len(y)\n", " J = 1.0 / (2.0 * m) * ((X * theta - y).T * ( X * theta - y))\n", " return J.item()\n", "\n", "# Wersja macierzowa gradientu funkcji kosztu\n", "def dJMx(theta,X,y):\n", " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", "\n", "# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n", "def GDMx(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-3):\n", " current_cost = fJ(theta, X, y)\n", " logs = [[current_cost, theta]]\n", " while True:\n", " theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n", " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", " if current_cost > 10000:\n", " break\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " logs.append([current_cost, theta]) \n", " return theta, logs\n", "\n", "thetaStartMx = np.matrix([0, 0]).reshape(2, 1)\n", "\n", "# Funkcja, która rysuje próg\n", "def threshold(fig, theta):\n", " x_thr = (0.5 - theta.item(0)) / theta.item(1)\n", " ax = fig.axes[0]\n", " ax.plot([x_thr, x_thr], [-1, 2],\n", " color='orange', linestyle='dashed',\n", " label=u'próg: $x={:.2F}$'.format(x_thr))" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " sl sw pl pw Gatunek\n", "0 5.2 3.4 1.4 0.2 Iris-setosa\n", "1 5.1 3.7 1.5 0.4 Iris-setosa\n", "2 6.7 3.1 5.6 2.4 Iris-virginica\n", "3 6.5 3.2 5.1 2.0 Iris-virginica\n", "4 4.9 2.5 4.5 1.7 Iris-virginica\n", "5 6.0 2.7 5.1 1.6 Iris-versicolor\n" ] } ], "source": [ "# Wczytanie pełnych (oryginalnych) danych\n", "\n", "data_iris = pandas.read_csv('iris.csv')\n", "print(data_iris[:6])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " dł. płatka Iris setosa?\n", "0 1.4 1\n", "1 1.5 1\n", "2 5.6 0\n", "3 5.1 0\n", "4 4.5 0\n", "5 5.1 0\n" ] } ], "source": [ "# Ograniczenie danych do 2 klas i 1 cechy\n", "\n", "data_iris_setosa = pandas.DataFrame()\n", "data_iris_setosa['dł. płatka'] = data_iris['pl'] # \"pl\" oznacza \"petal length\"\n", "data_iris_setosa['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n", "print(data_iris_setosa[:6])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "# Przygotowanie danych\n", "m, n_plus_1 = data_iris_setosa.values.shape\n", "n = n_plus_1 - 1\n", "Xn = data_iris_setosa.values[:, 0:n].reshape(m, n)\n", "\n", "XMx3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", "yMx3 = np.matrix(data_iris_setosa.values[:, 1]).reshape(m, 1)\n", "\n", "# Regresja liniowa\n", "theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n", "legend(fig)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Próba zastosowania regresji liniowej do problemu klasyfikacji\n", "\n", "Najpierw z ciekawości sprawdźmy, co otrzymalibyśmy, gdybyśmy zastosowali regresję liniową do problemu klasyfikacji." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n", "reglineMx(fig, hMx, theta_e3, XMx3)\n", "legend(fig)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "A gdyby tak przyjąć, że klasyfikator zwraca $1$ dla $h(x) > 0.5$ i $0$ w przeciwnym przypadku?" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n", "theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)\n", "reglineMx(fig, hMx, theta_e3, XMx3)\n", "threshold(fig, theta_e3) # pomarańczowa linia oznacza granicę między klasą \"1\" a klasą \"0\" wyznaczoną przez próg \"h(x) = 0.5\"\n", "legend(fig)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ " * Krzywa regresji liniowej jest niezbyt dopasowana do danych klasyfikacyjnych.\n", " * Zastosowanie progu $y = 0.5$ nie zawsze pomaga uzyskać sensowny rezultat.\n", " * $h(x)$ może przyjmować wartości mniejsze od $0$ i większe od $1$ – jak interpretować takie wyniki?\n", "\n", "Wniosek: w przypadku problemów klasyfikacyjnych regresja liniowa nie wydaje się najlepszym rozwiązaniem." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Wprowadźmy zatem pewne modyfikacje do naszego modelu." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Zdefiniujmy następującą funkcję, którą będziemy nazywać funkcją *logistyczną* (albo *sigmoidalną*):" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "**Funkcja logistyczna (sigmoidalna)**:\n", "\n", "$$g(x) = \\dfrac{1}{1+e^{-x}}$$" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# Funkjca logistycza\n", "\n", "def logistic(x):\n", " return 1.0 / (1.0 + np.exp(-x))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "def plot_logistic():\n", " x = np.linspace(-5,5,200)\n", " y = logistic(x)\n", "\n", " fig = plt.figure(figsize=(7,5))\n", " ax = fig.add_subplot(111)\n", " plt.ylim(-.1,1.1)\n", " ax.plot(x, y, linewidth='2')" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Wykres funkcji logistycznej $g(x) = \\dfrac{1}{1+e^{-x}}$:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_logistic()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Funkcja logistyczna przekształca zbiór liczb rzeczywistych $\\mathbb{R}$ w przedział otwarty $(0, 1)$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Funkcja regresji logistycznej dla pojedynczego przykładu o cechach wyrażonych wektorem $x$:\n", "\n", "$$h_\\theta(x) = g(\\theta^T \\, x) = \\dfrac{1}{1 + e^{-\\theta^T x}}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Dla całej macierzy cech $X$:\n", "\n", "$$h_\\theta(X) = g(X \\, \\theta) = \\dfrac{1}{1 + e^{-X \\theta}}$$" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# Funkcja regresji logistcznej\n", "def h(theta, X):\n", " return 1.0/(1.0 + np.exp(-X * theta))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Funkcja kosztu dla regresji logistycznej:\n", "\n", "$$J(\\theta) = -\\dfrac{1}{m} \\left( \\sum_{i=1}^{m} y^{(i)} \\log h_\\theta( x^{(i)} ) + \\left( 1 - y^{(i)} \\right) \\log \\left( 1 - h_\\theta (x^{(i)}) \\right) \\right)$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Gradient dla regresji logistycznej (wersja macierzowa):\n", "\n", "$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T \\left( h_\\theta(X) - \\vec y \\right)$$\n", "\n", "(Jedyna różnica między gradientem dla regresji logistycznej a gradientem dla regresji liniowej to postać $h_\\theta$)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Funkcja kosztu dla regresji logistycznej\n", "def J(h, theta, X, y):\n", " m = len(y)\n", " h_val = h(theta, X)\n", " s1 = np.multiply(y, np.log(h_val))\n", " s2 = np.multiply((1 - y), np.log(1 - h_val))\n", " return -np.sum(s1 + s2, axis=0) / m" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# Gradient dla regresji logistycznej\n", "def dJ(h, theta, X, y):\n", " return 1.0 / len(y) * (X.T * (h(theta, X) - y))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Metoda gradientu prostego dla regresji logistycznej\n", "def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n", " errorCurr = fJ(h, theta, X, y)\n", " errors = [[errorCurr, theta]]\n", " while True:\n", " # oblicz nowe theta\n", " theta = theta - alpha * fdJ(h, theta, X, y)\n", " # raportuj poziom błędu\n", " errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n", " # kryteria stopu\n", " if abs(errorPrev - errorCurr) <= eps:\n", " break\n", " if len(errors) > maxSteps:\n", " break\n", " errors.append([errorCurr, theta]) \n", " return theta, errors" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "error = [[0.05755617]]\n", "theta = [[ 5.02530461]\n", " [-1.99174803]]\n" ] } ], "source": [ "# Uruchomienie metody gradientu prostego dla regresji logistycznej\n", "thetaBest, errors = GD(h, J, dJ, thetaStartMx, XMx3, yMx3, \n", " alpha=0.1, eps=10**-7, maxSteps=1000)\n", "print(\"error =\", errors[-1][0])\n", "print(\"theta =\", thetaBest)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Funkcja regresji logistycznej (wersja skalarna)\n", "def scalar_logistic_regression_function(theta, x):\n", " return 1.0/(1.0 + np.exp(-(theta.item(0) + theta.item(1) * x)))\n", "\n", "# Rysowanie progu\n", "def threshold_val(fig, x_thr):\n", " ax = fig.axes[0]\n", " ax.plot([x_thr, x_thr], [-1, 2],\n", " color='orange', linestyle='dashed',\n", " label=u'próg: $x={:.2F}$'.format(x_thr))\n", "\n", "# Wykres krzywej regresji logistycznej\n", "def logistic_regline(fig, theta, X):\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " Arg = np.arange(x0, x1, 0.1)\n", " Val = scalar_logistic_regression_function(theta, Arg)\n", " ax.plot(Arg, Val, linewidth='2')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", " \r\n", "\r\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = regdotsMx(XMx3, yMx3, xlabel='x', ylabel='Iris setosa?')\n", "logistic_regline(fig, thetaBest, XMx3)\n", "threshold_val(fig, 2.5)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Traktujemy wartość $h_\\theta(x)$ jako prawdopodobieństwo, że cecha przyjmie wartość pozytywną:\n", "\n", "$$ h_\\theta(x) = P(y = 1 \\, | \\, x; \\theta) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Jeżeli $h_\\theta(x) > 0.5$, to dla takiego $x$ będziemy przewidywać wartość $y = 1$.\n", "W przeciwnym wypadku uprzewidzimy $y = 0$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Dlaczego możemy traktować wartość funkcji regresji logistycznej jako prawdopodobieństwo?\n", "\n", "Można o tym poczytać w zewnętrznych źródłach, np. https://towardsdatascience.com/logit-of-logistic-regression-understanding-the-fundamentals-f384152a33d1" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Dwuklasowa regresja logistyczna: więcej cech" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Jak postąpić, jeżeli będziemy mieli więcej niż jedną cechę $x$?\n", "\n", "Weźmy teraz wszystkie cechy występujące w zbiorze *Iris*." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " dł. płatków szer. płatków dł. dz. k. szer. dz. k. Iris setosa?\n", "0 1.4 0.2 5.2 3.4 1\n", "1 1.5 0.4 5.1 3.7 1\n", "2 5.6 2.4 6.7 3.1 0\n", "3 5.1 2.0 6.5 3.2 0\n", "4 4.5 1.7 4.9 2.5 0\n", "5 5.1 1.6 6.0 2.7 0\n" ] } ], "source": [ "data_iris_setosa_multi = pandas.DataFrame()\n", "data_iris_setosa_multi['dł. płatków'] = data_iris['pl'] # \"pl\" oznacza \"petal length\" (długość płatków)\n", "data_iris_setosa_multi['szer. płatków'] = data_iris['pw'] # \"pw\" oznacza \"petal width\" (szerokość płatków)\n", "data_iris_setosa_multi['dł. dz. k.'] = data_iris['sl'] # \"sl\" oznacza \"sepal length\" (długość działek kielicha)\n", "data_iris_setosa_multi['szer. dz. k.'] = data_iris['sw'] # \"sw\" oznacza \"sepal width\" (szerokość działek kielicha)\n", "data_iris_setosa_multi['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n", "print(data_iris_setosa_multi[:6])" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1. 1.4 0.2 5.2 3.4]\n", " [1. 1.5 0.4 5.1 3.7]\n", " [1. 5.6 2.4 6.7 3.1]\n", " [1. 5.1 2. 6.5 3.2]\n", " [1. 4.5 1.7 4.9 2.5]\n", " [1. 5.1 1.6 6. 2.7]]\n", "[[1.]\n", " [1.]\n", " [0.]\n", " [0.]\n", " [0.]\n", " [0.]]\n" ] } ], "source": [ "# Przygotowanie danych\n", "m, n_plus_1 = data_iris_setosa_multi.values.shape\n", "n = n_plus_1 - 1\n", "Xn = data_iris_setosa_multi.values[:, 0:n].reshape(m, n)\n", "\n", "XMx4 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n", "yMx4 = np.matrix(data_iris_setosa_multi.values[:, n]).reshape(m, 1)\n", "\n", "print(XMx4[:6])\n", "print(yMx4[:6])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Podział danych na zbiór trenujący i testowy\n", "XTrain, XTest = XMx4[:100], XMx4[100:]\n", "yTrain, yTest = yMx4[:100], yMx4[100:]\n", "\n", "# Macierz parametrów początkowych\n", "thetaTemp = np.ones(5).reshape(5,1)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "error = [[0.006797]]\n", "theta = [[ 1.11414027]\n", " [-2.89324615]\n", " [-0.66543637]\n", " [ 0.14887292]\n", " [ 2.13284493]]\n" ] } ], "source": [ "thetaBest, errors = GD(h, J, dJ, thetaTemp, XTrain, yTrain, \n", " alpha=0.1, eps=10**-7, maxSteps=1000)\n", "print(\"error =\", errors[-1][0])\n", "print(\"theta =\", thetaBest)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Funkcja decyzyjna regresji logistycznej\n", "\n", "Funkcja decyzyjna mówi o tym, kiedy nasz algorytm będzie przewidywał $y = 1$, a kiedy $y = 0$\n", "\n", "$$ c = \\left\\{ \n", "\\begin{array}{ll}\n", "1, & \\mbox{gdy } P(y=1 \\, | \\, x; \\theta) > 0.5 \\\\\n", "0 & \\mbox{w przeciwnym przypadku}\n", "\\end{array}\\right.\n", "$$\n", "\n", "$$ P(y=1 \\,| \\, x; \\theta) = h_\\theta(x) $$" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "theta = [[ 1.11414027]\n", " [-2.89324615]\n", " [-0.66543637]\n", " [ 0.14887292]\n", " [ 2.13284493]]\n", "x0 = [[1. 6.3 1.8 7.3 2.9]]\n", "h(x0) = 1.6061436959824898e-05\n", "c(x0) = (0, 1.6061436959824898e-05) \n", "\n" ] } ], "source": [ "def classifyBi(theta, X):\n", " prob = h(theta, X).item()\n", " return (1, prob) if prob > 0.5 else (0, prob)\n", "\n", "print(\"theta =\", thetaBest)\n", "print(\"x0 =\", XTest[0])\n", "print(\"h(x0) =\", h(thetaBest, XTest[0]).item())\n", "print(\"c(x0) =\", classifyBi(thetaBest, XTest[0]), \"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Obliczmy teraz skuteczność modelu (więcej na ten temat na następnym wykładzie, poświęconym metodom ewaluacji)." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 <=> 0 -- prob: 0.0\n", "1 <=> 1 -- prob: 0.9816\n", "0 <=> 0 -- prob: 0.0001\n", "0 <=> 0 -- prob: 0.0005\n", "0 <=> 0 -- prob: 0.0001\n", "1 <=> 1 -- prob: 0.9936\n", "0 <=> 0 -- prob: 0.0059\n", "0 <=> 0 -- prob: 0.0992\n", "0 <=> 0 -- prob: 0.0001\n", "0 <=> 0 -- prob: 0.0001\n", "\n", "Accuracy: 1.0\n" ] } ], "source": [ "acc = 0.0\n", "for i, rest in enumerate(yTest):\n", " cls, prob = classifyBi(thetaBest, XTest[i])\n", " if i < 10:\n", " print(int(yTest[i].item()), \"<=>\", cls, \"-- prob:\", round(prob, 4))\n", " acc += cls == yTest[i].item()\n", "\n", "print(\"\\nAccuracy:\", acc / len(XTest))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 3.2. Wieloklasowa regresja logistyczna" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Przykład: wszystkie cechy ze zbioru *Iris*, wszystkie 3 klasy ze zbioru *Iris*." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
slswplpwGatunek
05.23.41.40.2Iris-setosa
15.13.71.50.4Iris-setosa
26.73.15.62.4Iris-virginica
36.53.25.12.0Iris-virginica
44.92.54.51.7Iris-virginica
56.02.75.11.6Iris-versicolor
\n", "
" ], "text/plain": [ " sl sw pl pw Gatunek\n", "0 5.2 3.4 1.4 0.2 Iris-setosa\n", "1 5.1 3.7 1.5 0.4 Iris-setosa\n", "2 6.7 3.1 5.6 2.4 Iris-virginica\n", "3 6.5 3.2 5.1 2.0 Iris-virginica\n", "4 4.9 2.5 4.5 1.7 Iris-virginica\n", "5 6.0 2.7 5.1 1.6 Iris-versicolor" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas\n", "data_iris = pandas.read_csv('iris.csv')\n", "data_iris[:6]" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X = [[1. 5.2 3.4 1.4 0.2]\n", " [1. 5.1 3.7 1.5 0.4]\n", " [1. 6.7 3.1 5.6 2.4]\n", " [1. 6.5 3.2 5.1 2. ]]\n", "y = [['Iris-setosa']\n", " ['Iris-setosa']\n", " ['Iris-virginica']\n", " ['Iris-virginica']]\n" ] } ], "source": [ "# Przygotowanie danych\n", "\n", "import numpy as np\n", "\n", "features = ['sl', 'sw', 'pl', 'pw']\n", "m = len(data_iris)\n", "X = np.matrix(data_iris[features])\n", "X0 = np.ones(m).reshape(m, 1)\n", "X = np.hstack((X0, X))\n", "y = np.matrix(data_iris[[\"Gatunek\"]]).reshape(m, 1)\n", "\n", "print(\"X = \", X[:4])\n", "print(\"y = \", y[:4])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Zamieńmy etykiety tekstowe w tablicy $y$ na wektory jednostkowe (*one-hot vectors*):\n", "\n", "$$\n", "\\begin{array}{ccc}\n", "\\mbox{\"Iris-setosa\"} & \\mapsto & \\left[ \\begin{array}{ccc} 1 & 0 & 0 \\\\ \\end{array} \\right] \\\\\n", "\\mbox{\"Iris-virginica\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 1 & 0 \\\\ \\end{array} \\right] \\\\\n", "\\mbox{\"Iris-versicolor\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 0 & 1 \\\\ \\end{array} \\right] \\\\\n", "\\end{array}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Wówczas zamiast wektora $y$ otrzymamy macierz $Y$:\n", "\n", "$$\n", "y \\; = \\;\n", "\\left[\n", "\\begin{array}{c}\n", "y^{(1)} \\\\\n", "y^{(2)} \\\\\n", "y^{(3)} \\\\\n", "y^{(4)} \\\\\n", "y^{(5)} \\\\\n", "\\vdots \\\\\n", "\\end{array}\n", "\\right]\n", "\\; = \\;\n", "\\left[\n", "\\begin{array}{c}\n", "\\mbox{\"Iris-setosa\"} \\\\\n", "\\mbox{\"Iris-setosa\"} \\\\\n", "\\mbox{\"Iris-virginica\"} \\\\\n", "\\mbox{\"Iris-versicolor\"} \\\\\n", "\\mbox{\"Iris-virginica\"} \\\\\n", "\\vdots \\\\\n", "\\end{array}\n", "\\right]\n", "\\quad \\mapsto \\quad\n", "Y \\; = \\;\n", "\\left[\n", "\\begin{array}{ccc}\n", "1 & 0 & 0 \\\\\n", "1 & 0 & 0 \\\\\n", "0 & 1 & 0 \\\\\n", "0 & 0 & 1 \\\\\n", "0 & 1 & 0 \\\\\n", "\\vdots & \\vdots & \\vdots \\\\\n", "\\end{array}\n", "\\right]\n", "$$" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "def mapY(y, cls):\n", " m = len(y)\n", " yBi = np.matrix(np.zeros(m)).reshape(m, 1)\n", " yBi[y == cls] = 1.\n", " return yBi\n", "\n", "def indicatorMatrix(y):\n", " classes = np.unique(y.tolist())\n", " m = len(y)\n", " k = len(classes)\n", " Y = np.matrix(np.zeros((m, k)))\n", " for i, cls in enumerate(classes):\n", " Y[:, i] = mapY(y, cls)\n", " return Y\n", "\n", "# one-hot matrix\n", "Y = indicatorMatrix(y)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Podział danych na zbiór trenujący i testowy\n", "XTrain, XTest = X[:100], X[100:]\n", "YTrain, YTest = Y[:100], Y[100:]\n", "\n", "# Macierz parametrów początkowych - niech skłąda się z samych jedynek\n", "thetaTemp = np.ones(5).reshape(5,1)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Od regresji logistycznej dwuklasowej do wieloklasowej\n", "\n", "* Irysy są przydzielone do trzech klas: _Iris-setosa_ (0), _Iris-versicolor_ (1), _Iris-virginica_ (2).\n", "* Wiemy, jak stworzyć klasyfikatory dwuklasowe typu _Iris-setosa_ vs. _Nie-Iris-setosa_ (tzw. *one-vs-all*).\n", "* Możemy stworzyć trzy klasyfikatory $h_{\\theta_1}, h_{\\theta_2}, h_{\\theta_3}$ (otrzymując trzy zestawy parametrów $\\theta$) i wybrać klasę o najwyższym prawdopodobieństwie." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Pomoże nam w tym funkcja *softmax*, która jest uogólnieniem funkcji logistycznej na większą liczbę wymiarów." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Funkcja _softmax_\n", "\n", "Odpowiednikiem funkcji logistycznej dla wieloklasowej regresji logistycznej jest funkcja $\\mathrm{softmax}$:\n", "\n", "$$ \\textrm{softmax} \\colon \\mathbb{R}^k \\to [0,1]^k $$\n", "\n", "$$ \\textrm{softmax}(z_1,z_2,\\dots,z_k) = \\left( \\dfrac{e^{z_1}}{\\sum_{i=1}^{k}e^{z_i}}, \\dfrac{e^{z_2}}{\\sum_{i=1}^{k}e^{z_i}}, \\ldots, \\dfrac{e^{z_k}}{\\sum_{i=1}^{k}e^{z_i}} \\right) $$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "$$ \\textrm{softmax}( \\left[ \\begin{array}{c} \\theta_1^T x \\\\ \\theta_2^T x \\\\ \\vdots \\\\ \\theta_k^T x \\end{array} \\right] ) = \\left[ \\begin{array}{c} P(y=1 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ P(y=2 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ \\vdots \\\\ P(y=k \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\end{array} \\right] $$" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Zapis macierzowy funkcji softmax\n", "def softmax(X):\n", " return np.exp(X) / np.sum(np.exp(X))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Wartości funkcji $\\mathrm{softmax}$ sumują się do 1:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9999999999999999\n" ] } ], "source": [ "Z = np.matrix([[2.1, 0.5, 0.8, 0.9, 3.2]])\n", "P = softmax(Z)\n", "print(np.sum(P)) " ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Otrzymana macierz parametrów theta dla klasy 0:\n", " [[ 0.68590262]\n", " [ 0.39948964]\n", " [ 1.13312933]\n", " [-2.17550597]\n", " [-0.53088875]] \n", "\n", "Otrzymana macierz parametrów theta dla klasy 1:\n", " [[ 0.95431453]\n", " [ 0.07249434]\n", " [-1.07233395]\n", " [ 0.53801787]\n", " [-0.65001214]] \n", "\n", "Otrzymana macierz parametrów theta dla klasy 2:\n", " [[-0.66101185]\n", " [-1.40133883]\n", " [-2.01776182]\n", " [ 2.18505283]\n", " [ 2.74690482]] \n", "\n" ] } ], "source": [ "# Dla każdej klasy wytrenujmy osobny klasyfikator dwuklasowy.\n", "\n", "def trainMaxEnt(X, Y):\n", " n = X.shape[1]\n", " thetas = []\n", " for c in range(Y.shape[1]):\n", " YBi = Y[:,c]\n", " theta = np.matrix(np.random.random(n)).reshape(n,1)\n", " # Macierz parametrów theta obliczona dla każdej klasy osobno.\n", " thetaBest, errors = GD(h, J, dJ, theta, \n", " X, YBi, alpha=0.1, eps=10**-4)\n", " thetas.append(thetaBest)\n", " return thetas\n", "\n", "# Macierze theta dla każdej klasy\n", "thetas = trainMaxEnt(XTrain, YTrain);\n", "for c, theta in enumerate(thetas):\n", " print(f\"Otrzymana macierz parametrów theta dla klasy {c}:\\n\", theta, \"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Funkcja decyzyjna wieloklasowej regresji logistycznej\n", "\n", "$$ c = \\mathop{\\textrm{arg}\\,\\textrm{max}}_{i \\in \\{1, \\ldots ,k\\}} P(y=i|x;\\theta_1,\\ldots,\\theta_k) $$" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dla x = [[1. 7.3 2.9 6.3 1.8]]:\n", "Po zastosowaniu regresji: [-7.77303536 0.59324542 1.96796697]\n", "Otrzymane prawdopodobieństwa: [0. 0.202 0.798]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n", "Dla x = [[1. 4.8 3. 1.4 0.3]]:\n", "Po zastosowaniu regresji: [ 2.79786587 -1.35649314 -9.55757825]\n", "Otrzymane prawdopodobieństwa: [0.985 0.015 0. ]\n", "Wybrana klasa: 0\n", "Obliczone y = 0\n", "Oczekiwane y = 0\n", "\n", "Dla x = [[1. 7.1 3. 5.9 2.1]]:\n", "Po zastosowaniu regresji: [-7.02868459 0.06130237 1.99650886]\n", "Otrzymane prawdopodobieństwa: [0. 0.126 0.874]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n", "Dla x = [[1. 5.9 3. 5.1 1.8]]:\n", "Po zastosowaniu regresji: [-5.60840075 -0.26110148 1.10600174]\n", "Otrzymane prawdopodobieństwa: [0.001 0.203 0.796]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n", "Dla x = [[1. 6.1 2.6 5.6 1.4]]:\n", "Po zastosowaniu regresji: [-6.85715204 0.71134476 1.62660319]\n", "Otrzymane prawdopodobieństwa: [0. 0.286 0.714]\n", "Wybrana klasa: 2\n", "Obliczone y = 2\n", "Oczekiwane y = 2\n", "\n" ] } ], "source": [ "def classify(thetas, X, debug=False):\n", " regs = np.array([(X*theta).item() for theta in thetas])\n", " if debug:\n", " print(\"Po zastosowaniu regresji: \", regs)\n", " probs = softmax(regs)\n", " if debug:\n", " print(\"Otrzymane prawdopodobieństwa: \", np.around(probs,decimals=3))\n", " result = np.argmax(probs)\n", " if debug:\n", " print(\"Wybrana klasa: \", result)\n", " return result\n", "\n", "for i in range(5):\n", " print(f\"Dla x = {XTest[i]}:\")\n", " YPredicted = classify(thetas, XTest[i], debug=True)\n", " print(f\"Obliczone y = {YPredicted}\")\n", " print(f\"Oczekiwane y = {np.argmax(YTest[i])}\")\n", " print()" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" }, "livereveal": { "start_slideshow_at": "selected", "theme": "white" } }, "nbformat": 4, "nbformat_minor": 4 }