test/Ans.ipynb

232 lines
38 KiB
Plaintext
Raw Permalink Normal View History

2022-11-15 13:06:57 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = \"svg\"\n",
"\n",
"from IPython.display import display, Math, Latex\n",
"\n",
"data = pd.read_csv(\"fires_thefts.csv\", names=[\"x\", \"y\"])\n",
"\n",
"x = data[\"x\"].to_numpy()\n",
"y = data[\"y\"].to_numpy()\n",
"\n",
"# Hipoteza: funkcja liniowa jednej zmiennej\n",
"def h(theta, x):\n",
" return theta[0] + theta[1] * x\n",
"\n",
"# Funkcja kosztu\n",
"def J(h, theta, x, y):\n",
" m = len(y)\n",
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i]) ** 2 for i in range(m))\n",
"\n",
"# Wyświetlanie macierzy w LaTeX-u\n",
"def LatexMatrix(matrix):\n",
" ltx = r\"\\left[\\begin{array}\"\n",
" m, n = matrix.shape\n",
" ltx += \"{\" + (\"r\" * n) + \"}\"\n",
" for i in range(m):\n",
" ltx += r\" & \".join([(\"%.4f\" % j.item()) for j in matrix[i]]) + r\" \\\\ \"\n",
" ltx += r\"\\end{array}\\right]\"\n",
" return ltx\n",
"\n",
"def gradient_descent(h, cost_fun, theta, x, y, alpha, eps):\n",
" current_cost = cost_fun(h, theta, x, y)\n",
" history = [\n",
" [current_cost, theta]\n",
" ] # zapiszmy wartości kosztu i parametrów, by potem zrobić wykres\n",
" m = len(y)\n",
" while True:\n",
" new_theta = [\n",
" theta[0] - alpha / float(m) * sum(h(theta, x[i]) - y[i] for i in range(m)),\n",
" theta[1]\n",
" - alpha / float(m) * sum((h(theta, x[i]) - y[i]) * x[i] for i in range(m)),\n",
" ]\n",
" theta = new_theta # jednoczesna aktualizacja - używamy zmiennej tymczasowej\n",
" try:\n",
" prev_cost = current_cost\n",
" current_cost = cost_fun(h, theta, x, y)\n",
" except OverflowError:\n",
" break\n",
" if abs(prev_cost - current_cost) <= eps:\n",
" break\n",
" history.append([current_cost, theta])\n",
" return theta, history\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\large\\textrm{Wynik:}\\quad \\theta = \\left[\\begin{array}{r}16.9446 \\\\ 1.3160 \\\\ \\end{array}\\right] \\quad J(\\theta) = 180.4105 \\quad \\textrm{po 5369 iteracjach}$"
],
"text/plain": [
"<IPython.core.display.Math object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_theta, history = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.003, eps=0.000001)\n",
"\n",
"display(\n",
" Math(\n",
" r\"\\large\\textrm{Wynik:}\\quad \\theta = \"\n",
" + LatexMatrix(np.matrix(best_theta).reshape(2, 1))\n",
" + (r\" \\quad J(\\theta) = %.4f\" % history[-1][0])\n",
" + r\" \\quad \\textrm{po %d iteracjach}\" % len(history)\n",
" )\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"eps= 1.0, cost= 231.741, steps= 4\n",
"eps= 0.1, cost= 226.569, steps= 52\n",
"eps= 0.01, cost= 185.031, steps= 1115\n",
"eps= 0.001, cost= 180.872, steps= 2179\n",
"eps= 0.0001, cost= 180.456, steps= 3242\n",
"eps= 1e-05, cost= 180.415, steps= 4306\n",
"eps= 1e-06, cost= 180.411, steps= 5369\n",
"eps= 1e-07, cost= 180.410, steps= 6433\n",
"eps= 1e-08, cost= 180.410, steps= 7496\n",
"eps= 1e-09, cost= 180.410, steps= 8560\n",
"eps= 1e-10, cost= 180.410, steps= 9623\n",
"eps= 1e-11, cost= 180.410, steps= 10687\n"
]
}
],
"source": [
"\n",
"epss = [10.0**(-n) for n in range(0, 12)]\n",
"alpha=0.003\n",
"costs = []\n",
"lengths = []\n",
"for eps in epss:\n",
" theta_best, history = gradient_descent(\n",
" h, J, [0.0, 0.0], x, y, alpha, eps)\n",
" cost = history[-1][0]\n",
" steps = len(history)\n",
" print(f\"{eps=:7}, {cost=:15.3f}, {steps=:6}\")\n",
" costs.append(cost)\n",
" lengths.append(steps)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def eps_cost_steps_plot(eps, costs, steps):\n",
" \"\"\"Wykres kosztu i liczby kroków w zależności od eps\"\"\"\n",
" fig, ax1 = plt.subplots()\n",
" ax2 = ax1.twinx()\n",
" ax1.plot(eps, steps, \"--s\", color=\"green\")\n",
" ax2.plot(eps, costs, \":o\", color=\"orange\")\n",
" ax1.set_xscale(\"log\")\n",
" ax1.set_xlabel(\"eps\")\n",
" ax1.set_ylabel(\"liczba kroków\", color=\"green\")\n",
" ax2.set_ylabel(\"koszt\", color=\"orange\")\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"463.77625pt\" height=\"310.86825pt\" viewBox=\"0 0 463.77625 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2022-10-31T17:08:50.563047</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.6.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 310.86825 \nL 463.77625 310.86825 \nL 463.77625 0 \nL 0 0 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 59.690625 273.312 \nL 416.810625 273.312 \nL 416.810625 7.2 \nL 59.690625 7.2 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path id=\"m958b897915\" d=\"M 0 0 \nL 0 3.5 \n\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n </defs>\n <g>\n <use xlink:href=\"#m958b897915\" x=\"105.437402\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- $\\mathdefault{10^{-10}}$ -->\n <g transform=\"translate(91.487402 287.910437) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-31\" d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" transform=\"scale(0.015625)\"/>\n <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \nL 4684 2272 \nL 4684 1741 \nL 678 1741 \nL 678 2272 \nz\n\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(0 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(63.623047 0.765625)\"/>\n <use xlink:href=\"#DejaVuSans-2212\" transform=\"translate(128.203125 39.046875) scale(0.7)\"/>\n <use xlink:href=\"#DejaVuSans-31\" transform=\"translate(186.855469 39.046875) scale(0.7)\"/>\n <use xlink:href=\"#DejaVuSans-30\" transform=\"translate(231.391602 39.046875) scale(0.7)\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_2\">\n <g>\n <use xlink:href=\"#m958b897915\" x=\"164.465501\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- $\\mathdefault{10^{-8}}$ -->\n <g transform=\"translate(152.715501 287.910437) scale(0.1 -0.1)\">\n <defs>\n <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \nQ 1584 2216 1326 1975 \nQ 1069 1734 1069 1313 \nQ 1069 891 1326 650 \nQ 1584 409 2034 409 \nQ 2484 409 2743 651 \nQ 3003 894 3003 1313 \nQ 3003 1734 2745 1975 \nQ 2488 2216 2034 2216 \nz\nM 1403 2484 \nQ 997 2584 770 2862 \nQ 544 3141 544 3541 \nQ 544 4100 942 4425 \nQ 1341 4750 2034 4750 \nQ 2731 4750 3128 4425 \nQ 3525 4100 3525 3541 \nQ 3525 3141 3298 2862 \nQ 3072 2584 2669 2484 \nQ 3125 2378 3379 2068 \nQ
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"eps_cost_steps_plot(epss, costs, lengths)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[16.835521154474677, 1.3214970549417684]\n",
"Liczba pozarów - 50 Przewidywalna liczba włamań - 82.91037390156309\n",
"Liczba pozarów - 100 Przewidywalna liczba włamań - 148.98522664865152\n",
"Liczba pozarów - 200 Przewidywalna liczba włamań - 281.13493214282835\n"
]
}
],
"source": [
"example_x = [50, 100, 200]\n",
"print(best_theta)\n",
"example_y = [h(best_theta, ex) for ex in example_x]\n",
"for i in range(3):\n",
" print(f\"Liczba pozarów - {example_x[i]} \"\n",
" f\"Przewidywalna liczba włamań - {example_y[i]}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.8 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "f37486876ad4b243625dcab03485f0edb2a22cb7fa9db711ceb1161e85adf5f1"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}