{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "## Uczenie maszynowe – zastosowania\n", "### Zajęcia laboratoryjne\n", "# 4. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu *scikit-learn*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Scikit-learn](https://scikit-learn.org) jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem `scikit-learn`.\n", "\n", "Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[302322.47270869]\n", " [283694.74995925]\n", " [276290.72977935]\n", " [477362.89530745]\n", " [420862.62245119]\n", " [312510.3868097 ]\n", " [362445.20969959]\n", " [335753.83506582]\n", " [759239.88142398]\n", " [684376.72797254]]\n", "Błąd średniokwadratowy wynosi 29811493540.217434\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from sklearn.linear_model import LinearRegression # Model regresji liniowej z biblioteki scikit-learn\n", "\n", "from sklearn.metrics import mean_squared_error\n", "\n", "\n", "FEATURES = [\n", " 'Powierzchnia w m2',\n", " 'Liczba pokoi',\n", " 'Liczba pięter w budynku',\n", " 'Piętro',\n", " 'Rok budowy',\n", " 'ładne w opisie'\n", "]\n", "\n", "\n", "def preprocess(data):\n", " \"\"\"Wstępne przetworzenie danych, np. zamiana wartości tekstowych na liczby\"\"\"\n", " data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)\n", " data = data.applymap(np.nan_to_num) # Zamienia \"NaN\" na liczby\n", " return data\n", "\n", "# Nazwy plików\n", "dataset_filename = 'flats.tsv'\n", "\n", "# Wczytanie danych\n", "data = pd.read_csv(dataset_filename, header=0, sep='\\t')\n", "\n", "# Jeżeli chcemy, możemy stworzyć nową cechę (kolumnę) na podstawie istniejącej\n", "# Poniższa cecha mówi, czy kolumna \"opis\" zawiera słowo \"ładne\"\n", "data['ładne w opisie'] = data['opis'].apply(\n", " lambda x: True if 'ładne' in str(x) else False)\n", "\n", "data = data[FEATURES + ['cena']] # wybór cech\n", "data = data[(data[\"Powierzchnia w m2\"] < 10000) & (data[\"cena\"] > 1000)] # Odrzucenie obserwacji odstających\n", "data = preprocess(data) # wstępne przetworzenie danych\n", "\n", "# Podział danych na zbiory uczący i testowy\n", "split_point = int(0.8 * len(data))\n", "data_train = data[:split_point]\n", "data_test = data[split_point:]\n", "\n", "# Uczenie modelu\n", "y_train = pd.DataFrame(data_train['cena'])\n", "x_train = pd.DataFrame(data_train[FEATURES])\n", "model = LinearRegression() # definicja modelu\n", "model.fit(x_train, y_train) # dopasowanie modelu\n", "\n", "# Predykcja wyników dla danych testowych\n", "y_expected = pd.DataFrame(data_test['cena'])\n", "x_test = pd.DataFrame(data_test[FEATURES])\n", "y_predicted = model.predict(x_test) # predykcja wyników na podstawie modelu\n", "\n", "print(y_predicted[:10]) # Pierwsze 10 wyników\n", "\n", "# Ewaluacja\n", "mse = mean_squared_error(y_predicted, y_expected) # Błąd średniokwadratowy na zbiorze testowym\n", "\n", "print(\"Błąd średniokwadratowy wynosi \", mse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Biblioteka *scikit-learn* dostarcza również narzędzi do wstępnego przetwarzania danych, np. skalowania i normalizacji: https://scikit-learn.org/stable/modules/preprocessing.html" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "livereveal": { "start_slideshow_at": "selected", "theme": "amu" } }, "nbformat": 4, "nbformat_minor": 4 }