umz21/lab/04_scikit-learn.ipynb
2022-03-24 12:05:36 +01:00

154 lines
4.7 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"## Uczenie maszynowe zastosowania\n",
"### Zajęcia laboratoryjne\n",
"# 4. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu *scikit-learn*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Scikit-learn](https://scikit-learn.org) jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem `scikit-learn`.\n",
"\n",
"Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[302322.47270869]\n",
" [283694.74995925]\n",
" [276290.72977935]\n",
" [477362.89530745]\n",
" [420862.62245119]\n",
" [312510.3868097 ]\n",
" [362445.20969959]\n",
" [335753.83506582]\n",
" [759239.88142398]\n",
" [684376.72797254]]\n",
"Błąd średniokwadratowy wynosi 29811493540.217434\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn.linear_model import LinearRegression # Model regresji liniowej z biblioteki scikit-learn\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"\n",
"\n",
"FEATURES = [\n",
" 'Powierzchnia w m2',\n",
" 'Liczba pokoi',\n",
" 'Liczba pięter w budynku',\n",
" 'Piętro',\n",
" 'Rok budowy',\n",
" 'ładne w opisie'\n",
"]\n",
"\n",
"\n",
"def preprocess(data):\n",
" \"\"\"Wstępne przetworzenie danych, np. zamiana wartości tekstowych na liczby\"\"\"\n",
" data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)\n",
" data = data.applymap(np.nan_to_num) # Zamienia \"NaN\" na liczby\n",
" return data\n",
"\n",
"# Nazwy plików\n",
"dataset_filename = 'flats.tsv'\n",
"\n",
"# Wczytanie danych\n",
"data = pd.read_csv(dataset_filename, header=0, sep='\\t')\n",
"\n",
"# Jeżeli chcemy, możemy stworzyć nową cechę (kolumnę) na podstawie istniejącej\n",
"# Poniższa cecha mówi, czy kolumna \"opis\" zawiera słowo \"ładne\"\n",
"data['ładne w opisie'] = data['opis'].apply(\n",
" lambda x: True if 'ładne' in str(x) else False)\n",
"\n",
"data = data[FEATURES + ['cena']] # wybór cech\n",
"data = data[(data[\"Powierzchnia w m2\"] < 10000) & (data[\"cena\"] > 1000)] # Odrzucenie obserwacji odstających\n",
"data = preprocess(data) # wstępne przetworzenie danych\n",
"\n",
"# Podział danych na zbiory uczący i testowy\n",
"split_point = int(0.8 * len(data))\n",
"data_train = data[:split_point]\n",
"data_test = data[split_point:]\n",
"\n",
"# Uczenie modelu\n",
"y_train = pd.DataFrame(data_train['cena'])\n",
"x_train = pd.DataFrame(data_train[FEATURES])\n",
"model = LinearRegression() # definicja modelu\n",
"model.fit(x_train, y_train) # dopasowanie modelu\n",
"\n",
"# Predykcja wyników dla danych testowych\n",
"y_expected = pd.DataFrame(data_test['cena'])\n",
"x_test = pd.DataFrame(data_test[FEATURES])\n",
"y_predicted = model.predict(x_test) # predykcja wyników na podstawie modelu\n",
"\n",
"print(y_predicted[:10]) # Pierwsze 10 wyników\n",
"\n",
"# Ewaluacja\n",
"mse = mean_squared_error(y_predicted, y_expected) # Błąd średniokwadratowy na zbiorze testowym\n",
"\n",
"print(\"Błąd średniokwadratowy wynosi \", mse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Biblioteka *scikit-learn* dostarcza również narzędzi do wstępnego przetwarzania danych, np. skalowania i normalizacji: https://scikit-learn.org/stable/modules/preprocessing.html"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "amu"
}
},
"nbformat": 4,
"nbformat_minor": 4
}