{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Analiza danych w Pythonie: sklearn\n", "\n", "### Tomasz Dwojak\n", "\n", "### 3 czerwca 2018" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ " * Pierwsza część: pandas\n", " * Druga część: sklearn" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Przypomnienie z UMZ\n", " * przygotowanie i czyszczenie danych\n", " * wybór i trening modelu\n", " * tuning\n", " * ewaluacja" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "import sklearn\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "data = pd.read_csv(\"./gapminder.csv\", index_col=0)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
female_BMImale_BMIgdppopulationunder5mortalitylife_expectancyfertility
Afghanistan21.0740220.620581311.026528741.0110.452.86.20
Albania25.6572626.446578644.02968026.017.976.81.76
Algeria26.3684124.5962012314.034811059.029.575.52.73
Angola23.4843122.250837103.019842251.0192.056.76.43
Antigua and Barbuda27.5054525.7660225736.085350.010.975.52.16
\n", "
" ], "text/plain": [ " female_BMI male_BMI gdp population \\\n", "Afghanistan 21.07402 20.62058 1311.0 26528741.0 \n", "Albania 25.65726 26.44657 8644.0 2968026.0 \n", "Algeria 26.36841 24.59620 12314.0 34811059.0 \n", "Angola 23.48431 22.25083 7103.0 19842251.0 \n", "Antigua and Barbuda 27.50545 25.76602 25736.0 85350.0 \n", "\n", " under5mortality life_expectancy fertility \n", "Afghanistan 110.4 52.8 6.20 \n", "Albania 17.9 76.8 1.76 \n", "Algeria 29.5 75.5 2.73 \n", "Angola 192.0 56.7 6.43 \n", "Antigua and Barbuda 10.9 75.5 2.16 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "y = data['life_expectancy']\n", "X = data.drop('life_expectancy', axis=1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "train_X, test_X, train_y, test_y = \\\n", " train_test_split(X, y, test_size=0.2, random_state=123, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LinearRegression\n", "model = LinearRegression()\n", "model.fit(X,y)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "array([67.56279809, 76.25840076, 50.21126326, 59.21303855, 72.06348723])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted = model.predict(test_X)\n", "predicted[:5]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE: 3.5179543848147863\n" ] } ], "source": [ "from sklearn.metrics import mean_squared_error\n", "rmse = np.sqrt(mean_squared_error(predicted, test_y))\n", "print(\"RMSE:\", rmse)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "0.795295000468209" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ " r2 = model.score(test_X, test_y)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### API\n", " * model\n", " * `fix`\n", " * `predict`" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "female_BMI: -1.18\n", "male_BMI: 1.46\n", "gdp: 5.11e-05\n", "population: 7.21e-10\n", "under5mortality: -0.159\n", "fertility: 0.421\n" ] } ], "source": [ "for p in zip(train_X.columns, model.coef_):\n", " print(\"{}: {:.3}\".format(p[0], p[1]))" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/lib/python3.6/site-packages/ipykernel_launcher.py:2: FutureWarning: reshape is deprecated and will raise in a subsequent release. Please use .values.reshape(...) instead\n", " \n" ] }, { "data": { "text/plain": [ "LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model2 = LinearRegression()\n", "model2.fit(train_X['male_BMI'].reshape(-1, 1), train_y)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "0.5852413468462743" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model2.intercept_" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/lib/python3.6/site-packages/ipykernel_launcher.py:5: FutureWarning: reshape is deprecated and will raise in a subsequent release. Please use .values.reshape(...) instead\n", " \"\"\"\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.scatter(train_X['male_BMI'], train_y,color='g')\n", "plt.plot(train_X['male_BMI'], model2.predict(train_X['male_BMI'].reshape(-1, 1)),color='k')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }