{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Regresja wielomianowa" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [], "source": [ "# Przydatne funkcje\n", "cost_functions = dict()\n", "\n", "def cost(theta, X, y):\n", " \"\"\"Wersja macierzowa funkcji kosztu\"\"\"\n", " m = len(y)\n", " J = 1.0 / (2.0 * m) * ((X * theta - y).T * (X * theta - y))\n", " return J.item()\n", "\n", "def gradient(theta, X, y):\n", " \"\"\"Wersja macierzowa gradientu funkcji kosztu\"\"\"\n", " return 1.0 / len(y) * (X.T * (X * theta - y)) \n", "\n", "def gradient_descent(fJ, fdJ, theta, X, y, alpha=0.1, eps=10**-7):\n", " \"\"\"Algorytm gradientu prostego (wersja macierzowa)\"\"\"\n", " current_cost = fJ(theta, X, y)\n", " logs = [[current_cost, theta]]\n", " while True:\n", " theta = theta - alpha * fdJ(theta, X, y)\n", " current_cost, prev_cost = fJ(theta, X, y), current_cost\n", " # print(current_cost)\n", " if abs(prev_cost - current_cost) > 10**15:\n", " print('Algorithm does not converge!')\n", " break\n", " if abs(prev_cost - current_cost) <= eps:\n", " break\n", " logs.append([current_cost, theta]) \n", " return theta, logs\n", "\n", "def plot_data(X, y, xlabel, ylabel):\n", " \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n", " fig = plt.figure(figsize=(16*.6, 9*.6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X[:, 1]], [y], c='r', s=50, label='Dane')\n", " \n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.margins(.05, .05)\n", " plt.ylim(y.min() - 1, y.max() + 1)\n", " plt.xlim(np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1)\n", " return fig\n", "\n", "def plot_data_cost(X, y, xlabel, ylabel):\n", " \"\"\"Wykres danych (wersja macierzowa)\"\"\"\n", " fig = plt.figure(figsize=(16 * .6, 9 * .6))\n", " ax = fig.add_subplot(111)\n", " fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)\n", " ax.scatter([X], [y], c='r', s=50, label='Dane')\n", "\n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.margins(.05, .05)\n", " plt.ylim(min(y) - 1, max(y) + 1)\n", " plt.xlim(np.min(X) - 1, np.max(X) + 1)\n", " return fig\n", "\n", "def plot_fun(fig, fun, X):\n", " \"\"\"Wykres funkcji `fun`\"\"\"\n", " ax = fig.axes[0]\n", " x0 = np.min(X[:, 1]) - 1.0\n", " x1 = np.max(X[:, 1]) + 1.0\n", " Arg = np.arange(x0, x1, 0.1)\n", " Val = fun(Arg)\n", " return ax.plot(Arg, Val, linewidth='2')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def MSE(Y_true, Y_pred):\n", " return np.square(np.subtract(Y_true,Y_pred)).mean()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# Funkcja regresji wielomianowej\n", "\n", "def h_poly(Theta, x):\n", " \"\"\"Funkcja wielomianowa\"\"\"\n", " return sum(theta * np.power(x, i) for i, theta in enumerate(Theta.tolist()))\n", "\n", "def get_poly_data(data, deg):\n", " m, n_plus_1 = data.shape\n", " n = n_plus_1 - 1\n", "\n", " X1 = data[:, 0:n]\n", " X1 /= np.amax(X1, axis=0)\n", "\n", " Xs = [np.ones((m, 1)), X1]\n", "\n", " for i in range(2, deg+1):\n", " Xn = np.power(X1, i)\n", " Xn /= np.amax(Xn, axis=0)\n", " Xs.append(Xn)\n", "\n", " X = np.matrix(np.concatenate(Xs, axis=1)).reshape(m, deg * n + 1)\n", "\n", " y = np.matrix(data[:, -1]).reshape(m, 1)\n", "\n", " return X, y\n", "\n", "\n", "def polynomial_regression(X, y, n):\n", " \"\"\"Funkcja regresji wielomianowej\"\"\"\n", " theta_start = np.matrix([0] * (n+1)).reshape(n+1, 1)\n", " theta, logs = gradient_descent(cost, gradient, theta_start, X, y)\n", " return lambda x: h_poly(theta, x), logs" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def predict_values(model, data, n):\n", " x, y = get_poly_data(np.array(data), n)\n", " preprocessed_x = []\n", " for i in x:\n", " preprocessed_x.append(i.item(1))\n", " return y, model(preprocessed_x), MSE(y, model(preprocessed_x))\n", "\n", "def plot_and_mse(data, data_test, n):\n", " x, y = get_poly_data(np.array(data), n)\n", " model, logs = polynomial_regression(x, y, n)\n", " cost_function = [[element[0], i] for i, element in enumerate(logs)]\n", " cost_functions[n] = cost_function\n", " \n", " fig = plot_data(x, y, xlabel='x', ylabel='y')\n", " plot_fun(fig, model, x)\n", "\n", " y_true, Y_pred, mse = predict_values(model, data_test, n)\n", " print(f'Wielomian {n} stopnia, MSE = {mse}')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "notes" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | sqrMetres | \n", "price | \n", "
---|---|---|
160 | \n", "44 | \n", "349668.0 | \n", "
1066 | \n", "54 | \n", "260000.0 | \n", "
679 | \n", "65 | \n", "348000.0 | \n", "
1589 | \n", "97 | \n", "579000.0 | \n", "
132 | \n", "60 | \n", "295120.0 | \n", "
... | \n", "... | \n", "... | \n", "
894 | \n", "68 | \n", "390000.0 | \n", "
937 | \n", "78 | \n", "329000.0 | \n", "
368 | \n", "16 | \n", "399000.0 | \n", "
1278 | \n", "51 | \n", "460499.0 | \n", "
1670 | \n", "53 | \n", "339000.0 | \n", "
1674 rows × 2 columns
\n", "\n", " | number_courses | \n", "time_study | \n", "Marks | \n", "
---|---|---|---|
0 | \n", "3 | \n", "4.508 | \n", "19.202 | \n", "
1 | \n", "4 | \n", "0.096 | \n", "7.734 | \n", "
2 | \n", "4 | \n", "3.133 | \n", "13.811 | \n", "
3 | \n", "6 | \n", "7.909 | \n", "53.018 | \n", "
4 | \n", "8 | \n", "7.811 | \n", "55.299 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "6 | \n", "3.561 | \n", "19.128 | \n", "
96 | \n", "3 | \n", "0.301 | \n", "5.609 | \n", "
97 | \n", "4 | \n", "7.163 | \n", "41.444 | \n", "
98 | \n", "7 | \n", "0.309 | \n", "12.027 | \n", "
99 | \n", "3 | \n", "6.335 | \n", "32.357 | \n", "
100 rows × 3 columns
\n", "\n", " | age | \n", "sex | \n", "bmi | \n", "children | \n", "smoker | \n", "region | \n", "charges | \n", "
---|---|---|---|---|---|---|---|
1235 | \n", "26 | \n", "male | \n", "31.065 | \n", "0 | \n", "no | \n", "northwest | \n", "2699.56835 | \n", "
234 | \n", "39 | \n", "male | \n", "24.510 | \n", "2 | \n", "no | \n", "northwest | \n", "6710.19190 | \n", "
540 | \n", "34 | \n", "female | \n", "38.000 | \n", "3 | \n", "no | \n", "southwest | \n", "6196.44800 | \n", "
1256 | \n", "51 | \n", "female | \n", "36.385 | \n", "3 | \n", "no | \n", "northwest | \n", "11436.73815 | \n", "
1257 | \n", "54 | \n", "female | \n", "27.645 | \n", "1 | \n", "no | \n", "northwest | \n", "11305.93455 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
752 | \n", "64 | \n", "male | \n", "37.905 | \n", "0 | \n", "no | \n", "northwest | \n", "14210.53595 | \n", "
954 | \n", "34 | \n", "male | \n", "27.835 | \n", "1 | \n", "yes | \n", "northwest | \n", "20009.63365 | \n", "
1137 | \n", "26 | \n", "female | \n", "22.230 | \n", "0 | \n", "no | \n", "northwest | \n", "3176.28770 | \n", "
106 | \n", "19 | \n", "female | \n", "28.400 | \n", "1 | \n", "no | \n", "southwest | \n", "2331.51900 | \n", "
274 | \n", "25 | \n", "male | \n", "27.550 | \n", "0 | \n", "no | \n", "northwest | \n", "2523.16950 | \n", "
1338 rows × 7 columns
\n", "