4.2 KiB
4.2 KiB
Uczenie maszynowe – zastosowania
Zajęcia laboratoryjne
4. Korzystanie z gotowych implementacji algorytmów na przykładzie pakietu _scikit-learn
Scikit-learn jest otwartoźródłową biblioteką programistyczną dla języka Python wspomagającą uczenie maszynowe. Zawiera implementacje wielu algorytmów uczenia maszynowego.
Poniżej przykład, jak stworzyć klasyfikator regresji liniowej wielu zmiennych z użyciem scikit-learn
.
Na podobnej zasadzie można korzystać z innych modeli dostępnych w bibliotece.
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression # Model regresji liniowej z biblioteki scikit-learn
from sklearn.metrics import mean_squared_error
FEATURES = [
'Powierzchnia w m2',
'Liczba pokoi',
'Liczba pięter w budynku',
'Piętro',
'Rok budowy',
]
def preprocess(data):
"""Wstępne przetworzenie danych"""
data = data.replace({'parter': 0, 'poddasze': 0}, regex=True)
data = data.applymap(np.nan_to_num) # Zamienia "NaN" na liczby
return data
# Nazwy plików
dataset_filename = 'flats.tsv'
# Wczytanie danych
data = pd.read_csv(dataset_filename, header=0, sep='\t')
data = data[FEATURES + ['cena']] # wybór cech
data = preprocess(data) # wstępne przetworzenie danych
# Podział danych na zbiory uczący i testowy
split_point = int(0.8 * len(data))
data_train = data[:split_point]
data_test = data[split_point:]
# Uczenie modelu
y_train = pd.DataFrame(data_train['cena'])
x_train = pd.DataFrame(data_train[FEATURES])
model = LinearRegression() # definicja modelu
model.fit(x_train, y_train) # dopasowanie modelu
# Predykcja wyników dla danych testowych
y_expected = pd.DataFrame(data_test['cena'])
x_test = pd.DataFrame(data_test[FEATURES])
y_predicted = model.predict(x_test) # predykcja wyników na podstawie modelu
print(y_predicted[:10]) # Pierwsze 10 wyników
# Ewaluacja
mse = mean_squared_error(y_predicted, y_expected) # Błąd średniokwadratowy na zbiorze testowym
print("Błąd średniokwadratowy wynosi ", mse)
[[289411.43360715] [285930.72623304] [229893.92602325] [823267.1750005 ] [821038.18583152] [356875.19267371] [409340.86981766] [278401.700237 ] [301680.27997255] [281051.71865054]] Błąd średniokwadratowy wynosi 39595039990.2324
Biblioteka _scikit-learn dostarcza również narzędzi do wstępnego przetwarzania danych, np. skalowania i normalizacji: https://scikit-learn.org/stable/modules/preprocessing.html