umz21/lab/04_scikit-learn.ipynb
2022-03-24 10:35:07 +01:00

4.2 KiB
Raw Blame History

Uczenie maszynowe zastosowania

Zajęcia laboratoryjne

4. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu _scikit-learn

Scikit-learn jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego.

Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem scikit-learn.

Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece.

import numpy as np
import pandas as pd

from sklearn.linear_model import LinearRegression  # Model regresji liniowej z biblioteki scikit-learn

from sklearn.metrics import mean_squared_error


FEATURES = [
    'Powierzchnia w m2',
    'Liczba pokoi',
    'Liczba pięter w budynku',
    'Piętro',
    'Rok budowy',
]


def preprocess(data):
    """Wstępne przetworzenie danych"""
    data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)
    data = data.applymap(np.nan_to_num)  # Zamienia "NaN" na liczby
    return data

# Nazwy plików
dataset_filename = 'flats.tsv'

# Wczytanie danych
data = pd.read_csv(dataset_filename, header=0, sep='\t')
data = data[FEATURES + ['cena']]  # wybór cech
data = preprocess(data)  # wstępne przetworzenie danych

# Podział danych na zbiory uczący i testowy
split_point = int(0.8 * len(data))
data_train = data[:split_point]
data_test = data[split_point:]

# Uczenie modelu
y_train = pd.DataFrame(data_train['cena'])
x_train = pd.DataFrame(data_train[FEATURES])
model = LinearRegression()  # definicja modelu
model.fit(x_train, y_train)  # dopasowanie modelu

# Predykcja wyników dla danych testowych
y_expected = pd.DataFrame(data_test['cena'])
x_test = pd.DataFrame(data_test[FEATURES])
y_predicted = model.predict(x_test)  # predykcja wyników na podstawie modelu

print(y_predicted[:10])  # Pierwsze 10 wyników

# Ewaluacja
mse = mean_squared_error(y_predicted, y_expected)  # Błąd średniokwadratowy na zbiorze testowym

print("Błąd średniokwadratowy wynosi ", mse)
[[289411.43360715]
 [285930.72623304]
 [229893.92602325]
 [823267.1750005 ]
 [821038.18583152]
 [356875.19267371]
 [409340.86981766]
 [278401.700237  ]
 [301680.27997255]
 [281051.71865054]]
Błąd średniokwadratowy wynosi  39595039990.2324

Biblioteka _scikit-learn dostarcza również narzędzi do wstępnego przetwarzania danych, np. skalowania i normalizacji: https://scikit-learn.org/stable/modules/preprocessing.html