{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Regresja wielomianowa" ] }, { "cell_type": "code", "execution_count": 137, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 138, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przydatne funkcje\n", "\n", "def cost(theta, X, y):\n", " \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n", " m = len(y)\n", " J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n", " return J.item()\n", "\n", "def gradient(theta, X, y):\n", " \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n", " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", "\n", "def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-7):\n", " \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n", " current_cost = fJ(theta, X, y)\n", " logs = [[current_cost, theta]]\n", " while True:\n", " theta = theta - alpha * fdJ(theta, X, y)\n", " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", " if abs(prev_cost - current_cost) > 10**15:\n", " print('Algorithm does not converge!')\n", " break\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " logs.append([current_cost, theta]) \n", " return theta, logs\n", "\n", "def plot_data(X, y, xlabel, ylabel):\n", " \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n", " fig = plt.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.margins(.05, .05)\n", " plt.ylim(y.min() - 1, y.max() + 1)\n", " plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", " return fig\n", "\n", "def plot_fun(fig, fun, X):\n", " \"\"\"Wykres funkcji `fun`\"\"\"\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " Arg = np.arange(x0, x1, 0.1)\n", " Val = fun(Arg)\n", " return ax.plot(Arg, Val, linewidth='2')" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [], "source": [ "def MSE(Y_true, Y_pred):\n", " return np.square(np.subtract(Y_true,Y_pred)).mean()" ] }, { "cell_type": "code", "execution_count": 140, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# Funkcja regresji wielomianowej\n", "\n", "def h_poly(Theta, x):\n", " \"\"\"Funkcja wielomianowa\"\"\"\n", " return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n", "\n", "def get_poly_data(data, deg):\n", " m, n_plus_1 = data.shape\n", " n = n_plus_1 - 1\n", "\n", " X1 = data[:, 0:n]\n", " X1 /= np.amax(X1, axis=0)\n", "\n", " Xs = [np.ones((m, 1)), X1]\n", "\n", " for i in range(2, deg+1):\n", " Xn = np.power(X1, i)\n", " Xn /= np.amax(Xn, axis=0)\n", " Xs.append(Xn)\n", "\n", " X = np.matrix(np.concatenate(Xs, axis=1)).reshape(m, deg * n + 1)\n", "\n", " y = np.matrix(data[:, -1]).reshape(m, 1)\n", "\n", " return X, y\n", "\n", "\n", "def polynomial_regression(X, y, n):\n", " \"\"\"Funkcja regresji wielomianowej\"\"\"\n", " theta_start = np.matrix([0] * (n+1)).reshape(n+1, 1)\n", " theta, logs = gradient_descent(cost, gradient, theta_start, X, y)\n", " return lambda x: h_poly(theta, x)" ] }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [], "source": [ "def predict_values(model, data, n):\n", " x, y = get_poly_data(np.array(data), n)\n", " preprocessed_x = []\n", " for i in x:\n", " preprocessed_x.append(i.item(1))\n", " return y, model(preprocessed_x), MSE(y, model(preprocessed_x))\n", "\n", "def plot_and_mse(data, data_test, n):\n", " x, y = get_poly_data(np.array(data), n)\n", " model = polynomial_regression(x, y, n)\n", " \n", " fig = plot_data(x, y, xlabel='x', ylabel='y')\n", " plot_fun(fig, polynomial_regression(x, y, n), x)\n", "\n", " y_true, Y_pred, mse = predict_values(model, data_test, n)\n", " print(f'Wielomian {n} stopnia, MSE = {mse}')" ] }, { "cell_type": "code", "execution_count": 152, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | sqrMetres | \n", "price | \n", "
---|---|---|
49 | \n", "37 | \n", "338000.00 | \n", "
1171 | \n", "90 | \n", "855000.00 | \n", "
368 | \n", "16 | \n", "399000.00 | \n", "
1206 | \n", "58 | \n", "359602.00 | \n", "
1500 | \n", "20 | \n", "424977.14 | \n", "
... | \n", "... | \n", "... | \n", "
50 | \n", "78 | \n", "420000.00 | \n", "
396 | \n", "52 | \n", "275000.00 | \n", "
1367 | \n", "55 | \n", "192750.00 | \n", "
771 | \n", "62 | \n", "558745.00 | \n", "
337 | \n", "55 | \n", "246330.00 | \n", "
1674 rows × 2 columns
\n", "\n", " | number_courses | \n", "time_study | \n", "Marks | \n", "
---|---|---|---|
0 | \n", "3 | \n", "4.508 | \n", "19.202 | \n", "
1 | \n", "4 | \n", "0.096 | \n", "7.734 | \n", "
2 | \n", "4 | \n", "3.133 | \n", "13.811 | \n", "
3 | \n", "6 | \n", "7.909 | \n", "53.018 | \n", "
4 | \n", "8 | \n", "7.811 | \n", "55.299 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "6 | \n", "3.561 | \n", "19.128 | \n", "
96 | \n", "3 | \n", "0.301 | \n", "5.609 | \n", "
97 | \n", "4 | \n", "7.163 | \n", "41.444 | \n", "
98 | \n", "7 | \n", "0.309 | \n", "12.027 | \n", "
99 | \n", "3 | \n", "6.335 | \n", "32.357 | \n", "
100 rows × 3 columns
\n", "\n", " | age | \n", "sex | \n", "bmi | \n", "children | \n", "smoker | \n", "region | \n", "charges | \n", "
---|---|---|---|---|---|---|---|
238 | \n", "19 | \n", "male | \n", "29.070 | \n", "0 | \n", "yes | \n", "northwest | \n", "17352.68030 | \n", "
809 | \n", "25 | \n", "male | \n", "25.840 | \n", "1 | \n", "no | \n", "northeast | \n", "3309.79260 | \n", "
1053 | \n", "47 | \n", "male | \n", "29.800 | \n", "3 | \n", "yes | \n", "southwest | \n", "25309.48900 | \n", "
1177 | \n", "40 | \n", "female | \n", "27.400 | \n", "1 | \n", "no | \n", "southwest | \n", "6496.88600 | \n", "
964 | \n", "52 | \n", "male | \n", "36.765 | \n", "2 | \n", "no | \n", "northwest | \n", "26467.09737 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
374 | \n", "20 | \n", "male | \n", "33.330 | \n", "0 | \n", "no | \n", "southeast | \n", "1391.52870 | \n", "
950 | \n", "57 | \n", "male | \n", "18.335 | \n", "0 | \n", "no | \n", "northeast | \n", "11534.87265 | \n", "
954 | \n", "34 | \n", "male | \n", "27.835 | \n", "1 | \n", "yes | \n", "northwest | \n", "20009.63365 | \n", "
521 | \n", "32 | \n", "female | \n", "44.220 | \n", "0 | \n", "no | \n", "southeast | \n", "3994.17780 | \n", "
963 | \n", "46 | \n", "male | \n", "24.795 | \n", "3 | \n", "no | \n", "northeast | \n", "9500.57305 | \n", "
1338 rows × 7 columns
\n", "