umz21/lab/07_scikit-learn.ipynb
2021-03-05 08:21:24 +01:00

3.4 KiB
Raw Blame History

Uczenie maszynowe 2019/2020 laboratoria

27/28 kwietnia 2020

7. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu _scikit-learn

Scikit-learn jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego.

Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem scikit-learn.

Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece.

#! /usr/bin/env python3
# -*- coding: utf-8 -*-

# Regresja liniowa wielu zmiennych

import csv
import numpy
import pandas
import sys

from sklearn import linear_model  # Model regresji liniowej z biblioteki scikit-learn


FEATURES = [
    'Powierzchnia w m2',
    'Liczba pokoi',
    'Liczba pięter w budynku',
    'Piętro',
    'Rok budowy',
]


def preprocess(data):
    """Wstępne przetworzenie danych"""
    data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)
    data = data.applymap(numpy.nan_to_num)  # Zamienia "NaN" na liczby
    return data

# Nazwy plików
input_filename = 'flats-test.tsv'
output_filename = 'flats-predicted.tsv'
trainset_filename = 'flats-train.tsv'

# Wczytanie danych uczących
data = pandas.read_csv(trainset_filename, header=0, sep='\t')
columns = data.columns[1:]  # wszystkie kolumny oprócz pierwszej ("cena")
data = data[FEATURES + ['cena']]  # wybór cech
data = preprocess(data)  # wstępne przetworzenie danych
y = pandas.DataFrame(data['cena'])
x = pandas.DataFrame(data[FEATURES])
model = linear_model.LinearRegression()  # definicja modelu
model.fit(x, y)  # dopasowanie modelu

# Wczytanie danych testowych
data = pandas.read_csv(input_filename, header=None, sep='\t', names=columns)
x = pandas.DataFrame(data[FEATURES])  # wybór cech
x = preprocess(x)  # wstępne przetworzenie danych
y = model.predict(x)  # przewidywania modelu

# Zapis wyników do pliku
pandas.DataFrame(y).to_csv(output_filename, index=None, header=None, sep='\t')