40 KiB
import pandas as pd
from pandas import DataFrame
from sklearn import preprocessing
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn import ensemble
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
Cel projektu
Celem projektu jest stworzenie różnych modeli, których zadanie polega na predykcji cen poszczególnych samochodów na podstawie danych takich jak: - rok produkcji - przebieg - marka - rodzaj silnika - pojemność silnika
Wczytywanie danych
Zbiór zawiera listę samochodów, wraz z ich najważniejszymi cechami. Rozmiar zbioru: 47930 wierszy × 5 kolumn
col_names = ["price", "mileage", "year", "brand", "engine_type", "engine_cap"]
col_names_in = ["mileage", "year", "brand", "engine_type", "engine_cap"]
df_train = pd.read_csv(
"train/train.tsv", error_bad_lines=False, header=None, sep="\t", names=col_names
)
df = df_train
test = pd.read_csv(
"dev-0/in.tsv", error_bad_lines=False, header=None, sep="\t", names=col_names_in
)
test_expected = pd.read_csv("dev-0/expected.tsv", error_bad_lines=False, header=None, sep="\t")
df.head(10)
mileage | year | brand | engine_type | engine_cap | |
---|---|---|---|---|---|
0 | 29.077465 | 1.0060 | volvo | benzyna | 960.0 |
1 | 19.800027 | 1.0080 | kia | diesel | 418.8 |
2 | 27.916642 | 1.0075 | toyota | diesel | 420.0 |
3 | 32.976864 | 1.0075 | skoda | diesel | 480.0 |
4 | 38.932205 | 1.0060 | renault | diesel | 600.0 |
5 | 32.089045 | 1.0070 | opel | diesel | 390.0 |
6 | 31.997055 | 1.0065 | mercedes-benz | diesel | 900.0 |
7 | 26.138626 | 1.0075 | ford | benzyna | 300.0 |
8 | 38.142843 | 1.0015 | seat | diesel | 570.0 |
9 | 26.715448 | 1.0060 | mercedes-benz | diesel | 642.9 |
df
mileage | year | brand | engine_type | engine_cap | |
---|---|---|---|---|---|
0 | 29.077465 | 1.0060 | volvo | benzyna | 960.0 |
1 | 19.800027 | 1.0080 | kia | diesel | 418.8 |
2 | 27.916642 | 1.0075 | toyota | diesel | 420.0 |
3 | 32.976864 | 1.0075 | skoda | diesel | 480.0 |
4 | 38.932205 | 1.0060 | renault | diesel | 600.0 |
... | ... | ... | ... | ... | ... |
47997 | 25.530471 | 1.0055 | mini | benzyna | 479.4 |
47998 | 41.875698 | 1.0020 | mercedes-benz | diesel | 644.4 |
47999 | 40.061463 | 1.0025 | mercedes-benz | diesel | 506.7 |
48000 | 40.809827 | 1.0010 | mercedes-benz | diesel | 644.4 |
48001 | 38.337702 | 1.0035 | mercedes-benz | diesel | 896.1 |
47930 rows × 5 columns
Preprocessing danych
1. Dane odstające
Na początku zostały usunięte dane odstające, takie jak auta, których cena jest poniżej tysiąca, lub których przebieg jest wyższy niż 900 000km.
Y_test = test_expected[0]
# Drop rows which have strange value
brands = df.brand.value_counts()[:35].index.tolist()
indexes = df_train[(df_train.price < 1000) & (df_train.price > 1)].index
df_train.drop(indexes, inplace=True)
index = df_train[(df_train.mileage > 900000)].index
df_train.drop(index, inplace=True)
Y_train = df_train["price"]
df_train.drop("price", axis=1, inplace=True)
2. Normalizacja danych liczbowych
Dane takie jak rok, przebieg czy pojemność silnika zostały znacząco zredukowane
3. Lowercase nazw producentów
Nazwy producentów zostały zapisane wyłącznie małymi literami
4. Utworzenie 'dummies'
Zostały utworzone kolumny dla każdej z marek przyjmujące wartość (0,1)
5. Utworzenie wielomianiu stopnia 2
Z wykorzystaniem biblioteki sklearn.preprocessing
def preprocess_data(df: DataFrame, brands: list) -> DataFrame:
"""Prepare dataset to linear regression"""
df.brand = df.brand.apply(lambda x: x if x in brands else "0")
df["year"] = df.year / 2000
df["mileage"] = df.mileage ** 0.3
df["engine_cap"] = df.engine_cap * 0.3
df["brand"] = df["brand"].str.lower()
df = pd.get_dummies(df, columns=["brand", "engine_type"])
scaler = preprocessing.RobustScaler()
df[["mileage", "year", "engine_cap", "year"]] = scaler.fit_transform(
df[["mileage", "year", "engine_cap", "year"]]
)
poly = PolynomialFeatures(2, interaction_only=True)
df = poly.fit_transform(df)
return df
X_train = preprocess_data(df_train, brands)
X_test = preprocess_data(test, brands)
Regresja liniowa
Implementacja regresji liniowej za pomocą biblioteki sklearn
RMSE: 22065.84
MSE: 486901471.27
# Load model and fit data
lm_model = LinearRegression()
lm_model.fit(X_train, Y_train)
# Predict
lr_test_predicted = lm_model.predict(X_test)
# Predicted values to tsv
print("RMSE: ", mean_squared_error(Y_test, lr_test_predicted, squared=False))
print("MSE: ", mean_squared_error(Y_test, lr_test_predicted))
RMSE: 22065.843996457512 MSE: 486901471.276
Sieć neuronowa (Keras)
batch size: 64, epochs: 100, 3 x ReLU optimizer: adam, loss: mean squared error
RMSE: 18558.15
MSE: 344404977.04
input_layer = Input(shape=(X_train.shape[1]))
dense_layer_1 = Dense(100, activation='relu')(input_layer)
dense_layer_2 = Dense(50, activation='relu')(dense_layer_1)
dense_layer_3 = Dense(25, activation='relu')(dense_layer_2)
output = Dense(1)(dense_layer_3)
model = Model(inputs=input_layer, outputs=output)
model.compile(loss="mean_squared_error", optimizer="adam", metrics=["mean_squared_error"])
model.fit(X_train, Y_train, batch_size=64, epochs=100, verbose=1, validation_split=0.2)
y_pred = model.predict(X_test)
print(f"RMSE: {mean_squared_error(Y_test, y_pred, squared=False)}")
print(f"MSE: {mean_squared_error(Y_test, y_pred)}")
Epoch 1/100 600/600 [==============================] - 5s 7ms/step - loss: 2673221721.0250 - mean_squared_error: 2673221721.0250 - val_loss: 551377856.0000 - val_mean_squared_error: 551377856.0000 Epoch 2/100 600/600 [==============================] - 3s 6ms/step - loss: 442345591.4276 - mean_squared_error: 442345591.4276 - val_loss: 423978976.0000 - val_mean_squared_error: 423978976.0000 Epoch 3/100 600/600 [==============================] - 3s 6ms/step - loss: 378405340.5923 - mean_squared_error: 378405340.5923 - val_loss: 378884192.0000 - val_mean_squared_error: 378884192.0000 Epoch 4/100 600/600 [==============================] - 4s 7ms/step - loss: 384769043.7271 - mean_squared_error: 384769043.7271 - val_loss: 357286848.0000 - val_mean_squared_error: 357286848.0000 Epoch 5/100 600/600 [==============================] - 4s 6ms/step - loss: 343957554.7155 - mean_squared_error: 343957554.7155 - val_loss: 340152512.0000 - val_mean_squared_error: 340152512.0000 Epoch 6/100 600/600 [==============================] - 4s 7ms/step - loss: 337887713.4376 - mean_squared_error: 337887713.4376 - val_loss: 331659872.0000 - val_mean_squared_error: 331659872.0000 Epoch 7/100 600/600 [==============================] - 4s 6ms/step - loss: 313163787.1281 - mean_squared_error: 313163787.1281 - val_loss: 323213504.0000 - val_mean_squared_error: 323213504.0000 Epoch 8/100 600/600 [==============================] - 4s 6ms/step - loss: 314339570.3428 - mean_squared_error: 314339570.3428 - val_loss: 317251488.0000 - val_mean_squared_error: 317251488.0000 Epoch 9/100 600/600 [==============================] - 5s 9ms/step - loss: 303178196.5524 - mean_squared_error: 303178196.5524 - val_loss: 314496736.0000 - val_mean_squared_error: 314496736.0000 Epoch 10/100 600/600 [==============================] - 5s 8ms/step - loss: 313794526.9351 - mean_squared_error: 313794526.9351 - val_loss: 310654176.0000 - val_mean_squared_error: 310654176.0000 Epoch 11/100 600/600 [==============================] - 4s 7ms/step - loss: 284679367.4542 - mean_squared_error: 284679367.4542 - val_loss: 304685248.0000 - val_mean_squared_error: 304685248.0000 Epoch 12/100 600/600 [==============================] - 4s 7ms/step - loss: 311546194.3161 - mean_squared_error: 311546194.3161 - val_loss: 304376256.0000 - val_mean_squared_error: 304376256.0000 Epoch 13/100 600/600 [==============================] - 4s 7ms/step - loss: 286383306.4892 - mean_squared_error: 286383306.4892 - val_loss: 303079392.0000 - val_mean_squared_error: 303079392.0000 Epoch 14/100 600/600 [==============================] - 4s 7ms/step - loss: 312419505.4110 - mean_squared_error: 312419505.4110 - val_loss: 296362720.0000 - val_mean_squared_error: 296362720.0000 Epoch 15/100 600/600 [==============================] - 4s 7ms/step - loss: 281224970.5957 - mean_squared_error: 281224970.5957 - val_loss: 295127040.0000 - val_mean_squared_error: 295127040.0000 Epoch 16/100 600/600 [==============================] - 5s 8ms/step - loss: 300456786.4759 - mean_squared_error: 300456786.4759 - val_loss: 291579264.0000 - val_mean_squared_error: 291579264.0000 Epoch 17/100 600/600 [==============================] - 4s 6ms/step - loss: 271273312.1864 - mean_squared_error: 271273312.1864 - val_loss: 293092064.0000 - val_mean_squared_error: 293092064.0000 Epoch 18/100 600/600 [==============================] - 4s 6ms/step - loss: 274466717.5241 - mean_squared_error: 274466717.5241 - val_loss: 291955424.0000 - val_mean_squared_error: 291955424.0000 Epoch 19/100 600/600 [==============================] - 4s 6ms/step - loss: 280078536.3328 - mean_squared_error: 280078536.3328 - val_loss: 286574528.0000 - val_mean_squared_error: 286574528.0000 Epoch 20/100 600/600 [==============================] - 4s 7ms/step - loss: 283455574.6290 - mean_squared_error: 283455574.6290 - val_loss: 283341472.0000 - val_mean_squared_error: 283341472.0000 Epoch 21/100 600/600 [==============================] - 4s 6ms/step - loss: 269008367.3078 - mean_squared_error: 269008367.3078 - val_loss: 287479776.0000 - val_mean_squared_error: 287479776.0000 Epoch 22/100 600/600 [==============================] - 4s 6ms/step - loss: 285307228.4326 - mean_squared_error: 285307228.4326 - val_loss: 281901632.0000 - val_mean_squared_error: 281901632.0000 Epoch 23/100 600/600 [==============================] - 4s 6ms/step - loss: 270041985.1448 - mean_squared_error: 270041985.1448 - val_loss: 285430688.0000 - val_mean_squared_error: 285430688.0000 Epoch 24/100 600/600 [==============================] - 4s 7ms/step - loss: 287381889.8902 - mean_squared_error: 287381889.8902 - val_loss: 283002208.0000 - val_mean_squared_error: 283002208.0000 Epoch 25/100 600/600 [==============================] - 4s 7ms/step - loss: 290092397.6839 - mean_squared_error: 290092397.6839 - val_loss: 281590592.0000 - val_mean_squared_error: 281590592.0000 Epoch 26/100 600/600 [==============================] - 4s 7ms/step - loss: 287577090.5291 - mean_squared_error: 287577090.5291 - val_loss: 277902464.0000 - val_mean_squared_error: 277902464.0000 Epoch 27/100 600/600 [==============================] - 4s 7ms/step - loss: 272385114.1165 - mean_squared_error: 272385114.1165 - val_loss: 280177056.0000 - val_mean_squared_error: 280177056.0000 Epoch 28/100 600/600 [==============================] - 4s 7ms/step - loss: 257438328.8785 - mean_squared_error: 257438328.8785 - val_loss: 284091104.0000 - val_mean_squared_error: 284091104.0000 Epoch 29/100 600/600 [==============================] - 3s 6ms/step - loss: 276722888.6256 - mean_squared_error: 276722888.6256 - val_loss: 277816032.0000 - val_mean_squared_error: 277816032.0000 Epoch 30/100 600/600 [==============================] - 3s 6ms/step - loss: 271698972.8586 - mean_squared_error: 271698972.8586 - val_loss: 281744256.0000 - val_mean_squared_error: 281744256.0000 Epoch 31/100 600/600 [==============================] - 4s 7ms/step - loss: 277800460.3261 - mean_squared_error: 277800460.3261 - val_loss: 275767552.0000 - val_mean_squared_error: 275767552.0000 Epoch 32/100 600/600 [==============================] - 5s 8ms/step - loss: 255387858.4093 - mean_squared_error: 255387858.4093 - val_loss: 274004512.0000 - val_mean_squared_error: 274004512.0000 Epoch 33/100 600/600 [==============================] - 4s 7ms/step - loss: 267013800.2263 - mean_squared_error: 267013800.2263 - val_loss: 274496832.0000 - val_mean_squared_error: 274496832.0000 Epoch 34/100 600/600 [==============================] - 4s 6ms/step - loss: 270699375.3611 - mean_squared_error: 270699375.3611 - val_loss: 276478944.0000 - val_mean_squared_error: 276478944.0000 Epoch 35/100 600/600 [==============================] - 4s 6ms/step - loss: 279694474.7288 - mean_squared_error: 279694474.7288 - val_loss: 279028160.0000 - val_mean_squared_error: 279028160.0000 Epoch 36/100 600/600 [==============================] - 4s 7ms/step - loss: 270710719.1747 - mean_squared_error: 270710719.1747 - val_loss: 273949600.0000 - val_mean_squared_error: 273949600.0000 Epoch 37/100 600/600 [==============================] - 5s 8ms/step - loss: 272804902.7354 - mean_squared_error: 272804902.7354 - val_loss: 274979104.0000 - val_mean_squared_error: 274979104.0000 Epoch 38/100 600/600 [==============================] - 5s 8ms/step - loss: 254984751.6805 - mean_squared_error: 254984751.6805 - val_loss: 278099008.0000 - val_mean_squared_error: 278099008.0000 Epoch 39/100 600/600 [==============================] - 4s 7ms/step - loss: 263644632.1597 - mean_squared_error: 263644632.1597 - val_loss: 275570400.0000 - val_mean_squared_error: 275570400.0000 Epoch 40/100 600/600 [==============================] - 4s 7ms/step - loss: 283981970.5824 - mean_squared_error: 283981970.5824 - val_loss: 269600896.0000 - val_mean_squared_error: 269600896.0000 Epoch 41/100 600/600 [==============================] - 4s 6ms/step - loss: 263011782.5225 - mean_squared_error: 263011782.5225 - val_loss: 270043744.0000 - val_mean_squared_error: 270043744.0000 Epoch 42/100 600/600 [==============================] - 4s 6ms/step - loss: 275432014.8286 - mean_squared_error: 275432014.8286 - val_loss: 268776480.0000 - val_mean_squared_error: 268776480.0000 Epoch 43/100 600/600 [==============================] - 3s 6ms/step - loss: 260651440.1864 - mean_squared_error: 260651440.1864 - val_loss: 275194144.0000 - val_mean_squared_error: 275194144.0000 Epoch 44/100 600/600 [==============================] - 4s 7ms/step - loss: 257748764.5125 - mean_squared_error: 257748764.5125 - val_loss: 270911072.0000 - val_mean_squared_error: 270911072.0000 Epoch 45/100 600/600 [==============================] - 3s 6ms/step - loss: 266450056.8918 - mean_squared_error: 266450056.8918 - val_loss: 270361472.0000 - val_mean_squared_error: 270361472.0000 Epoch 46/100 600/600 [==============================] - 3s 5ms/step - loss: 267280017.8369 - mean_squared_error: 267280017.8369 - val_loss: 268170224.0000 - val_mean_squared_error: 268170224.0000 Epoch 47/100 600/600 [==============================] - 4s 6ms/step - loss: 270953393.1048 - mean_squared_error: 270953393.1048 - val_loss: 266962048.0000 - val_mean_squared_error: 266962048.0000 Epoch 48/100 600/600 [==============================] - 5s 8ms/step - loss: 261569597.8436 - mean_squared_error: 261569597.8436 - val_loss: 270642752.0000 - val_mean_squared_error: 270642752.0000 Epoch 49/100 600/600 [==============================] - 5s 9ms/step - loss: 252863808.0799 - mean_squared_error: 252863808.0799 - val_loss: 264875584.0000 - val_mean_squared_error: 264875584.0000 Epoch 50/100 600/600 [==============================] - 4s 7ms/step - loss: 269732835.3677 - mean_squared_error: 269732835.3677 - val_loss: 265078368.0000 - val_mean_squared_error: 265078368.0000 Epoch 51/100 600/600 [==============================] - 4s 7ms/step - loss: 277777046.5225 - mean_squared_error: 277777046.5225 - val_loss: 265569424.0000 - val_mean_squared_error: 265569424.0000 Epoch 52/100 600/600 [==============================] - 4s 7ms/step - loss: 259421935.3611 - mean_squared_error: 259421935.3611 - val_loss: 263121728.0000 - val_mean_squared_error: 263121728.0000 Epoch 53/100 600/600 [==============================] - 4s 7ms/step - loss: 246818920.4126 - mean_squared_error: 246818920.4126 - val_loss: 268283376.0000 - val_mean_squared_error: 268283376.0000 Epoch 54/100 600/600 [==============================] - 4s 7ms/step - loss: 262059519.1747 - mean_squared_error: 262059519.1747 - val_loss: 264587712.0000 - val_mean_squared_error: 264587712.0000 Epoch 55/100 600/600 [==============================] - 5s 8ms/step - loss: 251146320.5857 - mean_squared_error: 251146320.5857 - val_loss: 264188048.0000 - val_mean_squared_error: 264188048.0000 Epoch 56/100 600/600 [==============================] - 5s 8ms/step - loss: 277728213.8569 - mean_squared_error: 277728213.8569 - val_loss: 265315792.0000 - val_mean_squared_error: 265315792.0000 Epoch 57/100 600/600 [==============================] - 5s 8ms/step - loss: 273021954.3694 - mean_squared_error: 273021954.3694 - val_loss: 265453232.0000 - val_mean_squared_error: 265453232.0000 Epoch 58/100 600/600 [==============================] - 5s 7ms/step - loss: 235758602.8619 - mean_squared_error: 235758602.8619 - val_loss: 267418880.0000 - val_mean_squared_error: 267418880.0000 Epoch 59/100 600/600 [==============================] - 4s 7ms/step - loss: 253989512.0932 - mean_squared_error: 253989512.0932 - val_loss: 263675520.0000 - val_mean_squared_error: 263675520.0000 Epoch 60/100 600/600 [==============================] - 4s 7ms/step - loss: 262297644.0067 - mean_squared_error: 262297644.0067 - val_loss: 260217264.0000 - val_mean_squared_error: 260217264.0000 Epoch 61/100 600/600 [==============================] - 5s 8ms/step - loss: 265082348.2596 - mean_squared_error: 265082348.2596 - val_loss: 262431664.0000 - val_mean_squared_error: 262431664.0000 Epoch 62/100 600/600 [==============================] - 5s 8ms/step - loss: 257870115.1681 - mean_squared_error: 257870115.1681 - val_loss: 262881312.0000 - val_mean_squared_error: 262881312.0000 Epoch 63/100 600/600 [==============================] - 4s 7ms/step - loss: 240457727.4143 - mean_squared_error: 240457727.4143 - val_loss: 261733392.0000 - val_mean_squared_error: 261733392.0000 Epoch 64/100 600/600 [==============================] - 5s 8ms/step - loss: 277940927.2013 - mean_squared_error: 277940927.2013 - val_loss: 260641392.0000 - val_mean_squared_error: 260641392.0000 Epoch 65/100 600/600 [==============================] - 5s 9ms/step - loss: 243245328.8120 - mean_squared_error: 243245328.8120 - val_loss: 266007856.0000 - val_mean_squared_error: 266007856.0000 Epoch 66/100 600/600 [==============================] - 5s 8ms/step - loss: 280276759.2679 - mean_squared_error: 280276759.2679 - val_loss: 260283456.0000 - val_mean_squared_error: 260283456.0000 Epoch 67/100 600/600 [==============================] - 5s 8ms/step - loss: 256323983.0948 - mean_squared_error: 256323983.0948 - val_loss: 260369344.0000 - val_mean_squared_error: 260369344.0000 Epoch 68/100 600/600 [==============================] - 6s 10ms/step - loss: 257920354.0499 - mean_squared_error: 257920354.0499 - val_loss: 259491872.0000 - val_mean_squared_error: 259491872.0000 Epoch 69/100 600/600 [==============================] - 5s 8ms/step - loss: 248803245.8436 - mean_squared_error: 248803245.8436 - val_loss: 260577376.0000 - val_mean_squared_error: 260577376.0000 Epoch 70/100 600/600 [==============================] - 5s 9ms/step - loss: 262326159.4676 - mean_squared_error: 262326159.4676 - val_loss: 261333040.0000 - val_mean_squared_error: 261333040.0000 Epoch 71/100 600/600 [==============================] - 5s 8ms/step - loss: 244443427.6473 - mean_squared_error: 244443427.6473 - val_loss: 260864640.0000 - val_mean_squared_error: 260864640.0000 Epoch 72/100 600/600 [==============================] - 5s 8ms/step - loss: 251957369.0782 - mean_squared_error: 251957369.0782 - val_loss: 260814608.0000 - val_mean_squared_error: 260814608.0000 Epoch 73/100 600/600 [==============================] - 5s 7ms/step - loss: 258517712.4792 - mean_squared_error: 258517712.4792 - val_loss: 258114464.0000 - val_mean_squared_error: 258114464.0000 Epoch 74/100 600/600 [==============================] - 5s 8ms/step - loss: 219879441.5308 - mean_squared_error: 219879441.5308 - val_loss: 262360560.0000 - val_mean_squared_error: 262360560.0000 Epoch 75/100 600/600 [==============================] - 5s 8ms/step - loss: 260794823.4143 - mean_squared_error: 260794823.4143 - val_loss: 256498032.0000 - val_mean_squared_error: 256498032.0000 Epoch 76/100 600/600 [==============================] - 5s 9ms/step - loss: 239282565.3511 - mean_squared_error: 239282565.3511 - val_loss: 262098640.0000 - val_mean_squared_error: 262098640.0000 Epoch 77/100 600/600 [==============================] - 5s 8ms/step - loss: 257229255.3478 - mean_squared_error: 257229255.3478 - val_loss: 259628224.0000 - val_mean_squared_error: 259628224.0000 Epoch 78/100 600/600 [==============================] - 5s 9ms/step - loss: 244059795.5541 - mean_squared_error: 244059795.5541 - val_loss: 259398896.0000 - val_mean_squared_error: 259398896.0000 Epoch 79/100 600/600 [==============================] - 5s 8ms/step - loss: 255108445.9767 - mean_squared_error: 255108445.9767 - val_loss: 259564784.0000 - val_mean_squared_error: 259564784.0000 Epoch 80/100 600/600 [==============================] - 5s 9ms/step - loss: 252858985.0250 - mean_squared_error: 252858985.0250 - val_loss: 261315536.0000 - val_mean_squared_error: 261315536.0000 Epoch 81/100 600/600 [==============================] - 6s 9ms/step - loss: 251545596.1664 - mean_squared_error: 251545596.1664 - val_loss: 259559184.0000 - val_mean_squared_error: 259559184.0000 Epoch 82/100 600/600 [==============================] - 6s 10ms/step - loss: 253448548.1464 - mean_squared_error: 253448548.1464 - val_loss: 257081360.0000 - val_mean_squared_error: 257081360.0000 Epoch 83/100 600/600 [==============================] - 7s 11ms/step - loss: 223692804.4592 - mean_squared_error: 223692804.4592 - val_loss: 260599392.0000 - val_mean_squared_error: 260599392.0000 Epoch 84/100 600/600 [==============================] - 6s 10ms/step - loss: 238604269.9767 - mean_squared_error: 238604269.9767 - val_loss: 259515504.0000 - val_mean_squared_error: 259515504.0000 Epoch 85/100 600/600 [==============================] - 4s 6ms/step - loss: 239357600.3993 - mean_squared_error: 239357600.3993 - val_loss: 258469696.0000 - val_mean_squared_error: 258469696.0000 Epoch 86/100 600/600 [==============================] - 4s 7ms/step - loss: 250585435.0483 - mean_squared_error: 250585435.0483 - val_loss: 257148032.0000 - val_mean_squared_error: 257148032.0000 Epoch 87/100 600/600 [==============================] - 4s 6ms/step - loss: 241135506.1564 - mean_squared_error: 241135506.1564 - val_loss: 255790992.0000 - val_mean_squared_error: 255790992.0000 Epoch 88/100 600/600 [==============================] - 4s 7ms/step - loss: 236478667.7005 - mean_squared_error: 236478667.7005 - val_loss: 255462960.0000 - val_mean_squared_error: 255462960.0000 Epoch 89/100 600/600 [==============================] - 4s 7ms/step - loss: 255623276.9917 - mean_squared_error: 255623276.9917 - val_loss: 256298672.0000 - val_mean_squared_error: 256298672.0000 Epoch 90/100 600/600 [==============================] - 5s 8ms/step - loss: 225196806.3095 - mean_squared_error: 225196806.3095 - val_loss: 258884368.0000 - val_mean_squared_error: 258884368.0000 Epoch 91/100 600/600 [==============================] - 5s 8ms/step - loss: 240176409.9700 - mean_squared_error: 240176409.9700 - val_loss: 257321488.0000 - val_mean_squared_error: 257321488.0000 Epoch 92/100 600/600 [==============================] - 5s 9ms/step - loss: 239258905.3710 - mean_squared_error: 239258905.3710 - val_loss: 260538288.0000 - val_mean_squared_error: 260538288.0000 Epoch 93/100 600/600 [==============================] - 5s 9ms/step - loss: 245292192.2130 - mean_squared_error: 245292192.2130 - val_loss: 255793472.0000 - val_mean_squared_error: 255793472.0000 Epoch 94/100 600/600 [==============================] - 5s 8ms/step - loss: 243730631.4542 - mean_squared_error: 243730631.4542 - val_loss: 255217168.0000 - val_mean_squared_error: 255217168.0000 Epoch 95/100 600/600 [==============================] - 5s 8ms/step - loss: 241191757.3378 - mean_squared_error: 241191757.3378 - val_loss: 258455520.0000 - val_mean_squared_error: 258455520.0000 Epoch 96/100 600/600 [==============================] - 5s 8ms/step - loss: 234345852.6057 - mean_squared_error: 234345852.6057 - val_loss: 259143584.0000 - val_mean_squared_error: 259143584.0000 Epoch 97/100 600/600 [==============================] - 5s 8ms/step - loss: 237245952.2130 - mean_squared_error: 237245952.2130 - val_loss: 256133552.0000 - val_mean_squared_error: 256133552.0000 Epoch 98/100 600/600 [==============================] - 5s 8ms/step - loss: 246080581.8835 - mean_squared_error: 246080581.8835 - val_loss: 255931216.0000 - val_mean_squared_error: 255931216.0000 Epoch 99/100 600/600 [==============================] - 5s 8ms/step - loss: 259659058.3428 - mean_squared_error: 259659058.3428 - val_loss: 253970368.0000 - val_mean_squared_error: 253970368.0000 Epoch 100/100 600/600 [==============================] - 5s 8ms/step - loss: 251161536.1864 - mean_squared_error: 251161536.1864 - val_loss: 253605840.0000 - val_mean_squared_error: 253605840.0000 RMSE: 18558.15122927888 MSE: 344404977.04878515
Gradient Boosting Regressor
Regresja liniowa z wykorzystaniem technik wzmocnienia gradientowego z wykorzystaniem sklearn
RMSE: 19705.96
MSE: 388324934.97
"""
Gradient Boosting Regressor
"""
model = ensemble.GradientBoostingRegressor()
model.fit(X_train, Y_train)
gradient_predicted = model.predict(X_test)
print(f"RMSE: {mean_squared_error(Y_test, gradient_predicted, squared=False)}")
print(f"MSE: {mean_squared_error(Y_test, gradient_predicted)}")
RMSE: 19705.961914565338 MSE: 388324934.9782996
Podsumowanie
1. Sieć neuronowa
2. Gradient Boosting Regressor
3. Regresja liniowa
Najlepsze wyniki zostały osiągnięte przez model sieci neuronowej. Na drugim miejscu plasuje się metoda wzmocnienia gradientowego, a na trzecim regresja liniowa, wynika to z użycia wzmocnienia gradientowego, który pomaga wskazać kierunek, w którym nasz model ma się poprawiać.