{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Regresja wielomianowa" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przydatne funkcje\n", "\n", "def cost(theta, X, y):\n", " \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n", " m = len(y)\n", " J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n", " return J.item()\n", "\n", "def gradient(theta, X, y):\n", " \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n", " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", "\n", "def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-7):\n", " \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n", " current_cost = fJ(theta, X, y)\n", " logs = [[current_cost, theta]]\n", " while True:\n", " theta = theta - alpha * fdJ(theta, X, y)\n", " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", " if abs(prev_cost - current_cost) > 10**15:\n", " print('Algorithm does not converge!')\n", " break\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " logs.append([current_cost, theta]) \n", " return theta, logs\n", "\n", "def plot_data(X, y, xlabel, ylabel):\n", " \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n", " fig = plt.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.margins(.05, .05)\n", " plt.ylim(y.min() - 1, y.max() + 1)\n", " plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", " return fig\n", "\n", "def plot_fun(fig, fun, X):\n", " \"\"\"Wykres funkcji `fun`\"\"\"\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " Arg = np.arange(x0, x1, 0.1)\n", " Val = fun(Arg)\n", " return ax.plot(Arg, Val, linewidth='2')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def MSE(Y_true, Y_pred):\n", " return np.square(np.subtract(Y_true,Y_pred)).mean()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# Funkcja regresji wielomianowej\n", "\n", "def h_poly(Theta, x):\n", " \"\"\"Funkcja wielomianowa\"\"\"\n", " return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n", "\n", "def get_poly_data(data, deg):\n", " m, n_plus_1 = data.shape\n", " n = n_plus_1 - 1\n", "\n", " X1 = data[:, 0:n]\n", " X1 /= np.amax(X1, axis=0)\n", "\n", " Xs = [np.ones((m, 1)), X1]\n", "\n", " for i in range(2, deg+1):\n", " Xn = np.power(X1, i)\n", " Xn /= np.amax(Xn, axis=0)\n", " Xs.append(Xn)\n", "\n", " X = np.matrix(np.concatenate(Xs, axis=1)).reshape(m, deg * n + 1)\n", "\n", " y = np.matrix(data[:, -1]).reshape(m, 1)\n", "\n", " return X, y\n", "\n", "\n", "def polynomial_regression(X, y, n):\n", " \"\"\"Funkcja regresji wielomianowej\"\"\"\n", " theta_start = np.matrix([0] * (n+1)).reshape(n+1, 1)\n", " theta, logs = gradient_descent(cost, gradient, theta_start, X, y)\n", " return lambda x: h_poly(theta, x)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def predict_values(model, data, n):\n", " x, y = get_poly_data(np.array(data), n)\n", " preprocessed_x = []\n", " for i in x:\n", " preprocessed_x.append(i.item(1))\n", " return y, model(preprocessed_x), MSE(y, model(preprocessed_x))\n", "\n", "def plot_and_mse(data, data_test, n):\n", " x, y = get_poly_data(np.array(data), n)\n", " model = polynomial_regression(x, y, n)\n", " \n", " fig = plot_data(x, y, xlabel='x', ylabel='y')\n", " plot_fun(fig, polynomial_regression(x, y, n), x)\n", "\n", " y_true, Y_pred, mse = predict_values(model, data_test, n)\n", " print(f'Wielomian {n} stopnia, MSE = {mse}')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | sqrMetres | \n", "price | \n", "
---|---|---|
470 | \n", "40 | \n", "1140000.0 | \n", "
1171 | \n", "90 | \n", "855000.0 | \n", "
1128 | \n", "37 | \n", "288405.0 | \n", "
254 | \n", "49 | \n", "290000.0 | \n", "
508 | \n", "91 | \n", "375606.0 | \n", "
... | \n", "... | \n", "... | \n", "
389 | \n", "56 | \n", "325000.0 | \n", "
1403 | \n", "69 | \n", "399000.0 | \n", "
957 | \n", "94 | \n", "595000.0 | \n", "
356 | \n", "53 | \n", "339200.0 | \n", "
160 | \n", "44 | \n", "349668.0 | \n", "
1674 rows × 2 columns
\n", "\n", " | number_courses | \n", "time_study | \n", "Marks | \n", "
---|---|---|---|
0 | \n", "3 | \n", "4.508 | \n", "19.202 | \n", "
1 | \n", "4 | \n", "0.096 | \n", "7.734 | \n", "
2 | \n", "4 | \n", "3.133 | \n", "13.811 | \n", "
3 | \n", "6 | \n", "7.909 | \n", "53.018 | \n", "
4 | \n", "8 | \n", "7.811 | \n", "55.299 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "6 | \n", "3.561 | \n", "19.128 | \n", "
96 | \n", "3 | \n", "0.301 | \n", "5.609 | \n", "
97 | \n", "4 | \n", "7.163 | \n", "41.444 | \n", "
98 | \n", "7 | \n", "0.309 | \n", "12.027 | \n", "
99 | \n", "3 | \n", "6.335 | \n", "32.357 | \n", "
100 rows × 3 columns
\n", "\n", " | age | \n", "sex | \n", "bmi | \n", "children | \n", "smoker | \n", "region | \n", "charges | \n", "
---|---|---|---|---|---|---|---|
955 | \n", "31 | \n", "male | \n", "39.490 | \n", "1 | \n", "no | \n", "southeast | \n", "3875.73410 | \n", "
644 | \n", "43 | \n", "male | \n", "35.310 | \n", "2 | \n", "no | \n", "southeast | \n", "18806.14547 | \n", "
1210 | \n", "36 | \n", "male | \n", "30.875 | \n", "1 | \n", "no | \n", "northwest | \n", "5373.36425 | \n", "
260 | \n", "58 | \n", "female | \n", "25.200 | \n", "0 | \n", "no | \n", "southwest | \n", "11837.16000 | \n", "
740 | \n", "45 | \n", "male | \n", "24.035 | \n", "2 | \n", "no | \n", "northeast | \n", "8604.48365 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
808 | \n", "18 | \n", "male | \n", "30.140 | \n", "0 | \n", "no | \n", "southeast | \n", "1131.50660 | \n", "
301 | \n", "53 | \n", "female | \n", "22.610 | \n", "3 | \n", "yes | \n", "northeast | \n", "24873.38490 | \n", "
664 | \n", "64 | \n", "female | \n", "22.990 | \n", "0 | \n", "yes | \n", "southeast | \n", "27037.91410 | \n", "
989 | \n", "24 | \n", "female | \n", "20.520 | \n", "0 | \n", "yes | \n", "northeast | \n", "14571.89080 | \n", "
1121 | \n", "46 | \n", "male | \n", "38.170 | \n", "2 | \n", "no | \n", "southeast | \n", "8347.16430 | \n", "
1338 rows × 7 columns
\n", "