um-projekt/regression.ipynb
2021-06-15 12:56:28 +02:00

40 KiB
Raw Permalink Blame History

import pandas as pd
from pandas import DataFrame
from sklearn import preprocessing
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn import ensemble
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

Cel projektu

Celem projektu jest stworzenie różnych modeli, których zadanie polega na predykcji cen poszczególnych samochodów na podstawie danych takich jak: - rok produkcji - przebieg - marka - rodzaj silnika - pojemność silnika

Wczytywanie danych

Zbiór zawiera listę samochodów, wraz z ich najważniejszymi cechami. Rozmiar zbioru: 47930 wierszy × 5 kolumn

col_names = ["price", "mileage", "year", "brand", "engine_type", "engine_cap"]
col_names_in = ["mileage", "year", "brand", "engine_type", "engine_cap"]
df_train = pd.read_csv(
    "train/train.tsv", error_bad_lines=False, header=None, sep="\t", names=col_names
)
df = df_train
test = pd.read_csv(
    "dev-0/in.tsv", error_bad_lines=False, header=None, sep="\t", names=col_names_in
)


test_expected = pd.read_csv("dev-0/expected.tsv", error_bad_lines=False, header=None, sep="\t")
df.head(10)
mileage year brand engine_type engine_cap
0 29.077465 1.0060 volvo benzyna 960.0
1 19.800027 1.0080 kia diesel 418.8
2 27.916642 1.0075 toyota diesel 420.0
3 32.976864 1.0075 skoda diesel 480.0
4 38.932205 1.0060 renault diesel 600.0
5 32.089045 1.0070 opel diesel 390.0
6 31.997055 1.0065 mercedes-benz diesel 900.0
7 26.138626 1.0075 ford benzyna 300.0
8 38.142843 1.0015 seat diesel 570.0
9 26.715448 1.0060 mercedes-benz diesel 642.9
df
mileage year brand engine_type engine_cap
0 29.077465 1.0060 volvo benzyna 960.0
1 19.800027 1.0080 kia diesel 418.8
2 27.916642 1.0075 toyota diesel 420.0
3 32.976864 1.0075 skoda diesel 480.0
4 38.932205 1.0060 renault diesel 600.0
... ... ... ... ... ...
47997 25.530471 1.0055 mini benzyna 479.4
47998 41.875698 1.0020 mercedes-benz diesel 644.4
47999 40.061463 1.0025 mercedes-benz diesel 506.7
48000 40.809827 1.0010 mercedes-benz diesel 644.4
48001 38.337702 1.0035 mercedes-benz diesel 896.1

47930 rows × 5 columns

Preprocessing danych

1. Dane odstające

Na początku zostały usunięte dane odstające, takie jak auta, których cena jest poniżej tysiąca, lub których przebieg jest wyższy niż 900 000km.

Y_test = test_expected[0]


# Drop rows which have strange value
brands = df.brand.value_counts()[:35].index.tolist()
indexes = df_train[(df_train.price < 1000) & (df_train.price > 1)].index
df_train.drop(indexes, inplace=True)

index = df_train[(df_train.mileage > 900000)].index
df_train.drop(index, inplace=True)

Y_train = df_train["price"]
df_train.drop("price", axis=1, inplace=True)


2. Normalizacja danych liczbowych

Dane takie jak rok, przebieg czy pojemność silnika zostały znacząco zredukowane

3. Lowercase nazw producentów

Nazwy producentów zostały zapisane wyłącznie małymi literami

4. Utworzenie 'dummies'

Zostały utworzone kolumny dla każdej z marek przyjmujące wartość (0,1)

5. Utworzenie wielomianiu stopnia 2

Z wykorzystaniem biblioteki sklearn.preprocessing

def preprocess_data(df: DataFrame, brands: list) -> DataFrame:
    """Prepare dataset to linear regression"""

    df.brand = df.brand.apply(lambda x: x if x in brands else "0")
    df["year"] = df.year / 2000
    df["mileage"] = df.mileage ** 0.3
    df["engine_cap"] = df.engine_cap * 0.3
    df["brand"] = df["brand"].str.lower()

    df = pd.get_dummies(df, columns=["brand", "engine_type"])

    scaler = preprocessing.RobustScaler()
    df[["mileage", "year", "engine_cap", "year"]] = scaler.fit_transform(
        df[["mileage", "year", "engine_cap", "year"]]
    )

    poly = PolynomialFeatures(2, interaction_only=True)
    df = poly.fit_transform(df)

    return df
X_train = preprocess_data(df_train, brands)
X_test = preprocess_data(test, brands)

Regresja liniowa

Implementacja regresji liniowej za pomocą biblioteki sklearn

RMSE: 22065.84

MSE: 486901471.27

# Load model and fit data
lm_model = LinearRegression()
lm_model.fit(X_train, Y_train)

# Predict
lr_test_predicted = lm_model.predict(X_test)

# Predicted values to tsv
print("RMSE: ", mean_squared_error(Y_test, lr_test_predicted, squared=False))
print("MSE: ", mean_squared_error(Y_test, lr_test_predicted))
RMSE:  22065.843996457512
MSE:  486901471.276

Sieć neuronowa (Keras)

batch size: 64, epochs: 100, 3 x ReLU optimizer: adam, loss: mean squared error

RMSE: 18558.15

MSE: 344404977.04

input_layer = Input(shape=(X_train.shape[1]))
dense_layer_1 = Dense(100, activation='relu')(input_layer)
dense_layer_2 = Dense(50, activation='relu')(dense_layer_1)
dense_layer_3 = Dense(25, activation='relu')(dense_layer_2)
output = Dense(1)(dense_layer_3)

model = Model(inputs=input_layer, outputs=output)
model.compile(loss="mean_squared_error", optimizer="adam", metrics=["mean_squared_error"])


model.fit(X_train, Y_train, batch_size=64, epochs=100, verbose=1, validation_split=0.2)
y_pred = model.predict(X_test)

print(f"RMSE: {mean_squared_error(Y_test, y_pred, squared=False)}")
print(f"MSE: {mean_squared_error(Y_test, y_pred)}")
Epoch 1/100
600/600 [==============================] - 5s 7ms/step - loss: 2673221721.0250 - mean_squared_error: 2673221721.0250 - val_loss: 551377856.0000 - val_mean_squared_error: 551377856.0000
Epoch 2/100
600/600 [==============================] - 3s 6ms/step - loss: 442345591.4276 - mean_squared_error: 442345591.4276 - val_loss: 423978976.0000 - val_mean_squared_error: 423978976.0000
Epoch 3/100
600/600 [==============================] - 3s 6ms/step - loss: 378405340.5923 - mean_squared_error: 378405340.5923 - val_loss: 378884192.0000 - val_mean_squared_error: 378884192.0000
Epoch 4/100
600/600 [==============================] - 4s 7ms/step - loss: 384769043.7271 - mean_squared_error: 384769043.7271 - val_loss: 357286848.0000 - val_mean_squared_error: 357286848.0000
Epoch 5/100
600/600 [==============================] - 4s 6ms/step - loss: 343957554.7155 - mean_squared_error: 343957554.7155 - val_loss: 340152512.0000 - val_mean_squared_error: 340152512.0000
Epoch 6/100
600/600 [==============================] - 4s 7ms/step - loss: 337887713.4376 - mean_squared_error: 337887713.4376 - val_loss: 331659872.0000 - val_mean_squared_error: 331659872.0000
Epoch 7/100
600/600 [==============================] - 4s 6ms/step - loss: 313163787.1281 - mean_squared_error: 313163787.1281 - val_loss: 323213504.0000 - val_mean_squared_error: 323213504.0000
Epoch 8/100
600/600 [==============================] - 4s 6ms/step - loss: 314339570.3428 - mean_squared_error: 314339570.3428 - val_loss: 317251488.0000 - val_mean_squared_error: 317251488.0000
Epoch 9/100
600/600 [==============================] - 5s 9ms/step - loss: 303178196.5524 - mean_squared_error: 303178196.5524 - val_loss: 314496736.0000 - val_mean_squared_error: 314496736.0000
Epoch 10/100
600/600 [==============================] - 5s 8ms/step - loss: 313794526.9351 - mean_squared_error: 313794526.9351 - val_loss: 310654176.0000 - val_mean_squared_error: 310654176.0000
Epoch 11/100
600/600 [==============================] - 4s 7ms/step - loss: 284679367.4542 - mean_squared_error: 284679367.4542 - val_loss: 304685248.0000 - val_mean_squared_error: 304685248.0000
Epoch 12/100
600/600 [==============================] - 4s 7ms/step - loss: 311546194.3161 - mean_squared_error: 311546194.3161 - val_loss: 304376256.0000 - val_mean_squared_error: 304376256.0000
Epoch 13/100
600/600 [==============================] - 4s 7ms/step - loss: 286383306.4892 - mean_squared_error: 286383306.4892 - val_loss: 303079392.0000 - val_mean_squared_error: 303079392.0000
Epoch 14/100
600/600 [==============================] - 4s 7ms/step - loss: 312419505.4110 - mean_squared_error: 312419505.4110 - val_loss: 296362720.0000 - val_mean_squared_error: 296362720.0000
Epoch 15/100
600/600 [==============================] - 4s 7ms/step - loss: 281224970.5957 - mean_squared_error: 281224970.5957 - val_loss: 295127040.0000 - val_mean_squared_error: 295127040.0000
Epoch 16/100
600/600 [==============================] - 5s 8ms/step - loss: 300456786.4759 - mean_squared_error: 300456786.4759 - val_loss: 291579264.0000 - val_mean_squared_error: 291579264.0000
Epoch 17/100
600/600 [==============================] - 4s 6ms/step - loss: 271273312.1864 - mean_squared_error: 271273312.1864 - val_loss: 293092064.0000 - val_mean_squared_error: 293092064.0000
Epoch 18/100
600/600 [==============================] - 4s 6ms/step - loss: 274466717.5241 - mean_squared_error: 274466717.5241 - val_loss: 291955424.0000 - val_mean_squared_error: 291955424.0000
Epoch 19/100
600/600 [==============================] - 4s 6ms/step - loss: 280078536.3328 - mean_squared_error: 280078536.3328 - val_loss: 286574528.0000 - val_mean_squared_error: 286574528.0000
Epoch 20/100
600/600 [==============================] - 4s 7ms/step - loss: 283455574.6290 - mean_squared_error: 283455574.6290 - val_loss: 283341472.0000 - val_mean_squared_error: 283341472.0000
Epoch 21/100
600/600 [==============================] - 4s 6ms/step - loss: 269008367.3078 - mean_squared_error: 269008367.3078 - val_loss: 287479776.0000 - val_mean_squared_error: 287479776.0000
Epoch 22/100
600/600 [==============================] - 4s 6ms/step - loss: 285307228.4326 - mean_squared_error: 285307228.4326 - val_loss: 281901632.0000 - val_mean_squared_error: 281901632.0000
Epoch 23/100
600/600 [==============================] - 4s 6ms/step - loss: 270041985.1448 - mean_squared_error: 270041985.1448 - val_loss: 285430688.0000 - val_mean_squared_error: 285430688.0000
Epoch 24/100
600/600 [==============================] - 4s 7ms/step - loss: 287381889.8902 - mean_squared_error: 287381889.8902 - val_loss: 283002208.0000 - val_mean_squared_error: 283002208.0000
Epoch 25/100
600/600 [==============================] - 4s 7ms/step - loss: 290092397.6839 - mean_squared_error: 290092397.6839 - val_loss: 281590592.0000 - val_mean_squared_error: 281590592.0000
Epoch 26/100
600/600 [==============================] - 4s 7ms/step - loss: 287577090.5291 - mean_squared_error: 287577090.5291 - val_loss: 277902464.0000 - val_mean_squared_error: 277902464.0000
Epoch 27/100
600/600 [==============================] - 4s 7ms/step - loss: 272385114.1165 - mean_squared_error: 272385114.1165 - val_loss: 280177056.0000 - val_mean_squared_error: 280177056.0000
Epoch 28/100
600/600 [==============================] - 4s 7ms/step - loss: 257438328.8785 - mean_squared_error: 257438328.8785 - val_loss: 284091104.0000 - val_mean_squared_error: 284091104.0000
Epoch 29/100
600/600 [==============================] - 3s 6ms/step - loss: 276722888.6256 - mean_squared_error: 276722888.6256 - val_loss: 277816032.0000 - val_mean_squared_error: 277816032.0000
Epoch 30/100
600/600 [==============================] - 3s 6ms/step - loss: 271698972.8586 - mean_squared_error: 271698972.8586 - val_loss: 281744256.0000 - val_mean_squared_error: 281744256.0000
Epoch 31/100
600/600 [==============================] - 4s 7ms/step - loss: 277800460.3261 - mean_squared_error: 277800460.3261 - val_loss: 275767552.0000 - val_mean_squared_error: 275767552.0000
Epoch 32/100
600/600 [==============================] - 5s 8ms/step - loss: 255387858.4093 - mean_squared_error: 255387858.4093 - val_loss: 274004512.0000 - val_mean_squared_error: 274004512.0000
Epoch 33/100
600/600 [==============================] - 4s 7ms/step - loss: 267013800.2263 - mean_squared_error: 267013800.2263 - val_loss: 274496832.0000 - val_mean_squared_error: 274496832.0000
Epoch 34/100
600/600 [==============================] - 4s 6ms/step - loss: 270699375.3611 - mean_squared_error: 270699375.3611 - val_loss: 276478944.0000 - val_mean_squared_error: 276478944.0000
Epoch 35/100
600/600 [==============================] - 4s 6ms/step - loss: 279694474.7288 - mean_squared_error: 279694474.7288 - val_loss: 279028160.0000 - val_mean_squared_error: 279028160.0000
Epoch 36/100
600/600 [==============================] - 4s 7ms/step - loss: 270710719.1747 - mean_squared_error: 270710719.1747 - val_loss: 273949600.0000 - val_mean_squared_error: 273949600.0000
Epoch 37/100
600/600 [==============================] - 5s 8ms/step - loss: 272804902.7354 - mean_squared_error: 272804902.7354 - val_loss: 274979104.0000 - val_mean_squared_error: 274979104.0000
Epoch 38/100
600/600 [==============================] - 5s 8ms/step - loss: 254984751.6805 - mean_squared_error: 254984751.6805 - val_loss: 278099008.0000 - val_mean_squared_error: 278099008.0000
Epoch 39/100
600/600 [==============================] - 4s 7ms/step - loss: 263644632.1597 - mean_squared_error: 263644632.1597 - val_loss: 275570400.0000 - val_mean_squared_error: 275570400.0000
Epoch 40/100
600/600 [==============================] - 4s 7ms/step - loss: 283981970.5824 - mean_squared_error: 283981970.5824 - val_loss: 269600896.0000 - val_mean_squared_error: 269600896.0000
Epoch 41/100
600/600 [==============================] - 4s 6ms/step - loss: 263011782.5225 - mean_squared_error: 263011782.5225 - val_loss: 270043744.0000 - val_mean_squared_error: 270043744.0000
Epoch 42/100
600/600 [==============================] - 4s 6ms/step - loss: 275432014.8286 - mean_squared_error: 275432014.8286 - val_loss: 268776480.0000 - val_mean_squared_error: 268776480.0000
Epoch 43/100
600/600 [==============================] - 3s 6ms/step - loss: 260651440.1864 - mean_squared_error: 260651440.1864 - val_loss: 275194144.0000 - val_mean_squared_error: 275194144.0000
Epoch 44/100
600/600 [==============================] - 4s 7ms/step - loss: 257748764.5125 - mean_squared_error: 257748764.5125 - val_loss: 270911072.0000 - val_mean_squared_error: 270911072.0000
Epoch 45/100
600/600 [==============================] - 3s 6ms/step - loss: 266450056.8918 - mean_squared_error: 266450056.8918 - val_loss: 270361472.0000 - val_mean_squared_error: 270361472.0000
Epoch 46/100
600/600 [==============================] - 3s 5ms/step - loss: 267280017.8369 - mean_squared_error: 267280017.8369 - val_loss: 268170224.0000 - val_mean_squared_error: 268170224.0000
Epoch 47/100
600/600 [==============================] - 4s 6ms/step - loss: 270953393.1048 - mean_squared_error: 270953393.1048 - val_loss: 266962048.0000 - val_mean_squared_error: 266962048.0000
Epoch 48/100
600/600 [==============================] - 5s 8ms/step - loss: 261569597.8436 - mean_squared_error: 261569597.8436 - val_loss: 270642752.0000 - val_mean_squared_error: 270642752.0000
Epoch 49/100
600/600 [==============================] - 5s 9ms/step - loss: 252863808.0799 - mean_squared_error: 252863808.0799 - val_loss: 264875584.0000 - val_mean_squared_error: 264875584.0000
Epoch 50/100
600/600 [==============================] - 4s 7ms/step - loss: 269732835.3677 - mean_squared_error: 269732835.3677 - val_loss: 265078368.0000 - val_mean_squared_error: 265078368.0000
Epoch 51/100
600/600 [==============================] - 4s 7ms/step - loss: 277777046.5225 - mean_squared_error: 277777046.5225 - val_loss: 265569424.0000 - val_mean_squared_error: 265569424.0000
Epoch 52/100
600/600 [==============================] - 4s 7ms/step - loss: 259421935.3611 - mean_squared_error: 259421935.3611 - val_loss: 263121728.0000 - val_mean_squared_error: 263121728.0000
Epoch 53/100
600/600 [==============================] - 4s 7ms/step - loss: 246818920.4126 - mean_squared_error: 246818920.4126 - val_loss: 268283376.0000 - val_mean_squared_error: 268283376.0000
Epoch 54/100
600/600 [==============================] - 4s 7ms/step - loss: 262059519.1747 - mean_squared_error: 262059519.1747 - val_loss: 264587712.0000 - val_mean_squared_error: 264587712.0000
Epoch 55/100
600/600 [==============================] - 5s 8ms/step - loss: 251146320.5857 - mean_squared_error: 251146320.5857 - val_loss: 264188048.0000 - val_mean_squared_error: 264188048.0000
Epoch 56/100
600/600 [==============================] - 5s 8ms/step - loss: 277728213.8569 - mean_squared_error: 277728213.8569 - val_loss: 265315792.0000 - val_mean_squared_error: 265315792.0000
Epoch 57/100
600/600 [==============================] - 5s 8ms/step - loss: 273021954.3694 - mean_squared_error: 273021954.3694 - val_loss: 265453232.0000 - val_mean_squared_error: 265453232.0000
Epoch 58/100
600/600 [==============================] - 5s 7ms/step - loss: 235758602.8619 - mean_squared_error: 235758602.8619 - val_loss: 267418880.0000 - val_mean_squared_error: 267418880.0000
Epoch 59/100
600/600 [==============================] - 4s 7ms/step - loss: 253989512.0932 - mean_squared_error: 253989512.0932 - val_loss: 263675520.0000 - val_mean_squared_error: 263675520.0000
Epoch 60/100
600/600 [==============================] - 4s 7ms/step - loss: 262297644.0067 - mean_squared_error: 262297644.0067 - val_loss: 260217264.0000 - val_mean_squared_error: 260217264.0000
Epoch 61/100
600/600 [==============================] - 5s 8ms/step - loss: 265082348.2596 - mean_squared_error: 265082348.2596 - val_loss: 262431664.0000 - val_mean_squared_error: 262431664.0000
Epoch 62/100
600/600 [==============================] - 5s 8ms/step - loss: 257870115.1681 - mean_squared_error: 257870115.1681 - val_loss: 262881312.0000 - val_mean_squared_error: 262881312.0000
Epoch 63/100
600/600 [==============================] - 4s 7ms/step - loss: 240457727.4143 - mean_squared_error: 240457727.4143 - val_loss: 261733392.0000 - val_mean_squared_error: 261733392.0000
Epoch 64/100
600/600 [==============================] - 5s 8ms/step - loss: 277940927.2013 - mean_squared_error: 277940927.2013 - val_loss: 260641392.0000 - val_mean_squared_error: 260641392.0000
Epoch 65/100
600/600 [==============================] - 5s 9ms/step - loss: 243245328.8120 - mean_squared_error: 243245328.8120 - val_loss: 266007856.0000 - val_mean_squared_error: 266007856.0000
Epoch 66/100
600/600 [==============================] - 5s 8ms/step - loss: 280276759.2679 - mean_squared_error: 280276759.2679 - val_loss: 260283456.0000 - val_mean_squared_error: 260283456.0000
Epoch 67/100
600/600 [==============================] - 5s 8ms/step - loss: 256323983.0948 - mean_squared_error: 256323983.0948 - val_loss: 260369344.0000 - val_mean_squared_error: 260369344.0000
Epoch 68/100
600/600 [==============================] - 6s 10ms/step - loss: 257920354.0499 - mean_squared_error: 257920354.0499 - val_loss: 259491872.0000 - val_mean_squared_error: 259491872.0000
Epoch 69/100
600/600 [==============================] - 5s 8ms/step - loss: 248803245.8436 - mean_squared_error: 248803245.8436 - val_loss: 260577376.0000 - val_mean_squared_error: 260577376.0000
Epoch 70/100
600/600 [==============================] - 5s 9ms/step - loss: 262326159.4676 - mean_squared_error: 262326159.4676 - val_loss: 261333040.0000 - val_mean_squared_error: 261333040.0000
Epoch 71/100
600/600 [==============================] - 5s 8ms/step - loss: 244443427.6473 - mean_squared_error: 244443427.6473 - val_loss: 260864640.0000 - val_mean_squared_error: 260864640.0000
Epoch 72/100
600/600 [==============================] - 5s 8ms/step - loss: 251957369.0782 - mean_squared_error: 251957369.0782 - val_loss: 260814608.0000 - val_mean_squared_error: 260814608.0000
Epoch 73/100
600/600 [==============================] - 5s 7ms/step - loss: 258517712.4792 - mean_squared_error: 258517712.4792 - val_loss: 258114464.0000 - val_mean_squared_error: 258114464.0000
Epoch 74/100
600/600 [==============================] - 5s 8ms/step - loss: 219879441.5308 - mean_squared_error: 219879441.5308 - val_loss: 262360560.0000 - val_mean_squared_error: 262360560.0000
Epoch 75/100
600/600 [==============================] - 5s 8ms/step - loss: 260794823.4143 - mean_squared_error: 260794823.4143 - val_loss: 256498032.0000 - val_mean_squared_error: 256498032.0000
Epoch 76/100
600/600 [==============================] - 5s 9ms/step - loss: 239282565.3511 - mean_squared_error: 239282565.3511 - val_loss: 262098640.0000 - val_mean_squared_error: 262098640.0000
Epoch 77/100
600/600 [==============================] - 5s 8ms/step - loss: 257229255.3478 - mean_squared_error: 257229255.3478 - val_loss: 259628224.0000 - val_mean_squared_error: 259628224.0000
Epoch 78/100
600/600 [==============================] - 5s 9ms/step - loss: 244059795.5541 - mean_squared_error: 244059795.5541 - val_loss: 259398896.0000 - val_mean_squared_error: 259398896.0000
Epoch 79/100
600/600 [==============================] - 5s 8ms/step - loss: 255108445.9767 - mean_squared_error: 255108445.9767 - val_loss: 259564784.0000 - val_mean_squared_error: 259564784.0000
Epoch 80/100
600/600 [==============================] - 5s 9ms/step - loss: 252858985.0250 - mean_squared_error: 252858985.0250 - val_loss: 261315536.0000 - val_mean_squared_error: 261315536.0000
Epoch 81/100
600/600 [==============================] - 6s 9ms/step - loss: 251545596.1664 - mean_squared_error: 251545596.1664 - val_loss: 259559184.0000 - val_mean_squared_error: 259559184.0000
Epoch 82/100
600/600 [==============================] - 6s 10ms/step - loss: 253448548.1464 - mean_squared_error: 253448548.1464 - val_loss: 257081360.0000 - val_mean_squared_error: 257081360.0000
Epoch 83/100
600/600 [==============================] - 7s 11ms/step - loss: 223692804.4592 - mean_squared_error: 223692804.4592 - val_loss: 260599392.0000 - val_mean_squared_error: 260599392.0000
Epoch 84/100
600/600 [==============================] - 6s 10ms/step - loss: 238604269.9767 - mean_squared_error: 238604269.9767 - val_loss: 259515504.0000 - val_mean_squared_error: 259515504.0000
Epoch 85/100
600/600 [==============================] - 4s 6ms/step - loss: 239357600.3993 - mean_squared_error: 239357600.3993 - val_loss: 258469696.0000 - val_mean_squared_error: 258469696.0000
Epoch 86/100
600/600 [==============================] - 4s 7ms/step - loss: 250585435.0483 - mean_squared_error: 250585435.0483 - val_loss: 257148032.0000 - val_mean_squared_error: 257148032.0000
Epoch 87/100
600/600 [==============================] - 4s 6ms/step - loss: 241135506.1564 - mean_squared_error: 241135506.1564 - val_loss: 255790992.0000 - val_mean_squared_error: 255790992.0000
Epoch 88/100
600/600 [==============================] - 4s 7ms/step - loss: 236478667.7005 - mean_squared_error: 236478667.7005 - val_loss: 255462960.0000 - val_mean_squared_error: 255462960.0000
Epoch 89/100
600/600 [==============================] - 4s 7ms/step - loss: 255623276.9917 - mean_squared_error: 255623276.9917 - val_loss: 256298672.0000 - val_mean_squared_error: 256298672.0000
Epoch 90/100
600/600 [==============================] - 5s 8ms/step - loss: 225196806.3095 - mean_squared_error: 225196806.3095 - val_loss: 258884368.0000 - val_mean_squared_error: 258884368.0000
Epoch 91/100
600/600 [==============================] - 5s 8ms/step - loss: 240176409.9700 - mean_squared_error: 240176409.9700 - val_loss: 257321488.0000 - val_mean_squared_error: 257321488.0000
Epoch 92/100
600/600 [==============================] - 5s 9ms/step - loss: 239258905.3710 - mean_squared_error: 239258905.3710 - val_loss: 260538288.0000 - val_mean_squared_error: 260538288.0000
Epoch 93/100
600/600 [==============================] - 5s 9ms/step - loss: 245292192.2130 - mean_squared_error: 245292192.2130 - val_loss: 255793472.0000 - val_mean_squared_error: 255793472.0000
Epoch 94/100
600/600 [==============================] - 5s 8ms/step - loss: 243730631.4542 - mean_squared_error: 243730631.4542 - val_loss: 255217168.0000 - val_mean_squared_error: 255217168.0000
Epoch 95/100
600/600 [==============================] - 5s 8ms/step - loss: 241191757.3378 - mean_squared_error: 241191757.3378 - val_loss: 258455520.0000 - val_mean_squared_error: 258455520.0000
Epoch 96/100
600/600 [==============================] - 5s 8ms/step - loss: 234345852.6057 - mean_squared_error: 234345852.6057 - val_loss: 259143584.0000 - val_mean_squared_error: 259143584.0000
Epoch 97/100
600/600 [==============================] - 5s 8ms/step - loss: 237245952.2130 - mean_squared_error: 237245952.2130 - val_loss: 256133552.0000 - val_mean_squared_error: 256133552.0000
Epoch 98/100
600/600 [==============================] - 5s 8ms/step - loss: 246080581.8835 - mean_squared_error: 246080581.8835 - val_loss: 255931216.0000 - val_mean_squared_error: 255931216.0000
Epoch 99/100
600/600 [==============================] - 5s 8ms/step - loss: 259659058.3428 - mean_squared_error: 259659058.3428 - val_loss: 253970368.0000 - val_mean_squared_error: 253970368.0000
Epoch 100/100
600/600 [==============================] - 5s 8ms/step - loss: 251161536.1864 - mean_squared_error: 251161536.1864 - val_loss: 253605840.0000 - val_mean_squared_error: 253605840.0000
RMSE: 18558.15122927888
MSE: 344404977.04878515

Gradient Boosting Regressor

Regresja liniowa z wykorzystaniem technik wzmocnienia gradientowego z wykorzystaniem sklearn

RMSE: 19705.96

MSE: 388324934.97

"""
Gradient Boosting Regressor
"""

model = ensemble.GradientBoostingRegressor()
model.fit(X_train, Y_train)

gradient_predicted = model.predict(X_test)
print(f"RMSE: {mean_squared_error(Y_test, gradient_predicted, squared=False)}")
print(f"MSE: {mean_squared_error(Y_test, gradient_predicted)}")
RMSE: 19705.961914565338
MSE: 388324934.9782996

Podsumowanie

1. Sieć neuronowa

2. Gradient Boosting Regressor

3. Regresja liniowa

Najlepsze wyniki zostały osiągnięte przez model sieci neuronowej. Na drugim miejscu plasuje się metoda wzmocnienia gradientowego, a na trzecim regresja liniowa, wynika to z użycia wzmocnienia gradientowego, który pomaga wskazać kierunek, w którym nasz model ma się poprawiać.