umz21/wyk/03_Regresja_logistyczna.ipynb

1662 lines
231 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Uczenie maszynowe — zastosowania\n",
"# 3. Regresja logistyczna"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"**Uwaga**: Wbrew nazwie, *regresja* logistyczna jest algorytmem służącym do rozwiązywania problemów *klasyfikacji* (wcale nie problemów *regresji*!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Do demonstracji metody regresji ligistycznej wykorzystamy klasyczny zbiór danych <a href=\"https://en.wikipedia.org/wiki/Iris_flower_data_set\">*Iris flower data set*</a>, składający się ze 150 przykładów wartości 4 cech dla 3 gatunków irysów (kosaćców)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### *Iris flower data set*\n",
"\n",
"* 150 przykładów\n",
"* 4 cechy\n",
"* 3 kategorie"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"| <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg/450px-Kosaciec_szczecinkowaty_Iris_setosa.jpg\"> | <img style=\"float: right;\" src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/9/9f/Iris_virginica.jpg/736px-Iris_virginica.jpg\"> | <img style=\"float: right;\" src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/2/27/Blue_Flag%2C_Ottawa.jpg/600px-Blue_Flag%2C_Ottawa.jpg\"> |\n",
"| :--- | :--- | :--- |\n",
"| *Iris setosa* | *Iris virginica* | *Iris versicolor* |\n",
"| kosaciec szczecinkowy | kosaciec amerykański | kosaciec różnobarwny |\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"4 cechy:\n",
" * długość działek kielicha (*sepal length*, `sl`)\n",
" * szerokość działek kielicha (*sepal width*, `sw`)\n",
" * długość płatka (*petal length*, `pl`)\n",
" * szerokość płatka (*petal width*, `pw`)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.1. Dwuklasowa regresja logistyczna"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Zacznijmy od najprostszego przypadku:\n",
" * ograniczmy się do **2** klas\n",
" * ograniczmy się do **1** zmiennej"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"→ dwuklasowa regresja logistyczna jednej zmiennej"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Przydatne importy\n",
"\n",
"import numpy as np\n",
"import matplotlib\n",
"import matplotlib.pyplot as pl\n",
"import pandas\n",
"import ipywidgets as widgets\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'svg'\n",
"\n",
"from IPython.display import display, Math, Latex\n",
"\n",
"# Przydatne funkcje\n",
"\n",
"# Wyświetlanie macierzy w LaTeX-u\n",
"def LatexMatrix(matrix):\n",
" ltx = r'\\left[\\begin{array}'\n",
" m, n = matrix.shape\n",
" ltx += '{' + (\"r\" * n) + '}'\n",
" for i in range(m):\n",
" ltx += r\" & \".join([('%.4f' % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n",
" ltx += r'\\end{array}\\right]'\n",
" return ltx\n",
"\n",
"# Hipoteza (wersja macierzowa)\n",
"def hMx(theta, X):\n",
" return X * theta\n",
"\n",
"# Wykres danych (wersja macierzowa)\n",
"def regdotsMx(X, y, xlabel, ylabel): \n",
" fig = pl.figure(figsize=(16*.6, 9*.6))\n",
" ax = fig.add_subplot(111)\n",
" fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n",
" ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n",
" \n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.margins(.05, .05)\n",
" pl.ylim(y.min() - 1, y.max() + 1)\n",
" pl.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n",
" return fig\n",
"\n",
"# Wykres krzywej regresji (wersja macierzowa)\n",
"def reglineMx(fig, fun, theta, X):\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" L = [x0, x1]\n",
" LX = np.matrix([1, x0, 1, x1]).reshape(2, 2)\n",
" ax.plot(L, fun(theta, LX), linewidth='2',\n",
" label=(r'$y={theta0:.2}{op}{theta1:.2}x$'.format(\n",
" theta0=float(theta[0][0]),\n",
" theta1=(float(theta[1][0]) if theta[1][0] >= 0 else float(-theta[1][0])),\n",
" op='+' if theta[1][0] >= 0 else '-')))\n",
"\n",
"# Legenda wykresu\n",
"def legend(fig):\n",
" ax = fig.axes[0]\n",
" handles, labels = ax.get_legend_handles_labels()\n",
" # try-except block is a fix for a bug in Poly3DCollection\n",
" try:\n",
" fig.legend(handles, labels, fontsize='15', loc='lower right')\n",
" except AttributeError:\n",
" pass\n",
"\n",
"# Wersja macierzowa funkcji kosztu\n",
"def JMx(theta,X,y):\n",
" m = len(y)\n",
" J = 1.0 / (2.0 * m) * ((X * theta - y).T * ( X * theta - y))\n",
" return J.item()\n",
"\n",
"# Wersja macierzowa gradientu funkcji kosztu\n",
"def dJMx(theta,X,y):\n",
" return 1.0 / len(y) * (X.T * (X * theta - y)) \n",
"\n",
"# Implementacja algorytmu gradientu prostego za pomocą numpy i macierzy\n",
"def GDMx(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-3):\n",
" current_cost = fJ(theta, X, y)\n",
" logs = [[current_cost, theta]]\n",
" while True:\n",
" theta = theta - alpha * fdJ(theta, X, y) # implementacja wzoru\n",
" current_cost, prev_cost = fJ(theta, X, y), current_cost\n",
" if current_cost > 10000:\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" logs.append([current_cost, theta]) \n",
" return theta, logs\n",
"\n",
"thetaStartMx = np.matrix([0, 0]).reshape(2, 1)\n",
"\n",
"# Funkcja, która rysuje próg\n",
"def threshold(fig, theta):\n",
" x_thr = (0.5 - theta.item(0)) / theta.item(1)\n",
" ax = fig.axes[0]\n",
" ax.plot([x_thr, x_thr], [-1, 2],\n",
" color='orange', linestyle='dashed',\n",
" label=u'próg: $x={:.2F}$'.format(x_thr))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" sl sw pl pw Gatunek\n",
"0 5.2 3.4 1.4 0.2 Iris-setosa\n",
"1 5.1 3.7 1.5 0.4 Iris-setosa\n",
"2 6.7 3.1 5.6 2.4 Iris-virginica\n",
"3 6.5 3.2 5.1 2.0 Iris-virginica\n",
"4 4.9 2.5 4.5 1.7 Iris-virginica\n",
"5 6.0 2.7 5.1 1.6 Iris-versicolor\n"
]
}
],
"source": [
"# Wczytanie pełnych (oryginalnych) danych\n",
"\n",
"data_iris = pandas.read_csv('iris.csv')\n",
"print(data_iris[:6])"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 3,
"metadata": {
2022-03-11 10:23:27 +01:00
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" dł. płatka Iris setosa?\n",
"0 1.4 1\n",
"1 1.5 1\n",
"2 5.6 0\n",
"3 5.1 0\n",
"4 4.5 0\n",
"5 5.1 0\n"
]
}
],
"source": [
"# Ograniczenie danych do 2 klas i 1 cechy\n",
"\n",
"data_iris_setosa = pandas.DataFrame()\n",
"data_iris_setosa['dł. płatka'] = data_iris['pl'] # \"pl\" oznacza \"petal length\"\n",
"data_iris_setosa['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n",
"print(data_iris_setosa[:6])"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"# Przygotowanie danych\n",
"m, n_plus_1 = data_iris_setosa.values.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data_iris_setosa.values[:, 0:n].reshape(m, n)\n",
"\n",
"XMx3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx3 = np.matrix(data_iris_setosa.values[:, 1]).reshape(m, 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-11 10:23:27 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"360.619219pt\" version=\"1.1\" viewBox=\"0 0 673.940937 360.619219\" width=\"673.940937pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 360.619219 \nL 673.940937 360.619219 \nL 673.940937 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 322.039219 \nL 605.120937 322.039219 \nL 605.120937 10.999219 \nL 52.160938 10.999219 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path d=\"M 0 3.535534 \nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \nz\n\" id=\"m02e6b61b87\" style=\"stroke:#ff0000;\"/>\n </defs>\n <g clip-path=\"url(#p0b6dbced90)\">\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"444.132583\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"367.138153\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"451.132077\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"395.136127\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"185.151317\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"458.13157\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"521.127013\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m02e6b61b87\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"388.136634\" xlink:href=\"#m02e6b61b87\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"402.135621\" xlink:hre
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n",
"legend(fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Próba zastosowania regresji liniowej do problemu klasyfikacji\n",
"\n",
"Najpierw z ciekawości sprawdźmy, co otrzymalibyśmy, gdybyśmy zastosowali regresję liniową do problemu klasyfikacji."
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 6,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"# Przygotowanie danych\n",
"m, n_plus_1 = data_iris_setosa.values.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data_iris_setosa.values[:, 0:n].reshape(m, n)\n",
"\n",
"XMx3 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx3 = np.matrix(data_iris_setosa.values[:, 1]).reshape(m, 1)\n",
"\n",
"# Regresja liniowa\n",
"theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-11 10:23:27 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"360.619219pt\" version=\"1.1\" viewBox=\"0 0 673.940937 360.619219\" width=\"673.940937pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 360.619219 \nL 673.940937 360.619219 \nL 673.940937 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 322.039219 \nL 605.120937 322.039219 \nL 605.120937 10.999219 \nL 52.160938 10.999219 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path d=\"M 0 3.535534 \nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \nz\n\" id=\"m7ce270b550\" style=\"stroke:#ff0000;\"/>\n </defs>\n <g clip-path=\"url(#pdc2641a829)\">\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"444.132583\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"367.138153\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"451.132077\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"395.136127\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"185.151317\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"458.13157\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"521.127013\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m7ce270b550\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"388.136634\" xlink:href=\"#m7ce270b550\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"402.135621\" xlink:hre
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n",
"reglineMx(fig, hMx, theta_e3, XMx3)\n",
"legend(fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"A gdyby tak przyjąć, że klasyfikator zwraca $1$ dla $h(x) > 0.5$ i $0$ w przeciwnym przypadku?"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-11 10:23:27 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"360.619219pt\" version=\"1.1\" viewBox=\"0 0 673.940937 360.619219\" width=\"673.940937pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 360.619219 \nL 673.940937 360.619219 \nL 673.940937 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 322.039219 \nL 605.120937 322.039219 \nL 605.120937 10.999219 \nL 52.160938 10.999219 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path d=\"M 0 3.535534 \nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \nz\n\" id=\"mdc4dda7440\" style=\"stroke:#ff0000;\"/>\n </defs>\n <g clip-path=\"url(#pa55f031633)\">\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"444.132583\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"367.138153\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"451.132077\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"395.136127\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"185.151317\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"458.13157\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"521.127013\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#mdc4dda7440\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"388.136634\" xlink:href=\"#mdc4dda7440\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"402.135621\" xlink:hre
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdotsMx(XMx3, yMx3, 'x', 'Iris setosa?')\n",
"theta_e3, logs3 = GDMx(JMx, dJMx, thetaStartMx, XMx3, yMx3, alpha=0.03, eps=0.000001)\n",
"reglineMx(fig, hMx, theta_e3, XMx3)\n",
"threshold(fig, theta_e3) # pomarańczowa linia oznacza granicę między klasą \"1\" a klasą \"0\" wyznaczoną przez próg \"h(x) = 0.5\"\n",
"legend(fig)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
" * Krzywa regresji liniowej jest niezbyt dopasowana do danych klasyfikacyjnych.\n",
" * Zastosowanie progu $y = 0.5$ nie zawsze pomaga uzyskać sensowny rezultat.\n",
" * $h(x)$ może przyjmować wartości mniejsze od $0$ i większe od $1$ jak interpretować takie wyniki?\n",
"\n",
"Wniosek: w przypadku problemów klasyfikacyjnych regresja liniowa nie wydaje się najlepszym rozwiązaniem."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Wprowadźmy zatem pewne modyfikacje do naszego modelu."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Zdefiniujmy następującą funkcję, którą będziemy nazywać funkcją *logistyczną* (albo *sigmoidalną*):"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"**Funkcja logistyczna (sigmoidalna)**:\n",
"\n",
"$$g(x) = \\dfrac{1}{1+e^{-x}}$$"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Funkjca logistycza\n",
"\n",
"def logistic(x):\n",
" return 1.0 / (1.0 + np.exp(-x))"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 10,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"def plot_logistic():\n",
" x = np.linspace(-5,5,200)\n",
" y = logistic(x)\n",
"\n",
" fig = plt.figure(figsize=(7,5))\n",
" ax = fig.add_subplot(111)\n",
" plt.ylim(-.1,1.1)\n",
" ax.plot(x, y, linewidth='2')"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wykres funkcji logistycznej $g(x) = \\dfrac{1}{1+e^{-x}}$:"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 11,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
2022-03-11 10:23:27 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"302.878125pt\" version=\"1.1\" viewBox=\"0 0 427.903125 302.878125\" width=\"427.903125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 302.878125 \nL 427.903125 302.878125 \nL 427.903125 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 30.103125 279 \nL 420.703125 279 \nL 420.703125 7.2 \nL 30.103125 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"m4f35b76038\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n </defs>\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"83.366761\" xlink:href=\"#m4f35b76038\" y=\"279\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 4 -->\n <defs>\n <path d=\"M 10.59375 35.5 \nL 73.1875 35.5 \nL 73.1875 27.203125 \nL 10.59375 27.203125 \nz\n\" id=\"DejaVuSans-8722\"/>\n <path d=\"M 37.796875 64.3125 \nL 12.890625 25.390625 \nL 37.796875 25.390625 \nz\nM 35.203125 72.90625 \nL 47.609375 72.90625 \nL 47.609375 25.390625 \nL 58.015625 25.390625 \nL 58.015625 17.1875 \nL 47.609375 17.1875 \nL 47.609375 0 \nL 37.796875 0 \nL 37.796875 17.1875 \nL 4.890625 17.1875 \nL 4.890625 26.703125 \nz\n\" id=\"DejaVuSans-52\"/>\n </defs>\n <g transform=\"translate(75.995668 293.598437)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-8722\"/>\n <use x=\"83.789062\" xlink:href=\"#DejaVuSans-52\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"154.384943\" xlink:href=\"#m4f35b76038\" y=\"279\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 2 -->\n <defs>\n <path d=\"M 19.1875 8.296875 \nL 53.609375 8.296875 \nL 53.609375 0 \nL 7.328125 0 \nL 7.328125 8.296875 \nQ 12.9375 14.109375 22.625 23.890625 \nQ 32.328125 33.6875 34.8125 36.53125 \nQ 39.546875 41.84375 41.421875 45.53125 \nQ 43.3125 49.21875 43.3125 52.78125 \nQ 43.3125 58.59375 39.234375 62.25 \nQ 35.15625 65.921875 28.609375 65.921875 \nQ 23.96875 65.921875 18.8125 64.3125 \nQ 13.671875 62.703125 7.8125 59.421875 \nL 7.8125 69.390625 \nQ 13.765625 71.78125 18.9375 73 \nQ 24.125 74.21875 28.421875 74.21875 \nQ 39.75 74.21875 46.484375 68.546875 \nQ 53.21875 62.890625 53.21875 53.421875 \nQ 53.21875 48.921875 51.53125 44.890625 \nQ 49.859375 40.875 45.40625 35.40625 \nQ 44.1875 33.984375 37.640625 27.21875 \nQ 31.109375 20.453125 19.1875 8.296875 \nz\n\" id=\"DejaVuSans-50\"/>\n </defs>\n <g transform=\"translate(147.013849 293.598437)scale(0.1 -0.1)\">\n <use xlink:href=\"#DejaVuSans-8722\"/>\n <use x=\"83.789062\" xlink:href=\"#DejaVuSans-50\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_3\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"225.403125\" xlink:href=\"#m4f35b76038\" y=\"279\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 0 -->\n <defs>\n <path d=\"M 31.78125 66.40625 \nQ 24.171875 66.40625 20.328125 58.90625 \nQ 16.5 51.421875 16.5 36.375 \nQ 16.5 21.390625 20.328125 13.890625 \nQ 24.171875 6.390625 31.78125 6.390625 \nQ 39.453125 6.390625 43.28125 13.890625 \nQ 47.125 21.390625 47.125 36.375 \nQ 47.125 51.421875 43.28125 58.90625 \nQ 39.453125 66.40625 31.78125 66.40625 \nz\nM 31.78125 74.21875 \nQ 44.046875 74.21875 50.515625 64.515625 \nQ 56.984375 54.828125 56.984375 36.375 \nQ 56.984375 17.96875 50.515625 8.265625 \nQ 44.046875 -1.421875 31.78125 -1.421
"text/plain": [
"<Figure size 504x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_logistic()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Funkcja logistyczna przekształca zbiór liczb rzeczywistych $\\mathbb{R}$ w przedział otwarty $(0, 1)$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja regresji logistycznej dla pojedynczego przykładu o cechach wyrażonych wektorem $x$:\n",
"\n",
"$$h_\\theta(x) = g(\\theta^T \\, x) = \\dfrac{1}{1 + e^{-\\theta^T x}}$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Dla całej macierzy cech $X$:\n",
"\n",
"$$h_\\theta(X) = g(X \\, \\theta) = \\dfrac{1}{1 + e^{-X \\theta}}$$"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"# Funkcja regresji logistcznej\n",
"def h(theta, X):\n",
" return 1.0/(1.0 + np.exp(-X * theta))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Funkcja kosztu dla regresji logistycznej:\n",
"\n",
"$$J(\\theta) = -\\dfrac{1}{m} \\left( \\sum_{i=1}^{m} y^{(i)} \\log h_\\theta( x^{(i)} ) + \\left( 1 - y^{(i)} \\right) \\log \\left( 1 - h_\\theta (x^{(i)}) \\right) \\right)$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Gradient dla regresji logistycznej (wersja macierzowa):\n",
"\n",
"$$\\nabla J(\\theta) = \\frac{1}{|\\vec y|} X^T \\left( h_\\theta(X) - \\vec y \\right)$$\n",
"\n",
"(Jedyna różnica między gradientem dla regresji logistycznej a gradientem dla regresji liniowej to postać $h_\\theta$)."
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Funkcja kosztu dla regresji logistycznej\n",
"def J(h, theta, X, y):\n",
" m = len(y)\n",
" h_val = h(theta, X)\n",
" s1 = np.multiply(y, np.log(h_val))\n",
" s2 = np.multiply((1 - y), np.log(1 - h_val))\n",
" return -np.sum(s1 + s2, axis=0) / m"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [],
"source": [
"# Gradient dla regresji logistycznej\n",
"def dJ(h, theta, X, y):\n",
" return 1.0 / len(y) * (X.T * (h(theta, X) - y))"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 15,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Metoda gradientu prostego dla regresji logistycznej\n",
"def GD(h, fJ, fdJ, theta, X, y, alpha=0.01, eps=10**-3, maxSteps=10000):\n",
" errorCurr = fJ(h, theta, X, y)\n",
" errors = [[errorCurr, theta]]\n",
" while True:\n",
" # oblicz nowe theta\n",
" theta = theta - alpha * fdJ(h, theta, X, y)\n",
" # raportuj poziom błędu\n",
" errorCurr, errorPrev = fJ(h, theta, X, y), errorCurr\n",
" # kryteria stopu\n",
" if abs(errorPrev - errorCurr) <= eps:\n",
" break\n",
" if len(errors) > maxSteps:\n",
" break\n",
" errors.append([errorCurr, theta]) \n",
" return theta, errors"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 16,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"error = [[0.05755617]]\n",
"theta = [[ 5.02530461]\n",
" [-1.99174803]]\n"
]
}
],
"source": [
"# Uruchomienie metody gradientu prostego dla regresji logistycznej\n",
"thetaBest, errors = GD(h, J, dJ, thetaStartMx, XMx3, yMx3, \n",
" alpha=0.1, eps=10**-7, maxSteps=1000)\n",
"print(\"error =\", errors[-1][0])\n",
"print(\"theta =\", thetaBest)"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 17,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Funkcja regresji logistycznej (wersja skalarna)\n",
"def scalar_logistic_regression_function(theta, x):\n",
" return 1.0/(1.0 + np.exp(-(theta.item(0) + theta.item(1) * x)))\n",
"\n",
"# Rysowanie progu\n",
"def threshold_val(fig, x_thr):\n",
" ax = fig.axes[0]\n",
" ax.plot([x_thr, x_thr], [-1, 2],\n",
" color='orange', linestyle='dashed',\n",
" label=u'próg: $x={:.2F}$'.format(x_thr))\n",
"\n",
"# Wykres krzywej regresji logistycznej\n",
"def logistic_regline(fig, theta, X):\n",
" ax = fig.axes[0]\n",
" x0 = np.min(X[:, 1]) - 1.0\n",
" x1 = np.max(X[:, 1]) + 1.0\n",
" Arg = np.arange(x0, x1, 0.1)\n",
" Val = scalar_logistic_regression_function(theta, Arg)\n",
" ax.plot(Arg, Val, linewidth='2')"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 18,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
2022-03-11 10:23:27 +01:00
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"359.595469pt\" version=\"1.1\" viewBox=\"0 0 612.320937 359.595469\" width=\"612.320937pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <defs>\n <style type=\"text/css\">\n*{stroke-linecap:butt;stroke-linejoin:round;}\n </style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 359.595469 \nL 612.320937 359.595469 \nL 612.320937 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 52.160938 322.039219 \nL 605.120937 322.039219 \nL 605.120937 10.999219 \nL 52.160938 10.999219 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"PathCollection_1\">\n <defs>\n <path d=\"M 0 3.535534 \nC 0.937635 3.535534 1.836992 3.163008 2.5 2.5 \nC 3.163008 1.836992 3.535534 0.937635 3.535534 0 \nC 3.535534 -0.937635 3.163008 -1.836992 2.5 -2.5 \nC 1.836992 -3.163008 0.937635 -3.535534 0 -3.535534 \nC -0.937635 -3.535534 -1.836992 -3.163008 -2.5 -2.5 \nC -3.163008 -1.836992 -3.535534 -0.937635 -3.535534 0 \nC -3.535534 0.937635 -3.163008 1.836992 -2.5 2.5 \nC -1.836992 3.163008 -0.937635 3.535534 0 3.535534 \nz\n\" id=\"m98588838b2\" style=\"stroke:#ff0000;\"/>\n </defs>\n <g clip-path=\"url(#p6bf1ebeb35)\">\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"444.132583\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"367.138153\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"297.143216\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"157.153343\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"451.132077\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"395.136127\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"185.151317\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"409.135115\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"458.13157\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"143.154355\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"521.127013\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"150.153849\" xlink:href=\"#m98588838b2\" y=\"114.679219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"388.136634\" xlink:href=\"#m98588838b2\" y=\"218.359219\"/>\n <use style=\"fill:#ff0000;stroke:#ff0000;\" x=\"402.135621\" xlink:hre
"text/plain": [
"<Figure size 691.2x388.8 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = regdotsMx(XMx3, yMx3, xlabel='x', ylabel='Iris setosa?')\n",
"logistic_regline(fig, thetaBest, XMx3)\n",
"threshold_val(fig, 2.5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Traktujemy wartość $h_\\theta(x)$ jako prawdopodobieństwo, że cecha przyjmie wartość pozytywną:\n",
"\n",
"$$ h_\\theta(x) = P(y = 1 \\, | \\, x; \\theta) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Jeżeli $h_\\theta(x) > 0.5$, to dla takiego $x$ będziemy przewidywać wartość $y = 1$.\n",
"W przeciwnym wypadku uprzewidzimy $y = 0$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Dlaczego możemy traktować wartość funkcji regresji logistycznej jako prawdopodobieństwo?\n",
"\n",
"Można o tym poczytać w zewnętrznych źródłach, np. https://towardsdatascience.com/logit-of-logistic-regression-understanding-the-fundamentals-f384152a33d1"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Dwuklasowa regresja logistyczna: więcej cech"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Jak postąpić, jeżeli będziemy mieli więcej niż jedną cechę $x$?\n",
"\n",
"Weźmy teraz wszystkie cechy występujące w zbiorze *Iris*."
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 19,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" dł. płatków szer. płatków dł. dz. k. szer. dz. k. Iris setosa?\n",
"0 1.4 0.2 5.2 3.4 1\n",
"1 1.5 0.4 5.1 3.7 1\n",
"2 5.6 2.4 6.7 3.1 0\n",
"3 5.1 2.0 6.5 3.2 0\n",
"4 4.5 1.7 4.9 2.5 0\n",
"5 5.1 1.6 6.0 2.7 0\n"
]
}
],
"source": [
"data_iris_setosa_multi = pandas.DataFrame()\n",
"data_iris_setosa_multi['dł. płatków'] = data_iris['pl'] # \"pl\" oznacza \"petal length\" (długość płatków)\n",
"data_iris_setosa_multi['szer. płatków'] = data_iris['pw'] # \"pw\" oznacza \"petal width\" (szerokość płatków)\n",
"data_iris_setosa_multi['dł. dz. k.'] = data_iris['sl'] # \"sl\" oznacza \"sepal length\" (długość działek kielicha)\n",
"data_iris_setosa_multi['szer. dz. k.'] = data_iris['sw'] # \"sw\" oznacza \"sepal width\" (szerokość działek kielicha)\n",
"data_iris_setosa_multi['Iris setosa?'] = data_iris['Gatunek'].apply(lambda x: 1 if x=='Iris-setosa' else 0)\n",
"print(data_iris_setosa_multi[:6])"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 20,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 21,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1. 1.4 0.2 5.2 3.4]\n",
" [1. 1.5 0.4 5.1 3.7]\n",
" [1. 5.6 2.4 6.7 3.1]\n",
" [1. 5.1 2. 6.5 3.2]\n",
" [1. 4.5 1.7 4.9 2.5]\n",
" [1. 5.1 1.6 6. 2.7]]\n",
"[[1.]\n",
" [1.]\n",
" [0.]\n",
" [0.]\n",
" [0.]\n",
" [0.]]\n"
]
}
],
"source": [
"# Przygotowanie danych\n",
"m, n_plus_1 = data_iris_setosa_multi.values.shape\n",
"n = n_plus_1 - 1\n",
"Xn = data_iris_setosa_multi.values[:, 0:n].reshape(m, n)\n",
"\n",
"XMx4 = np.matrix(np.concatenate((np.ones((m, 1)), Xn), axis=1)).reshape(m, n_plus_1)\n",
"yMx4 = np.matrix(data_iris_setosa_multi.values[:, n]).reshape(m, 1)\n",
"\n",
"print(XMx4[:6])\n",
"print(yMx4[:6])"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Podział danych na zbiór trenujący i testowy\n",
"XTrain, XTest = XMx4[:100], XMx4[100:]\n",
"yTrain, yTest = yMx4[:100], yMx4[100:]\n",
"\n",
"# Macierz parametrów początkowych\n",
"thetaTemp = np.ones(5).reshape(5,1)"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 23,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"error = [[0.006797]]\n",
"theta = [[ 1.11414027]\n",
" [-2.89324615]\n",
" [-0.66543637]\n",
" [ 0.14887292]\n",
" [ 2.13284493]]\n"
]
}
],
"source": [
"thetaBest, errors = GD(h, J, dJ, thetaTemp, XTrain, yTrain, \n",
" alpha=0.1, eps=10**-7, maxSteps=1000)\n",
"print(\"error =\", errors[-1][0])\n",
"print(\"theta =\", thetaBest)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Funkcja decyzyjna regresji logistycznej\n",
"\n",
"Funkcja decyzyjna mówi o tym, kiedy nasz algorytm będzie przewidywał $y = 1$, a kiedy $y = 0$\n",
"\n",
"$$ c = \\left\\{ \n",
"\\begin{array}{ll}\n",
"1, & \\mbox{gdy } P(y=1 \\, | \\, x; \\theta) > 0.5 \\\\\n",
"0 & \\mbox{w przeciwnym przypadku}\n",
"\\end{array}\\right.\n",
"$$\n",
"\n",
"$$ P(y=1 \\,| \\, x; \\theta) = h_\\theta(x) $$"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 24,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"theta = [[ 1.11414027]\n",
" [-2.89324615]\n",
" [-0.66543637]\n",
" [ 0.14887292]\n",
" [ 2.13284493]]\n",
"x0 = [[1. 6.3 1.8 7.3 2.9]]\n",
2022-03-11 10:23:27 +01:00
"h(x0) = 1.6061436959824844e-05\n",
"c(x0) = (0, 1.6061436959824844e-05) \n",
"\n"
]
}
],
"source": [
"def classifyBi(theta, X):\n",
" prob = h(theta, X).item()\n",
" return (1, prob) if prob > 0.5 else (0, prob)\n",
"\n",
"print(\"theta =\", thetaBest)\n",
"print(\"x0 =\", XTest[0])\n",
"print(\"h(x0) =\", h(thetaBest, XTest[0]).item())\n",
"print(\"c(x0) =\", classifyBi(thetaBest, XTest[0]), \"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Obliczmy teraz skuteczność modelu (więcej na ten temat na następnym wykładzie, poświęconym metodom ewaluacji)."
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 25,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 <=> 0 -- prob: 0.0\n",
2022-03-11 10:23:27 +01:00
"1 <=> 1 -- prob: 0.9816\n",
"0 <=> 0 -- prob: 0.0001\n",
"0 <=> 0 -- prob: 0.0005\n",
"0 <=> 0 -- prob: 0.0001\n",
"1 <=> 1 -- prob: 0.9936\n",
"0 <=> 0 -- prob: 0.0059\n",
"0 <=> 0 -- prob: 0.0992\n",
"0 <=> 0 -- prob: 0.0001\n",
"0 <=> 0 -- prob: 0.0001\n",
"\n",
2022-03-11 10:23:27 +01:00
"Accuracy: 1.0\n"
]
}
],
"source": [
"acc = 0.0\n",
"for i, rest in enumerate(yTest):\n",
" cls, prob = classifyBi(thetaBest, XTest[i])\n",
" if i < 10:\n",
" print(int(yTest[i].item()), \"<=>\", cls, \"-- prob:\", round(prob, 4))\n",
" acc += cls == yTest[i].item()\n",
"\n",
"print(\"\\nAccuracy:\", acc / len(XTest))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## 3.2. Wieloklasowa regresja logistyczna"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Przykład: wszystkie cechy ze zbioru *Iris*, wszystkie 3 klasy ze zbioru *Iris*."
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 26,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sl</th>\n",
" <th>sw</th>\n",
" <th>pl</th>\n",
" <th>pw</th>\n",
" <th>Gatunek</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.2</td>\n",
" <td>3.4</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5.1</td>\n",
" <td>3.7</td>\n",
" <td>1.5</td>\n",
" <td>0.4</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6.7</td>\n",
" <td>3.1</td>\n",
" <td>5.6</td>\n",
" <td>2.4</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>6.5</td>\n",
" <td>3.2</td>\n",
" <td>5.1</td>\n",
" <td>2.0</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4.9</td>\n",
" <td>2.5</td>\n",
" <td>4.5</td>\n",
" <td>1.7</td>\n",
" <td>Iris-virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6.0</td>\n",
" <td>2.7</td>\n",
" <td>5.1</td>\n",
" <td>1.6</td>\n",
" <td>Iris-versicolor</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sl sw pl pw Gatunek\n",
"0 5.2 3.4 1.4 0.2 Iris-setosa\n",
"1 5.1 3.7 1.5 0.4 Iris-setosa\n",
"2 6.7 3.1 5.6 2.4 Iris-virginica\n",
"3 6.5 3.2 5.1 2.0 Iris-virginica\n",
"4 4.9 2.5 4.5 1.7 Iris-virginica\n",
"5 6.0 2.7 5.1 1.6 Iris-versicolor"
]
},
2022-03-11 10:23:27 +01:00
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas\n",
"data_iris = pandas.read_csv('iris.csv')\n",
"data_iris[:6]"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 27,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X = [[1. 5.2 3.4 1.4 0.2]\n",
" [1. 5.1 3.7 1.5 0.4]\n",
" [1. 6.7 3.1 5.6 2.4]\n",
" [1. 6.5 3.2 5.1 2. ]]\n",
"y = [['Iris-setosa']\n",
" ['Iris-setosa']\n",
" ['Iris-virginica']\n",
" ['Iris-virginica']]\n"
]
}
],
"source": [
"# Przygotowanie danych\n",
"\n",
"import numpy as np\n",
"\n",
"features = ['sl', 'sw', 'pl', 'pw']\n",
"m = len(data_iris)\n",
"X = np.matrix(data_iris[features])\n",
"X0 = np.ones(m).reshape(m, 1)\n",
"X = np.hstack((X0, X))\n",
"y = np.matrix(data_iris[[\"Gatunek\"]]).reshape(m, 1)\n",
"\n",
"print(\"X = \", X[:4])\n",
"print(\"y = \", y[:4])"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Zamieńmy etykiety tekstowe w tablicy $y$ na wektory jednostkowe (*one-hot vectors*):\n",
"\n",
"$$\n",
"\\begin{array}{ccc}\n",
"\\mbox{\"Iris-setosa\"} & \\mapsto & \\left[ \\begin{array}{ccc} 1 & 0 & 0 \\\\ \\end{array} \\right] \\\\\n",
"\\mbox{\"Iris-virginica\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 1 & 0 \\\\ \\end{array} \\right] \\\\\n",
"\\mbox{\"Iris-versicolor\"} & \\mapsto & \\left[ \\begin{array}{ccc} 0 & 0 & 1 \\\\ \\end{array} \\right] \\\\\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wówczas zamiast wektora $y$ otrzymamy macierz $Y$:\n",
"\n",
"$$\n",
"y \\; = \\;\n",
"\\left[\n",
"\\begin{array}{c}\n",
"y^{(1)} \\\\\n",
"y^{(2)} \\\\\n",
"y^{(3)} \\\\\n",
"y^{(4)} \\\\\n",
"y^{(5)} \\\\\n",
"\\vdots \\\\\n",
"\\end{array}\n",
"\\right]\n",
"\\; = \\;\n",
"\\left[\n",
"\\begin{array}{c}\n",
"\\mbox{\"Iris-setosa\"} \\\\\n",
"\\mbox{\"Iris-setosa\"} \\\\\n",
"\\mbox{\"Iris-virginica\"} \\\\\n",
"\\mbox{\"Iris-versicolor\"} \\\\\n",
"\\mbox{\"Iris-virginica\"} \\\\\n",
"\\vdots \\\\\n",
"\\end{array}\n",
"\\right]\n",
"\\quad \\mapsto \\quad\n",
"Y \\; = \\;\n",
"\\left[\n",
"\\begin{array}{ccc}\n",
"1 & 0 & 0 \\\\\n",
"1 & 0 & 0 \\\\\n",
"0 & 1 & 0 \\\\\n",
"0 & 0 & 1 \\\\\n",
"0 & 1 & 0 \\\\\n",
"\\vdots & \\vdots & \\vdots \\\\\n",
"\\end{array}\n",
"\\right]\n",
"$$"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 28,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"def mapY(y, cls):\n",
" m = len(y)\n",
" yBi = np.matrix(np.zeros(m)).reshape(m, 1)\n",
" yBi[y == cls] = 1.\n",
" return yBi\n",
"\n",
"def indicatorMatrix(y):\n",
" classes = np.unique(y.tolist())\n",
" m = len(y)\n",
" k = len(classes)\n",
" Y = np.matrix(np.zeros((m, k)))\n",
" for i, cls in enumerate(classes):\n",
" Y[:, i] = mapY(y, cls)\n",
" return Y\n",
"\n",
"# one-hot matrix\n",
"Y = indicatorMatrix(y)"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 29,
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"outputs": [],
"source": [
"# Podział danych na zbiór trenujący i testowy\n",
"XTrain, XTest = X[:100], X[100:]\n",
"YTrain, YTest = Y[:100], Y[100:]\n",
"\n",
"# Macierz parametrów początkowych - niech skłąda się z samych jedynek\n",
"thetaTemp = np.ones(5).reshape(5,1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Od regresji logistycznej dwuklasowej do wieloklasowej\n",
"\n",
"* Irysy są przydzielone do trzech klas: _Iris-setosa_ (0), _Iris-versicolor_ (1), _Iris-virginica_ (2).\n",
"* Wiemy, jak stworzyć klasyfikatory dwuklasowe typu _Iris-setosa_ vs. _Nie-Iris-setosa_ (tzw. *one-vs-all*).\n",
"* Możemy stworzyć trzy klasyfikatory $h_{\\theta_1}, h_{\\theta_2}, h_{\\theta_3}$ (otrzymując trzy zestawy parametrów $\\theta$) i wybrać klasę o najwyższym prawdopodobieństwie."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "notes"
}
},
"source": [
"Pomoże nam w tym funkcja *softmax*, która jest uogólnieniem funkcji logistycznej na większą liczbę wymiarów."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Funkcja _softmax_\n",
"\n",
"Odpowiednikiem funkcji logistycznej dla wieloklasowej regresji logistycznej jest funkcja $\\mathrm{softmax}$:\n",
"\n",
"$$ \\textrm{softmax} \\colon \\mathbb{R}^k \\to [0,1]^k $$\n",
"\n",
"$$ \\textrm{softmax}(z_1,z_2,\\dots,z_k) = \\left( \\dfrac{e^{z_1}}{\\sum_{i=1}^{k}e^{z_i}}, \\dfrac{e^{z_2}}{\\sum_{i=1}^{k}e^{z_i}}, \\ldots, \\dfrac{e^{z_k}}{\\sum_{i=1}^{k}e^{z_i}} \\right) $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"$$ \\textrm{softmax}( \\left[ \\begin{array}{c} \\theta_1^T x \\\\ \\theta_2^T x \\\\ \\vdots \\\\ \\theta_k^T x \\end{array} \\right] ) = \\left[ \\begin{array}{c} P(y=1 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ P(y=2 \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\\\ \\vdots \\\\ P(y=k \\, | \\, x;\\theta_1,\\ldots,\\theta_k) \\end{array} \\right] $$"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 30,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Zapis macierzowy funkcji softmax\n",
"def softmax(X):\n",
" return np.exp(X) / np.sum(np.exp(X))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Wartości funkcji $\\mathrm{softmax}$ sumują się do 1:"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 31,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9999999999999999\n"
]
}
],
"source": [
"Z = np.matrix([[2.1, 0.5, 0.8, 0.9, 3.2]])\n",
"P = softmax(Z)\n",
"print(np.sum(P)) "
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 32,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Otrzymana macierz parametrów theta dla klasy 0:\n",
2022-03-11 10:23:27 +01:00
" [[ 0.49207837]\n",
" [ 0.32297899]\n",
" [ 1.27518552]\n",
" [-2.10699503]\n",
" [-0.55887263]] \n",
"\n",
"Otrzymana macierz parametrów theta dla klasy 1:\n",
2022-03-11 10:23:27 +01:00
" [[ 0.49531628]\n",
" [ 0.36093408]\n",
" [-1.27157499]\n",
" [ 0.28948428]\n",
" [-0.38775578]] \n",
"\n",
"Otrzymana macierz parametrów theta dla klasy 2:\n",
2022-03-11 10:23:27 +01:00
" [[-0.14368885]\n",
" [-1.59697421]\n",
" [-2.03887846]\n",
" [ 2.40910042]\n",
" [ 2.54991032]] \n",
"\n"
]
}
],
"source": [
"# Dla każdej klasy wytrenujmy osobny klasyfikator dwuklasowy.\n",
"\n",
"def trainMaxEnt(X, Y):\n",
" n = X.shape[1]\n",
" thetas = []\n",
" for c in range(Y.shape[1]):\n",
" YBi = Y[:,c]\n",
" theta = np.matrix(np.random.random(n)).reshape(n,1)\n",
" # Macierz parametrów theta obliczona dla każdej klasy osobno.\n",
" thetaBest, errors = GD(h, J, dJ, theta, \n",
" X, YBi, alpha=0.1, eps=10**-4)\n",
" thetas.append(thetaBest)\n",
" return thetas\n",
"\n",
"# Macierze theta dla każdej klasy\n",
"thetas = trainMaxEnt(XTrain, YTrain);\n",
"for c, theta in enumerate(thetas):\n",
" print(f\"Otrzymana macierz parametrów theta dla klasy {c}:\\n\", theta, \"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Funkcja decyzyjna wieloklasowej regresji logistycznej\n",
"\n",
"$$ c = \\mathop{\\textrm{arg}\\,\\textrm{max}}_{i \\in \\{1, \\ldots ,k\\}} P(y=i|x;\\theta_1,\\ldots,\\theta_k) $$"
]
},
{
"cell_type": "code",
2022-03-11 10:23:27 +01:00
"execution_count": 33,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dla x = [[1. 7.3 2.9 6.3 1.8]]:\n",
2022-03-11 10:23:27 +01:00
"Po zastosowaniu regresji: [-7.73217642 0.56835819 2.05282303]\n",
"Otrzymane prawdopodobieństwa: [0. 0.185 0.815]\n",
"Wybrana klasa: 2\n",
"Obliczone y = 2\n",
"Oczekiwane y = 2\n",
"\n",
"Dla x = [[1. 4.8 3. 1.4 0.3]]:\n",
2022-03-11 10:23:27 +01:00
"Po zastosowaniu regresji: [ 2.75047926 -1.29797382 -9.78808679]\n",
"Otrzymane prawdopodobieństwa: [0.983 0.017 0. ]\n",
"Wybrana klasa: 0\n",
"Obliczone y = 0\n",
"Oczekiwane y = 0\n",
"\n",
"Dla x = [[1. 7.1 3. 5.9 2.1]]:\n",
2022-03-11 10:23:27 +01:00
"Po zastosowaniu regresji: [-6.99411744 0.13689343 1.96966296]\n",
"Otrzymane prawdopodobieństwa: [0. 0.138 0.862]\n",
"Wybrana klasa: 2\n",
"Obliczone y = 2\n",
"Oczekiwane y = 2\n",
"\n",
"Dla x = [[1. 5.9 3. 5.1 1.8]]:\n",
2022-03-11 10:23:27 +01:00
"Po zastosowaniu regresji: [-5.52843441 -0.41148816 1.19377859]\n",
"Otrzymane prawdopodobieństwa: [0.001 0.167 0.832]\n",
"Wybrana klasa: 2\n",
"Obliczone y = 2\n",
"Oczekiwane y = 2\n",
"\n",
"Dla x = [[1. 6.1 2.6 5.6 1.4]]:\n",
2022-03-11 10:23:27 +01:00
"Po zastosowaniu regresji: [-6.80386129 0.4691731 1.87452121]\n",
"Otrzymane prawdopodobieństwa: [0. 0.197 0.803]\n",
"Wybrana klasa: 2\n",
"Obliczone y = 2\n",
"Oczekiwane y = 2\n",
"\n"
]
}
],
"source": [
"def classify(thetas, X, debug=False):\n",
" regs = np.array([(X*theta).item() for theta in thetas])\n",
" if debug:\n",
" print(\"Po zastosowaniu regresji: \", regs)\n",
" probs = softmax(regs)\n",
" if debug:\n",
" print(\"Otrzymane prawdopodobieństwa: \", np.around(probs,decimals=3))\n",
" result = np.argmax(probs)\n",
" if debug:\n",
" print(\"Wybrana klasa: \", result)\n",
" return result\n",
"\n",
"for i in range(5):\n",
" print(f\"Dla x = {XTest[i]}:\")\n",
" YPredicted = classify(thetas, XTest[i], debug=True)\n",
" print(f\"Obliczone y = {YPredicted}\")\n",
" print(f\"Oczekiwane y = {np.argmax(YTest[i])}\")\n",
" print()"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-03-11 10:23:27 +01:00
"version": "3.7.6"
},
"livereveal": {
"start_slideshow_at": "selected",
2021-03-24 12:56:26 +01:00
"theme": "white"
}
},
"nbformat": 4,
"nbformat_minor": 4
}