2022-04-08 18:16:40 +02:00
|
|
|
import tensorflow as tf
|
|
|
|
from keras import layers
|
2022-04-24 17:13:01 +02:00
|
|
|
from keras.models import save_model
|
2022-04-08 18:16:40 +02:00
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
2022-04-24 18:05:57 +02:00
|
|
|
import sys
|
|
|
|
|
|
|
|
# Pobranie przykładowego argumentu trenowania
|
|
|
|
EPOCHS_NUM = int(sys.argv[1])
|
2022-04-08 18:16:40 +02:00
|
|
|
|
|
|
|
# Wczytanie danych
|
|
|
|
data_train = pd.read_csv('lego_sets_clean_train.csv')
|
|
|
|
data_test = pd.read_csv('lego_sets_clean_test.csv')
|
|
|
|
|
|
|
|
# Wydzielenie zbiorów dla predykcji ceny zestawu na podstawie liczby klocków, którą zawiera
|
|
|
|
train_piece_counts = np.array(data_train['piece_count'])
|
|
|
|
train_prices = np.array(data_train['list_price'])
|
|
|
|
test_piece_counts = np.array(data_test['piece_count'])
|
|
|
|
test_prices = np.array(data_test['list_price'])
|
|
|
|
|
|
|
|
# Normalizacja
|
|
|
|
normalizer = layers.Normalization(input_shape=[1, ], axis=None)
|
|
|
|
normalizer.adapt(train_piece_counts)
|
|
|
|
|
|
|
|
# Inicjalizacja
|
|
|
|
model = tf.keras.Sequential([
|
|
|
|
normalizer,
|
|
|
|
layers.Dense(units=1)
|
|
|
|
])
|
|
|
|
|
|
|
|
# Kompilacja
|
|
|
|
model.compile(
|
|
|
|
optimizer=tf.optimizers.Adam(learning_rate=0.1),
|
|
|
|
loss='mean_absolute_error'
|
|
|
|
)
|
|
|
|
|
|
|
|
# Trening
|
|
|
|
history = model.fit(
|
|
|
|
train_piece_counts,
|
|
|
|
train_prices,
|
2022-04-24 18:05:57 +02:00
|
|
|
epochs=EPOCHS_NUM,
|
2022-04-08 18:16:40 +02:00
|
|
|
verbose=0,
|
|
|
|
validation_split=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# Prosta ewaluacja
|
|
|
|
test_results = {'model': model.evaluate(
|
|
|
|
test_piece_counts,
|
|
|
|
test_prices, verbose=0)
|
|
|
|
}
|
|
|
|
|
|
|
|
# Wykonanie wielu predykcji
|
|
|
|
x = tf.linspace(100, 7000, 6901)
|
|
|
|
y = model.predict(x)
|
|
|
|
|
|
|
|
# Zapis predykcji do pliku
|
|
|
|
results = pd.DataFrame({"input_piece_count": x.numpy().tolist(), "predicted_price": [a[0] for a in y.tolist()]})
|
2022-04-24 17:10:42 +02:00
|
|
|
results.to_csv(r'lego_reg_results.csv', index=False, header=True)
|
|
|
|
|
|
|
|
# Zapis modelu do pliku
|
|
|
|
model.save('lego_reg_model')
|
2022-04-08 18:16:40 +02:00
|
|
|
|
|
|
|
# Opcjonalne statystyki, wykresy
|
|
|
|
'''
|
|
|
|
print(test_results)
|
|
|
|
|
|
|
|
hist = pd.DataFrame(history.history)
|
|
|
|
hist['epoch'] = history.epoch
|
|
|
|
print(hist.tail())
|
|
|
|
|
|
|
|
plt.scatter(train_piece_counts, train_prices, label='Data')
|
|
|
|
plt.plot(x, y, color='k', label='Predictions')
|
|
|
|
plt.xlabel('pieces')
|
|
|
|
plt.ylabel('price')
|
|
|
|
plt.legend()
|
|
|
|
plt.show()
|
|
|
|
'''
|