170 lines
31 KiB
Plaintext
170 lines
31 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "11f16ac1",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Część podstawowa"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "38487b50",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Predicted amount of thefts for 50 fires: 82.70999487819813\n",
|
||
|
"Predicted amount of thefts for 100 fires: 148.45251499453076\n",
|
||
|
"Predicted amount of thefts for 200 fires: 279.93755522719596\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"data = pd.read_csv('fires_thefts.csv', names = ['fires', 'thefts'])\n",
|
||
|
"\n",
|
||
|
"x = data[['fires']].to_numpy().flatten()\n",
|
||
|
"y = data[['thefts']].to_numpy().flatten()\n",
|
||
|
"\n",
|
||
|
"def gradient_descent(h, cost_fun, theta, x, y, alpha, eps, max_steps = 1000000):\n",
|
||
|
" current_cost = cost_fun(h, theta, x, y)\n",
|
||
|
" log = [[current_cost, theta]]\n",
|
||
|
" m = len(y)\n",
|
||
|
" steps_counter = 0\n",
|
||
|
" while True and steps_counter < max_steps:\n",
|
||
|
" steps_counter += 1\n",
|
||
|
" new_theta = [\n",
|
||
|
" theta[0] - alpha/float(m) * sum(h(theta, x[i]) - y[i]\n",
|
||
|
" for i in range(m)), \n",
|
||
|
" theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n",
|
||
|
" for i in range(m))]\n",
|
||
|
" theta = new_theta\n",
|
||
|
" prev_cost = current_cost\n",
|
||
|
" current_cost = cost_fun(h, theta, x, y)\n",
|
||
|
" if abs(prev_cost - current_cost) <= eps:\n",
|
||
|
" break\n",
|
||
|
" log.append([current_cost, theta])\n",
|
||
|
" return theta, log\n",
|
||
|
"\n",
|
||
|
"def J(h, theta, x, y):\n",
|
||
|
" m = len(y)\n",
|
||
|
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))\n",
|
||
|
"\n",
|
||
|
"def h(theta, x):\n",
|
||
|
" return theta[0] + theta[1] * x\n",
|
||
|
"\n",
|
||
|
"def mse(expected, predicted):\n",
|
||
|
" m = len(expected)\n",
|
||
|
" if len(predicted) != m:\n",
|
||
|
" raise Exception('Wektory mają różne długości!')\n",
|
||
|
" return 1.0 / (2 * m) * sum((expected[i] - predicted[i])**2 for i in range(m))\n",
|
||
|
"\n",
|
||
|
"best_theta, log = gradient_descent(h, J, [0.0, 0.0], x, y, alpha=0.001, eps=0.0000001, max_steps = 1000000)\n",
|
||
|
"\n",
|
||
|
"predicted_50 = h(best_theta, 50)\n",
|
||
|
"predicted_100 = h(best_theta, 100)\n",
|
||
|
"predicted_200 = h(best_theta, 200)\n",
|
||
|
"print(f'Predicted amount of thefts for 50 fires: {predicted_50}')\n",
|
||
|
"print(f'Predicted amount of thefts for 100 fires: {predicted_100}')\n",
|
||
|
"print(f'Predicted amount of thefts for 200 fires: {predicted_200}')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "a2126b66",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Część zaawansowana"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "e98cdca2",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/var/folders/lm/cbc3n48n4x94zd3vf6zbbly40000gn/T/ipykernel_46784/756364182.py:32: RuntimeWarning: overflow encountered in scalar power\n",
|
||
|
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))\n",
|
||
|
"/var/folders/lm/cbc3n48n4x94zd3vf6zbbly40000gn/T/ipykernel_46784/756364182.py:25: RuntimeWarning: invalid value encountered in scalar subtract\n",
|
||
|
" if abs(prev_cost - current_cost) <= eps:\n",
|
||
|
"/var/folders/lm/cbc3n48n4x94zd3vf6zbbly40000gn/T/ipykernel_46784/756364182.py:20: RuntimeWarning: overflow encountered in scalar add\n",
|
||
|
" theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n",
|
||
|
"/var/folders/lm/cbc3n48n4x94zd3vf6zbbly40000gn/T/ipykernel_46784/756364182.py:20: RuntimeWarning: overflow encountered in scalar multiply\n",
|
||
|
" theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n",
|
||
|
"/var/folders/lm/cbc3n48n4x94zd3vf6zbbly40000gn/T/ipykernel_46784/756364182.py:20: RuntimeWarning: invalid value encountered in scalar subtract\n",
|
||
|
" theta[1] - alpha/float(m) * sum((h(theta, x[i]) - y[i]) * x[i]\n",
|
||
|
"/var/folders/lm/cbc3n48n4x94zd3vf6zbbly40000gn/T/ipykernel_46784/756364182.py:32: RuntimeWarning: overflow encountered in scalar add\n",
|
||
|
" return 1.0 / (2 * m) * sum((h(theta, x[i]) - y[i])**2 for i in range(m))\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqoAAAJpCAYAAABl+RBwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJTklEQVR4nO3de1hVZd7/8c9GBTQFEpBDKpqZh1BCLaUytTTPRdPpcSztMJ1+WhrlGNWjMzaNNmXlk44OllqNZtmMNmXpmIYdREOUJrM8DYWZQJqCiqKx1++PZMFWQFBh3bLer+val+6919r7u3cr+Phd930vj2VZlgAAAADD+DldAAAAAFAegioAAACMRFAFAACAkQiqAAAAMBJBFQAAAEYiqAIAAMBIBFUAAAAYiaAKAAAAIxFUAQAAYCSCKgAAAIxUZ4PqJ598oqFDhyo6Oloej0dLliyp1v5btmxRnz59FBERocDAQF144YV66qmndOzYMZ/tFi1apPbt2yswMFCdOnXSBx98UOFrPvDAA/J4PHrppZdO4xMBAAC4S50NqocOHVJcXJxmzJhxWvs3aNBAI0aM0L///W9t2bJFL730kmbPnq2JEyfa26xZs0bDhg3TPffco40bNyoxMVGJiYnatGnTSa+3ePFirV27VtHR0af9mQAAANzEY1mW5XQRNc3j8Wjx4sVKTEy0HysqKtKTTz6pN998U/v371dsbKyeffZZ9e7du8LXSUpKUnp6uj799FNJ0m233aZDhw7p/ffft7fp0aOHLr30Us2aNct+bNeuXerevbuWL1+uwYMHa+zYsRo7duzZ/pgAAAB1Sp3tqJ7K6NGjlZaWpoULF+o///mPbrnlFg0YMEDbtm0rd/vt27dr2bJl6tWrl/1YWlqa+vbt67Nd//79lZaWZt/3er264447NG7cOF1yySU182EAAADqIFcG1ezsbM2dO1eLFi1Sz5491aZNGz322GO66qqrNHfuXJ9tr7jiCgUGBqpt27bq2bOnJk2aZD+Xk5OjiIgIn+0jIiKUk5Nj33/22WdVv359PfzwwzX7oQAAAOqY+k4X4ISvvvpKxcXFuvjii30eLyoqUmhoqM9jb731lg4cOKAvv/xS48aN0/PPP6/f//73VXqfjIwMTZs2TRs2bJDH4zlr9QMAALiBK4PqwYMHVa9ePWVkZKhevXo+zzVu3NjnfosWLSRJHTt2VHFxse677z49+uijqlevniIjI5Wbm+uzfW5uriIjIyVJn376qfLy8tSyZUv7+eLiYj366KN66aWX9N1339XApwMAAKgbXBlU4+PjVVxcrLy8PPXs2bPK+3m9Xh07dkxer1f16tVTQkKCVq5c6TMxasWKFUpISJAk3XHHHeWOYb3jjjt01113nZXPAgAAUFfV2aB68OBBbd++3b6flZWlzMxMNW3aVBdffLGGDx+uESNGaOrUqYqPj9dPP/2klStXqnPnzho8eLDmz5+vBg0aqFOnTgoICND69euVnJys2267TQ0aNJAkjRkzRr169dLUqVM1ePBgLVy4UOvXr1dKSookKTQ09KShBA0aNFBkZKTatWtXe18GAADAOajOBtX169erT58+9v2kpCRJ0siRIzVv3jzNnTtXf/rTn/Too49q165dCgsLU48ePTRkyBBJUv369fXss89q69atsixLMTExGj16tB555BH7Na+44gotWLBATz31lJ544gm1bdtWS5YsUWxsbO1+WAAAgDrIFeuoAgAA4NzjyuWpAAAAYD6CKgAAAIxUp8aoer1e/fjjj2rSpAnrlgIAABjIsiwdOHBA0dHR8vOrvGdap4Lqjz/+aK97CgAAAHPt3LlTzZs3r3SbOhVUmzRpIunXDx4UFORwNQAAADhRQUGBWrRoYee2ytSpoFpyuj8oKIigCgAAYLCqDNNkMhUAAACMZFRQbdWqlTwez0m3UaNGOV0aAAAAaplRp/7T09NVXFxs39+0aZP69eunW265xcGqAAAA4ASjgmp4eLjP/SlTpqhNmzbq1auXQxUBAOAuxcXFOnbsmNNl4BxWr1491a9f/6wsFWpUUC3r6NGj+vvf/66kpKQKP2hRUZGKiors+wUFBbVVHgAAdc7Bgwf1ww8/iKur40w1atRIUVFR8vf3P6PXMTaoLlmyRPv379edd95Z4TaTJ0/WH//4x9orCgCAOqq4uFg//PCDGjVqpPDwcC6cg9NiWZaOHj2qn376SVlZWWrbtu0pF/WvjMcy9J9N/fv3l7+/v957770Ktymvo9qiRQvl5+ezPBUAANVw5MgRZWVlqVWrVmrYsKHT5eAcV1hYqO+//16tW7dWYGCgz3MFBQUKDg6uUl4zsqP6/fff66OPPtI///nPSrcLCAhQQEBALVUFAEDdRycVZ8OZdFF9XuesvMpZNnfuXDVr1kyDBw92uhQAAAA4xLig6vV6NXfuXI0cOVL16xvZ8AUAAEAtMC6ofvTRR8rOztbdd9/tdCkAAOAc991338nj8SgzM7PK+8ybN08hISE1VhOqzriget1118myLF188cVOlwIAAGCURYsWqX379goMDFSnTp30wQcfVLr97t279dvf/lYXX3yx/Pz8NHbs2Nop9CwxLqgCAADgZGvWrNGwYcN0zz33aOPGjUpMTFRiYqI2bdpU4T5FRUUKDw/XU089pbi4uFqs9uwgqAIAgJNYlqXCo784cqvOypnLli3TVVddpZCQEIWGhmrIkCHasWNHhdunpqbK4/Fo6dKl6ty5swIDA9WjR49yw97y5cvVoUMHNW7cWAMGDNDu3bvt59LT09WvXz+FhYUpODhYvXr10oYNG6r3JVfTtGnTNGDAAI0bN04dOnTQ008/rS5dumj69OkV7tOqVStNmzZNI0aMUHBwcI3WVxOYrQQAAE5y+FixOk5Y7sh7b57UX438qxZRDh06pKSkJHXu3FkHDx7UhAkTdOONNyozM7PSJZLGjRunadOmKTIyUk888YSGDh2qrVu3qkGDBpJ+XQf0+eef1xtvvCE/Pz/dfvvteuyxxzR//nxJ0oEDBzRy5Ei9/PLLsixLU6dO1aBBg7Rt2zY1adKk3PecP3++7r///ko/z4cffqiePXuW+1xaWpqSkpJ8Huvfv7+WLFlS6WueywiqAADgnHXTTTf53J8zZ47Cw8O1efNmxcbGVrjfxIkT1a9fP0nSa6+9pubNm2vx4sW69dZbJUnHjh3TrFmz1KZNG0nS6NGjNWnSJHv/a665xuf1UlJSFBISotWrV2vIkCHlvuf111+v7t27V/p5Lrjgggqfy8nJUUREhM9jERERysnJqfQ1z2UEVQAAcJKGDepp86T+jr13VW3btk0TJkzQunXrtGfPHnm9XklSdnZ2pUE1ISHB/nvTpk3Vrl07ffPNN/ZjjRo1skOqJEVFRSkvL8++n5ubq6eeekqpqanKy8tTcXGxCgsLlZ2dXeF7NmnSpMJuK8pHUAUAACfxeDxVPv3upKFDhyomJkazZ89WdHS0vF6vYmNjdfTo0TN63ZIhACU8Ho/P2NmRI0dq7969mjZtmmJiYhQQEKCEhIRK3/dMT/1HRkYqNzfX57Hc3FxFRkae6uOcs8w/AgEAAMqxd+9ebdmyRbNnz7bD3WeffValfdeuXauWLVtKkvbt26etW7eqQ4cOVX7vzz//XH/96181aNAgSdLOnTu1Z8+eSvc501P/CQkJWrlypc8SUytWrPDpDtc1BFUAAHBOOv/88xUaGqqUlBRFRUUpOztbjz/+eJX2nTRpkkJDQxUREaEnn3xSYWFhSkxMrPJ7t23bVm+88Ya6deumgoICjRs3Tg0bNqx0nzM99T9mzBj16tVLU6dO1eDBg7Vw4UKtX79eKSkp9jbJycnatWuXXn/9dfuxkosdHDx4UD/99JMyMzPl7++vjh07nnYttYXlqQAAkqRB0z7VgJc+0Z6DRU6XAlSJn5+fFi5cqIyMDMXGxuqRRx7Rc889V6V9p0yZojFjxqhr167KycnRe++9J39//yq/96uvvqp9+/apS5cuuuOOO/Twww+rWbNmp/tRquSKK67QggULlJKSori
|
||
|
"text/plain": [
|
||
|
"<Figure size 800x700 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"best_theta_01, log_01 = gradient_descent(h, J, [0.0, 0.0], x, y, alpha = 0.1, eps = 0.0000001, max_steps = 1000)\n",
|
||
|
"best_theta_001, log_001 = gradient_descent(h, J, [0.0, 0.0], x, y, alpha = 0.01, eps = 0.0000001, max_steps = 1000)\n",
|
||
|
"best_theta_0001, log_0001 = gradient_descent(h, J, [0.0, 0.0], x, y, alpha = 0.001, eps = 0.0000001, max_steps = 1000)\n",
|
||
|
"\n",
|
||
|
"steps_range = np.arange(0, 200, 1)\n",
|
||
|
"y_01, y_001, y_0001 = [], [], []\n",
|
||
|
"for step in steps_range:\n",
|
||
|
" y_01.append(log_01[step][0])\n",
|
||
|
" y_001.append(log_001[step][0])\n",
|
||
|
" y_0001.append(log_0001[step][0])\n",
|
||
|
"\n",
|
||
|
"fig = plt.figure(figsize=(8, 7))\n",
|
||
|
"ax = fig.add_subplot(111)\n",
|
||
|
"ax.plot(steps_range, y_01, label='alpha = 0.1')\n",
|
||
|
"ax.plot(steps_range, y_001, label='alpha = 0.01')\n",
|
||
|
"ax.plot(steps_range, y_0001, label='alpha = 0.001')\n",
|
||
|
"ax.legend(loc='best')\n",
|
||
|
"ax.set_xlabel('krok')\n",
|
||
|
"ax.set_ylabel(r'$J(\\theta)$')\n",
|
||
|
"# plt.ylim([0, 800])\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|