{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Uczenie maszynowe – zastosowania\n",
"# 3. Regresja logistyczna"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wbrew nazwie, *regresja* logistyczna jest algorytmem służącym do rozwiązywania problemów *klasyfikacji* (wcale nie problemów *regresji*!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.1. Dwuklasowa regresja logistyczna"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Zacznijmy od najprostszego przypadku: chcemy nasze dane przypisać do jednej z **dwóch** klas.\n",
"W tym celu użyjemy regresji logistycznej **dwuklasowej**."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"\n",
"\n",
"### Przykład: kosaciec szczecinkowy (*Iris setosa*)\n",
"\n",
"Mamy dane dotyczące długości płatków i chcielibyśmy na tej podstawie określić, czy jest to roślina z gatunku _Iris setosa_"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import numpy as np\n",
"import matplotlib\n",
"import matplotlib.pyplot as pl\n",
"import pandas\n",
"import ipywidgets as widgets\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'svg'\n",
"\n",
"from IPython.display import display, Math, Latex\n",
"\n",
"# Przydatne funkcje\n",
"\n",
"# Wyświetlanie macierzy w LaTeX-u\n",
"def LatexMatrix(matrix):\n",
" ltx = r'\\left[\\begin{array}'\n",
" m, n = matrix.shape\n",
" ltx += '{' + (\"r\" * n) + '}'\n",
" for i in range(m):\n",
" ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n",
" ltx += r'\\end{array}\\right]'\n",
" return ltx\n",
"\n",
"# Hipoteza (wersja macierzowa)\n",
"def hMx(theta, X):\n",
" return X * theta\n",
"\n",
"# Wykres danych (wersja macierzowa)\n",
"def regdotsMx(X, y, xlabel, ylabel): \n",
" fig = pl.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" pl.ylim(y.min() - 1, y.max() + 1)\n",
" pl.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"# Wykres krzywej regresji (wersja macierzowa)\n",
"def reglineMx(fig, fun, theta, X):\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" L = [x0, x1]\n",
" LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n",
" ax.plot(L, fun(theta, LX), linewidth='2',\n",
" label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n",
" theta0=float(theta[0][0]),\n",
" theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n",
" op='+' if theta[1][0] >= 0 else '-')))\n",
"\n",
"# Legenda wykresu\n",
"def legend(fig):\n",
" ax = fig.axes[0]\n",
" handles, labels = ax.get_legend_handles_labels()\n",
" # try-except block is a fix for a bug in Poly3DCollection\n",
" try:\n",
" fig.legend(handles, labels, fontsize='15', loc='lower right')\n",
" except AttributeError:\n",
" pass\n",
"\n",
"# Wersja macierzowa funkcji kosztu\n",
"def JMx(theta,X,y):\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * ( X * theta - y))\n",
" return J.item()\n",
"\n",
"# Wersja macierzowa gradientu funkcji kosztu\n",
"def dJMx(theta,X,y):\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n",
"def GDMx(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-3):\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if current_cost > 10000:\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"thetaStartMx = np.matrix([0, 0]).reshape(2, 1)\n",
"\n",
"# Funkcja, która rysuje próg\n",
"def threshold(fig, theta):\n",
" x_thr = (0.5 - theta.item(0)) / theta.item(1)\n",
" ax = fig.axes[0]\n",
" ax.plot([x_thr, x_thr], [-1, 2],\n",
" color='orange', linestyle='dashed',\n",
" label=u'próg: $x={:.2F}$'.format(x_thr))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" sl sw pl pw Gatunek\n",
"0 5.2 3.4 1.4 0.2 Iris-setosa\n",
"1 5.1 3.7 1.5 0.4 Iris-setosa\n",
"2 6.7 3.1 5.6 2.4 Iris-virginica\n",
"3 6.5 3.2 5.1 2.0 Iris-virginica\n",
"4 4.9 2.5 4.5 1.7 Iris-virginica\n",
"5 6.0 2.7 5.1 1.6 Iris-versicolor\n"
]
}
],
"source": [
"# Wczytanie danych\n",
"\n",
"data_iris = pandas.read_csv('iris.csv')\n",
"print(data_iris[:6])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" dł. płatka Iris setosa?\n",
"0 1.4 1\n",
"1 1.5 1\n",
"2 5.6 0\n",
"3 5.1 0\n",
"4 4.5 0\n",
"5 5.1 0\n"
]
}
],
"source": [
"data_iris_setosa = pandas.DataFrame()\n",
"data_iris_setosa['dł. płatka'] = data_iris['pl'] # \"pl\" oznacza \"petal length\"\n",
"data_iris_setosa['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n",
"print(data_iris_setosa[:6])"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Czy możemy tu zastosować regresję liniową?\n",
"\n",
"Spróbujmy:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(150, 2)\n"
]
}
],
"source": [
"print(data_iris_setosa.values.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"m, n_plus_1 = data_iris_setosa.values.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data_iris_setosa.values[:, 0:n].reshape(m, n)\n",
"\n",
"XMx3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx3 = np.matrix(data_iris_setosa.values[:, 1]).reshape(m, 1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\r\n",
"\r\n",
"\r\n",
"\r\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n",
"theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)\n",
"reglineMx(fig, hMx, theta_e3, XMx3)\n",
"threshold(fig, theta_e3)\n",
"legend(fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
" * Krzywa regresji liniowej jest niezbyt dopasowana do danych klasyfikacyjnych.\n",
" * Zastosowanie progu $y = 0.5$ nie zawsze pomaga uzyskać sensowny rezultat.\n",
" * $h(x)$ może przyjmować wartości mniejsze od $0$ i większe od $1$ – jak interpretować takie wyniki?\n",
"\n",
"Wniosek: w przypadku problemów klasyfikacyjnych regresja liniowa nie wydaje się najlepszym rozwiązaniem."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zdefiniujmy sobie następującą funkcję, którą będziemy nazywać funkcją *logistyczną* (albo *sigmoidalną*):"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"**Funkcja logistyczna (sigmoidalna)**:\n",
"\n",
"$$g(x) = \\dfrac{1}{1+e^{-x}}$$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkjca logistycza\n",
"\n",
"def logistic(x):\n",
" return 1.0 / (1.0 + np.exp(-x))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"def plot_logistic():\n",
" x = np.linspace(-5,5,200)\n",
" y = logistic(x)\n",
"\n",
" fig = plt.figure(figsize=(7,5))\n",
" ax = fig.add_subplot(111)\n",
" plt.ylim(-.1,1.1)\n",
" ax.plot(x, y, linewidth='2')"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wykres funkcji logistycznej $g(x) = \\dfrac{1}{1+e^{-x}}$:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\r\n",
"\r\n",
"\r\n",
"\r\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_logistic()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Funkcja logistyczna przekształca zbiór liczb rzeczywistych $\\mathbb{R}$ w przedział otwarty $(0, 1)$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja regresji logistycznej:\n",
"\n",
"$$h_\\theta(x) = g(\\theta^T \\, x) = \\dfrac{1}{1 + e^{-\\theta^T x}}$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Wersja macierzowa:\n",
"\n",
"$$h_\\theta(X) = g(X \\, \\theta) = \\dfrac{1}{1 + e^{-X \\theta}}$$"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Funkcja regresji logistcznej\n",
"def h(theta, X):\n",
" return 1.0/(1.0 + np.exp(-X * theta))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kosztu dla regresji logistycznej:\n",
"\n",
"$$J(\\theta) = -\\dfrac{1}{m} \\left( \\sum_{i=1}^{m} y^{(i)} \\log h_\\theta( x^{(i)} ) + \\left( 1 - y^{(i)} \\right) \\log \\left( 1 - h_\\theta (x^{(i)}) \\right) \\right)$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Gradient dla regresji logistycznej (wersja macierzowa):\n",
"\n",
"$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T \\left( h_\\theta(X) - \\vec y \\right)$$\n",
"\n",
"(Jedyna różnica między gradientem dla regresji logistycznej a gradientem dla regresji liniowej to postać $h_\\theta$)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Funkcja kosztu dla regresji logistycznej\n",
"def J(h, theta, X, y):\n",
" m = len(y)\n",
" h_val = h(theta, X)\n",
" s1 = np.multiply(y, np.log(h_val))\n",
" s2 = np.multiply((1 - y), np.log(1 - h_val))\n",
" return -np.sum(s1 + s2, axis=0) / m"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Gradient dla regresji logistycznej\n",
"def dJ(h, theta, X, y):\n",
" return 1.0 / len(y) * (X.T * (h(theta, X) - y))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Metoda gradientu prostego dla regresji logistycznej\n",
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta]) \n",
" return theta, errors"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"error = [[0.05755617]]\n",
"theta = [[ 5.02530461]\n",
" [-1.99174803]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"thetaBest, errors = GD(h, J, dJ, thetaStartMx, XMx3, yMx3, \n",
" alpha=0.1, eps=10**-7, maxSteps=1000)\n",
"print(\"error =\", errors[-1][0])\n",
"print(\"theta =\", thetaBest)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Funkcja regresji logistycznej (wersja skalarna)\n",
"def scalar_logistic_regression_function(theta, x):\n",
" return 1.0/(1.0 + np.exp(-(theta.item(0) + theta.item(1) * x)))\n",
"\n",
"# Rysowanie progu\n",
"def threshold_val(fig, x_thr):\n",
" ax = fig.axes[0]\n",
" ax.plot([x_thr, x_thr], [-1, 2],\n",
" color='orange', linestyle='dashed',\n",
" label=u'próg: $x={:.2F}$'.format(x_thr))\n",
"\n",
"# Wykres krzywej regresji logistycznej\n",
"def logistic_regline(fig, theta, X):\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = scalar_logistic_regression_function(theta, Arg)\n",
" ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\r\n",
"\r\n",
"\r\n",
"\r\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdotsMx(XMx3, yMx3, xlabel='x', ylabel='Iris setosa?')\n",
"logistic_regline(fig, thetaBest, XMx3)\n",
"threshold_val(fig, 2.5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Traktujemy wartość $h_\\theta(x)$ jako prawdopodobieństwo, że cecha przyjmie wartość pozytywną:\n",
"\n",
"$$ h_\\theta(x) = P(y = 1 \\, | \\, x; \\theta) $$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jeżeli $h_\\theta(x) > 0.5$, to dla takiego $x$ będziemy przewidywać wartość $y = 1$.\n",
"W przeciwnym wypadku uprzewidzimy $y = 0$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Dwuklasowa regresja logistyczna: więcej cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Jak postąpić, jeżeli będziemy mieli więcej niż jedną cechę $x$?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Weźmy pod uwagę następujące cechy:\n",
" * długość działek kielicha\n",
" * szerokość działek kielicha\n",
" * długość płatka\n",
" * szerokość płatka"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" dł. płatków szer. płatków dł. dz. k. szer. dz. k. Iris setosa?\n",
"0 1.4 0.2 5.2 3.4 1\n",
"1 1.5 0.4 5.1 3.7 1\n",
"2 5.6 2.4 6.7 3.1 0\n",
"3 5.1 2.0 6.5 3.2 0\n",
"4 4.5 1.7 4.9 2.5 0\n",
"5 5.1 1.6 6.0 2.7 0\n"
]
}
],
"source": [
"data_iris_setosa_multi = pandas.DataFrame()\n",
"data_iris_setosa_multi['dł. płatków'] = data_iris['pl'] # \"pl\" oznacza \"petal length\" (długość płatków)\n",
"data_iris_setosa_multi['szer. płatków'] = data_iris['pw'] # \"pw\" oznacza \"petal width\" (szerokość płatków)\n",
"data_iris_setosa_multi['dł. dz. k.'] = data_iris['sl'] # \"sl\" oznacza \"sepal length\" (długość działek kielicha)\n",
"data_iris_setosa_multi['szer. dz. k.'] = data_iris['sw'] # \"sw\" oznacza \"sepal width\" (szerokość działek kielicha)\n",
"data_iris_setosa_multi['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n",
"print(data_iris_setosa_multi[:6])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/svg+xml": [
"\r\n",
"\r\n",
"\r\n",
"\r\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Poniższy wykres przedstawia zależności między wszystkimi cechami\n",
"\n",
"seaborn.pairplot(\n",
" data_iris_setosa_multi,\n",
" vars=[c for c in data_iris_setosa_multi.columns if c != 'Iris setosa?'], \n",
" hue='Iris setosa?', height=1.5, aspect=1.75)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1. 1.4 0.2 5.2 3.4]\n",
" [1. 1.5 0.4 5.1 3.7]\n",
" [1. 5.6 2.4 6.7 3.1]\n",
" [1. 5.1 2. 6.5 3.2]\n",
" [1. 4.5 1.7 4.9 2.5]\n",
" [1. 5.1 1.6 6. 2.7]]\n",
"[[1.]\n",
" [1.]\n",
" [0.]\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n"
]
}
],
"source": [
"# Przygotowanie danych\n",
"m, n_plus_1 = data_iris_setosa_multi.values.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data_iris_setosa_multi.values[:, 0:n].reshape(m, n)\n",
"\n",
"XMx4 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx4 = np.matrix(data_iris_setosa_multi.values[:, n]).reshape(m, 1)\n",
"\n",
"print(XMx4[:6])\n",
"print(yMx4[:6])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Podział danych na zbiór trenujący i testowy\n",
"XTrain, XTest = XMx4[:100], XMx4[100:]\n",
"yTrain, yTest = yMx4[:100], yMx4[100:]\n",
"\n",
"# Macierz parametrów początkowych\n",
"thetaTemp = np.ones(5).reshape(5,1)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"error = [[0.006797]]\n",
"theta = [[ 1.11414027]\n",
" [-2.89324615]\n",
" [-0.66543637]\n",
" [ 0.14887292]\n",
" [ 2.13284493]]\n"
]
}
],
"source": [
"thetaBest, errors = GD(h, J, dJ, thetaTemp, XTrain, yTrain, \n",
" alpha=0.1, eps=10**-7, maxSteps=1000)\n",
"print(\"error =\", errors[-1][0])\n",
"print(\"theta =\", thetaBest)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja decyzyjna regresji logistycznej\n",
"\n",
"Funkcja decyzyjna mówi o tym, kiedy nasz algorytm będzie przewidywał $y = 1$, a kiedy $y = 0$\n",
"\n",
"$$ c = \\left\\{ \n",
"\\begin{array}{ll}\n",
"1, & \\mbox{gdy } P(y=1 \\, | \\, x; \\theta) > 0.5 \\\\\n",
"0 & \\mbox{w przeciwnym przypadku}\n",
"\\end{array}\\right.\n",
"$$\n",
"\n",
"$$ P(y=1 \\,| \\, x; \\theta) = h_\\theta(x) $$"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.11414027]\n",
" [-2.89324615]\n",
" [-0.66543637]\n",
" [ 0.14887292]\n",
" [ 2.13284493]]\n",
"x0 = [[1. 6.3 1.8 7.3 2.9]]\n",
"h(x0) = 1.6061436959824898e-05\n",
"c(x0) = (0, 1.6061436959824898e-05) \n",
"\n"
]
}
],
"source": [
"def classifyBi(theta, X):\n",
" prob = h(theta, X) .item()\n",
" return (1, prob) if prob > 0.5 else (0, prob)\n",
"\n",
"print(\"theta =\", thetaBest)\n",
"print(\"x0 =\", XTest[0])\n",
"print(\"h(x0) =\", h(thetaBest, XTest[0]).item())\n",
"print(\"c(x0) =\", classifyBi(thetaBest, XTest[0]), \"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Skuteczność"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 <=> 0 -- prob: 0.0\n",
"1 <=> 1 -- prob: 0.9816\n",
"0 <=> 0 -- prob: 0.0001\n",
"0 <=> 0 -- prob: 0.0005\n",
"0 <=> 0 -- prob: 0.0001\n",
"1 <=> 1 -- prob: 0.9936\n",
"0 <=> 0 -- prob: 0.0059\n",
"0 <=> 0 -- prob: 0.0992\n",
"0 <=> 0 -- prob: 0.0001\n",
"0 <=> 0 -- prob: 0.0001\n",
"\n",
"Accuracy: 1.0\n"
]
}
],
"source": [
"acc = 0.0\n",
"for i, rest in enumerate(yTest):\n",
" cls, prob = classifyBi(thetaBest, XTest[i])\n",
" if i < 10:\n",
" print(int(yTest[i].item()), \"<=>\", cls, \"-- prob:\", round(prob, 4))\n",
" acc += cls == yTest[i].item()\n",
"\n",
"print(\"\\nAccuracy:\", acc / len(XTest))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.4. Wieloklasowa regresja logistyczna"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Przykład: gatunki irysów (kosaćców)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"\n",
"\n",
"Kosaciec szczecinkowy (*Iris setosa*)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"\n",
"\n",
"Kosaciec amerykański (*Iris virginica*)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"\n",
"\n",
"Kosaciec różnobarwny (*Iris versicolor*)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Cechy:\n",
" * długość działek kielicha\n",
" * szerokość działek kielicha\n",
" * długość płatka\n",
" * szerokość płatka"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wczytanie danych"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" sl | \n",
" sw | \n",
" pl | \n",
" pw | \n",
" Gatunek | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 5.2 | \n",
" 3.4 | \n",
" 1.4 | \n",
" 0.2 | \n",
" Iris-setosa | \n",
"
\n",
" \n",
" 1 | \n",
" 5.1 | \n",
" 3.7 | \n",
" 1.5 | \n",
" 0.4 | \n",
" Iris-setosa | \n",
"
\n",
" \n",
" 2 | \n",
" 6.7 | \n",
" 3.1 | \n",
" 5.6 | \n",
" 2.4 | \n",
" Iris-virginica | \n",
"
\n",
" \n",
" 3 | \n",
" 6.5 | \n",
" 3.2 | \n",
" 5.1 | \n",
" 2.0 | \n",
" Iris-virginica | \n",
"
\n",
" \n",
" 4 | \n",
" 4.9 | \n",
" 2.5 | \n",
" 4.5 | \n",
" 1.7 | \n",
" Iris-virginica | \n",
"
\n",
" \n",
" 5 | \n",
" 6.0 | \n",
" 2.7 | \n",
" 5.1 | \n",
" 1.6 | \n",
" Iris-versicolor | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sl sw pl pw Gatunek\n",
"0 5.2 3.4 1.4 0.2 Iris-setosa\n",
"1 5.1 3.7 1.5 0.4 Iris-setosa\n",
"2 6.7 3.1 5.6 2.4 Iris-virginica\n",
"3 6.5 3.2 5.1 2.0 Iris-virginica\n",
"4 4.9 2.5 4.5 1.7 Iris-virginica\n",
"5 6.0 2.7 5.1 1.6 Iris-versicolor"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas\n",
"data_iris = pandas.read_csv('iris.csv')\n",
"data_iris[:6]"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Przygotowanie danych"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X = [[1. 5.2 3.4 1.4 0.2]\n",
" [1. 5.1 3.7 1.5 0.4]\n",
" [1. 6.7 3.1 5.6 2.4]\n",
" [1. 6.5 3.2 5.1 2. ]]\n",
"y = [['Iris-setosa']\n",
" ['Iris-setosa']\n",
" ['Iris-virginica']\n",
" ['Iris-virginica']]\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"features = ['sl', 'sw', 'pl', 'pw']\n",
"m = len(data_iris)\n",
"X = np.matrix(data_iris[features])\n",
"X0 = np.ones(m).reshape(m, 1)\n",
"X = np.hstack((X0, X))\n",
"y = np.matrix(data_iris[[\"Gatunek\"]]).reshape(m, 1)\n",
"\n",
"print(\"X = \", X[:4])\n",
"print(\"y = \", y[:4])"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Zamieńmy etykiety tekstowe w tablicy $y$ na wektory jednostkowe (_one-hot vectors_):\n",
"\n",
"$$\n",
"\\begin{array}{ccc}\n",
"\\mbox{\"Iris-setosa\"} & \\mapsto & \\left[ \\begin{array}{ccc} 1 & 0 & 0 \\\\ \\end{array} \\right] \\\\\n",
"\\mbox{\"Iris-virginica\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 1 & 0 \\\\ \\end{array} \\right] \\\\\n",
"\\mbox{\"Iris-versicolor\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 0 & 1 \\\\ \\end{array} \\right] \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wówczas zamiast wektora $y$ otrzymamy macierz $Y$:\n",
"\n",
"$$\n",
"y \\; = \\;\n",
"\\left[\n",
"\\begin{array}{c}\n",
"y^{(1)} \\\\\n",
"y^{(2)} \\\\\n",
"y^{(3)} \\\\\n",
"y^{(4)} \\\\\n",
"y^{(5)} \\\\\n",
"\\vdots \\\\\n",
"\\end{array}\n",
"\\right]\n",
"\\; = \\;\n",
"\\left[\n",
"\\begin{array}{c}\n",
"\\mbox{\"Iris-setosa\"} \\\\\n",
"\\mbox{\"Iris-setosa\"} \\\\\n",
"\\mbox{\"Iris-virginica\"} \\\\\n",
"\\mbox{\"Iris-versicolor\"} \\\\\n",
"\\mbox{\"Iris-virginica\"} \\\\\n",
"\\vdots \\\\\n",
"\\end{array}\n",
"\\right]\n",
"\\quad \\mapsto \\quad\n",
"Y \\; = \\;\n",
"\\left[\n",
"\\begin{array}{ccc}\n",
"1 & 0 & 0 \\\\\n",
"1 & 0 & 0 \\\\\n",
"0 & 1 & 0 \\\\\n",
"0 & 0 & 1 \\\\\n",
"0 & 1 & 0 \\\\\n",
"\\vdots & \\vdots & \\vdots \\\\\n",
"\\end{array}\n",
"\\right]\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def mapY(y, cls):\n",
" m = len(y)\n",
" yBi = np.matrix(np.zeros(m)).reshape(m, 1)\n",
" yBi[y == cls] = 1.\n",
" return yBi\n",
"\n",
"def indicatorMatrix(y):\n",
" classes = np.unique(y.tolist())\n",
" m = len(y)\n",
" k = len(classes)\n",
" Y = np.matrix(np.zeros((m, k)))\n",
" for i, cls in enumerate(classes):\n",
" Y[:, i] = mapY(y, cls)\n",
" return Y\n",
"\n",
"# one-hot matrix\n",
"Y = indicatorMatrix(y)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Podział danych na zbiór trenujący i testowy\n",
"XTrain, XTest = X[:100], X[100:]\n",
"YTrain, YTest = Y[:100], Y[100:]\n",
"\n",
"# Macierz parametrów początkowych\n",
"thetaTemp = np.ones(5).reshape(5,1)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pawel\\anaconda3\\lib\\site-packages\\seaborn\\axisgrid.py:2071: UserWarning: The `size` parameter has been renamed to `height`; please update your code.\n",
" warnings.warn(msg, UserWarning)\n"
]
},
{
"data": {
"image/svg+xml": [
"\r\n",
"\r\n",
"\r\n",
"\r\n"
],
"text/plain": [
"