2022-05-12 22:30:38 +02:00
|
|
|
import tensorflow as tf
|
|
|
|
from keras import layers
|
|
|
|
from keras.models import save_model
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import mlflow
|
2022-05-13 00:06:48 +02:00
|
|
|
import mlflow.keras
|
2022-05-12 22:30:38 +02:00
|
|
|
from urllib.parse import urlparse
|
2022-05-13 14:46:06 +02:00
|
|
|
import sys
|
2022-05-12 22:30:38 +02:00
|
|
|
|
|
|
|
|
2022-05-13 14:46:06 +02:00
|
|
|
def train():
|
2022-05-12 22:30:38 +02:00
|
|
|
|
2022-05-13 14:46:06 +02:00
|
|
|
# Definicja wartości parametrów treningu
|
|
|
|
epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 100
|
|
|
|
units = int(sys.argv[2]) if len(sys.argv) > 2 else 1
|
|
|
|
learning_rate = float(sys.argv[3]) if len(sys.argv) > 3 else 0.1
|
2022-05-12 22:30:38 +02:00
|
|
|
|
2022-05-13 14:46:06 +02:00
|
|
|
# Konfiguracja serwera i nazwy eksperymentu MLflow
|
2022-05-14 00:41:18 +02:00
|
|
|
mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
2022-05-13 14:46:06 +02:00
|
|
|
mlflow.set_experiment('s449288')
|
2022-05-12 22:30:38 +02:00
|
|
|
|
|
|
|
# Podpięcie treningu do MLflow
|
|
|
|
with mlflow.start_run() as run:
|
|
|
|
print('MLflow run experiment_id: {0}'.format(run.info.experiment_id))
|
|
|
|
print('MLflow run artifact_uri: {0}'.format(run.info.artifact_uri))
|
|
|
|
|
|
|
|
# Wczytanie danych
|
|
|
|
data_train = pd.read_csv('lego_sets_clean_train.csv')
|
|
|
|
data_test = pd.read_csv('lego_sets_clean_test.csv')
|
|
|
|
|
|
|
|
# Wydzielenie zbiorów dla predykcji ceny zestawu na podstawie liczby klocków, którą zawiera
|
|
|
|
train_piece_counts = np.array(data_train['piece_count'])
|
|
|
|
train_prices = np.array(data_train['list_price'])
|
|
|
|
test_piece_counts = np.array(data_test['piece_count'])
|
|
|
|
test_prices = np.array(data_test['list_price'])
|
|
|
|
|
|
|
|
# Normalizacja
|
|
|
|
normalizer = layers.Normalization(input_shape=[1, ], axis=None)
|
|
|
|
normalizer.adapt(train_piece_counts)
|
|
|
|
|
|
|
|
# Inicjalizacja
|
|
|
|
model = tf.keras.Sequential([
|
|
|
|
normalizer,
|
|
|
|
layers.Dense(units=units)
|
|
|
|
])
|
|
|
|
|
|
|
|
# Kompilacja
|
|
|
|
model.compile(
|
|
|
|
optimizer=tf.optimizers.Adam(learning_rate=learning_rate),
|
|
|
|
loss='mean_absolute_error'
|
|
|
|
)
|
|
|
|
|
|
|
|
# Trening
|
|
|
|
history = model.fit(
|
|
|
|
train_piece_counts,
|
|
|
|
train_prices,
|
|
|
|
epochs=epochs,
|
|
|
|
verbose=0,
|
|
|
|
validation_split=0.2
|
|
|
|
)
|
|
|
|
|
|
|
|
# Wykonanie predykcji na danych ze zbioru testującego
|
|
|
|
y_pred = model.predict(test_piece_counts)
|
|
|
|
|
|
|
|
# Zapis predykcji do pliku
|
|
|
|
results = pd.DataFrame(
|
2022-05-13 14:46:06 +02:00
|
|
|
{'test_set_piece_count': test_piece_counts.tolist(),
|
|
|
|
'predicted_price': [round(a[0], 2) for a in y_pred.tolist()]})
|
2022-05-12 22:30:38 +02:00
|
|
|
results.to_csv('lego_reg_results.csv', index=False, header=True)
|
|
|
|
|
2022-05-13 14:46:06 +02:00
|
|
|
# Zapis modelu do pliku
|
2022-05-12 22:30:38 +02:00
|
|
|
model.save('lego_reg_model')
|
|
|
|
|
|
|
|
# Ewaluacja MAE na potrzeby MLflow (kopia z evaluate.py)
|
|
|
|
mae = model.evaluate(
|
|
|
|
test_piece_counts,
|
|
|
|
test_prices, verbose=0)
|
|
|
|
|
|
|
|
# Zapis parametrów i metryk dla MLflow
|
|
|
|
mlflow.log_param('epochs', epochs)
|
|
|
|
mlflow.log_param('units', units)
|
|
|
|
mlflow.log_param('learning_rate', learning_rate)
|
|
|
|
mlflow.log_metric("mae", mae)
|
|
|
|
|
|
|
|
# Logowanie i zapis modelu dla Mlflow
|
|
|
|
signature = mlflow.models.signature.infer_signature(train_piece_counts, model.predict(train_piece_counts))
|
|
|
|
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
|
|
|
|
if tracking_url_type_store != 'file':
|
|
|
|
mlflow.keras.log_model(model, 'lego-model', registered_model_name='TFLegoModel',
|
2022-05-13 14:46:06 +02:00
|
|
|
signature=signature)
|
2022-05-12 22:30:38 +02:00
|
|
|
else:
|
2022-05-13 14:46:06 +02:00
|
|
|
mlflow.keras.log_model(model, 'model', signature=signature, input_example=np.array(500))
|
2022-05-12 22:30:38 +02:00
|
|
|
|
|
|
|
|
2022-05-13 14:46:06 +02:00
|
|
|
if __name__ == '__main__':
|
2022-05-12 22:30:38 +02:00
|
|
|
train()
|