umz21/lab/04_scikit-learn.ipynb
2022-03-24 10:35:07 +01:00

146 lines
4.2 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"## Uczenie maszynowe zastosowania\n",
"### Zajęcia laboratoryjne\n",
"# 4. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu *scikit-learn*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Scikit-learn](https://scikit-learn.org) jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem `scikit-learn`.\n",
"\n",
"Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[289411.43360715]\n",
" [285930.72623304]\n",
" [229893.92602325]\n",
" [823267.1750005 ]\n",
" [821038.18583152]\n",
" [356875.19267371]\n",
" [409340.86981766]\n",
" [278401.700237 ]\n",
" [301680.27997255]\n",
" [281051.71865054]]\n",
"Błąd średniokwadratowy wynosi 39595039990.2324\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn.linear_model import LinearRegression # Model regresji liniowej z biblioteki scikit-learn\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"\n",
"\n",
"FEATURES = [\n",
" 'Powierzchnia w m2',\n",
" 'Liczba pokoi',\n",
" 'Liczba pięter w budynku',\n",
" 'Piętro',\n",
" 'Rok budowy',\n",
"]\n",
"\n",
"\n",
"def preprocess(data):\n",
" \"\"\"Wstępne przetworzenie danych\"\"\"\n",
" data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)\n",
" data = data.applymap(np.nan_to_num) # Zamienia \"NaN\" na liczby\n",
" return data\n",
"\n",
"# Nazwy plików\n",
"dataset_filename = 'flats.tsv'\n",
"\n",
"# Wczytanie danych\n",
"data = pd.read_csv(dataset_filename, header=0, sep='\\t')\n",
"data = data[FEATURES + ['cena']] # wybór cech\n",
"data = preprocess(data) # wstępne przetworzenie danych\n",
"\n",
"# Podział danych na zbiory uczący i testowy\n",
"split_point = int(0.8 * len(data))\n",
"data_train = data[:split_point]\n",
"data_test = data[split_point:]\n",
"\n",
"# Uczenie modelu\n",
"y_train = pd.DataFrame(data_train['cena'])\n",
"x_train = pd.DataFrame(data_train[FEATURES])\n",
"model = LinearRegression() # definicja modelu\n",
"model.fit(x_train, y_train) # dopasowanie modelu\n",
"\n",
"# Predykcja wyników dla danych testowych\n",
"y_expected = pd.DataFrame(data_test['cena'])\n",
"x_test = pd.DataFrame(data_test[FEATURES])\n",
"y_predicted = model.predict(x_test) # predykcja wyników na podstawie modelu\n",
"\n",
"print(y_predicted[:10]) # Pierwsze 10 wyników\n",
"\n",
"# Ewaluacja\n",
"mse = mean_squared_error(y_predicted, y_expected) # Błąd średniokwadratowy na zbiorze testowym\n",
"\n",
"print(\"Błąd średniokwadratowy wynosi \", mse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Biblioteka *scikit-learn* dostarcza również narzędzi do wstępnego przetwarzania danych, np. skalowania i normalizacji: https://scikit-learn.org/stable/modules/preprocessing.html"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "amu"
}
},
"nbformat": 4,
"nbformat_minor": 4
}