umz21/lab/2007_scikit-learn.ipynb
2021-03-02 08:32:40 +01:00

119 lines
3.4 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"## Uczenie maszynowe 2019/2020 laboratoria\n",
"### 27/28 kwietnia 2020\n",
"# 7. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu *scikit-learn*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Scikit-learn](https://scikit-learn.org) jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem `scikit-learn`.\n",
"\n",
"Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#! /usr/bin/env python3\n",
"# -*- coding: utf-8 -*-\n",
"\n",
"# Regresja liniowa wielu zmiennych\n",
"\n",
"import csv\n",
"import numpy\n",
"import pandas\n",
"import sys\n",
"\n",
"from sklearn import linear_model # Model regresji liniowej z biblioteki scikit-learn\n",
"\n",
"\n",
"FEATURES = [\n",
" 'Powierzchnia w m2',\n",
" 'Liczba pokoi',\n",
" 'Liczba pięter w budynku',\n",
" 'Piętro',\n",
" 'Rok budowy',\n",
"]\n",
"\n",
"\n",
"def preprocess(data):\n",
" \"\"\"Wstępne przetworzenie danych\"\"\"\n",
" data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)\n",
" data = data.applymap(numpy.nan_to_num) # Zamienia \"NaN\" na liczby\n",
" return data\n",
"\n",
"# Nazwy plików\n",
"input_filename = 'flats-test.tsv'\n",
"output_filename = 'flats-predicted.tsv'\n",
"trainset_filename = 'flats-train.tsv'\n",
"\n",
"# Wczytanie danych uczących\n",
"data = pandas.read_csv(trainset_filename, header=0, sep='\\t')\n",
"columns = data.columns[1:] # wszystkie kolumny oprócz pierwszej (\"cena\")\n",
"data = data[FEATURES + ['cena']] # wybór cech\n",
"data = preprocess(data) # wstępne przetworzenie danych\n",
"y = pandas.DataFrame(data['cena'])\n",
"x = pandas.DataFrame(data[FEATURES])\n",
"model = linear_model.LinearRegression() # definicja modelu\n",
"model.fit(x, y) # dopasowanie modelu\n",
"\n",
"# Wczytanie danych testowych\n",
"data = pandas.read_csv(input_filename, header=None, sep='\\t', names=columns)\n",
"x = pandas.DataFrame(data[FEATURES]) # wybór cech\n",
"x = preprocess(x) # wstępne przetworzenie danych\n",
"y = model.predict(x) # przewidywania modelu\n",
"\n",
"# Zapis wyników do pliku\n",
"pandas.DataFrame(y).to_csv(output_filename, index=None, header=None, sep='\\t')"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"livereveal": {
"start_slideshow_at": "selected",
"theme": "amu"
}
},
"nbformat": 4,
"nbformat_minor": 4
}