import tensorflow as tf
from keras import layers
from keras.models import save_model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver

# Stworzenie obiektu klasy Experiment do śledzenia przebiegu regresji narzędziem Sacred
ex = Experiment(save_git_info=False)

# Dodanie obserwatora FileObserver
ex.observers.append(FileStorageObserver('runs'))

#Dodanie obserwatora Mongo
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))

# Przykładowa modyfikowalna z Sacred konfiguracja wybranych parametrów treningu
@ex.config
def config():
    epochs = 100
    units = 1
    learning_rate = 0.1


# Reszta kodu wrzucona do udekorowanej funkcji train do wywołania przez Sacred, żeby coś było capture'owane
@ex.capture
def train(epochs, units, learning_rate, _run):

    # Wczytanie danych
    data_train = pd.read_csv('lego_sets_clean_train.csv')
    data_test = pd.read_csv('lego_sets_clean_test.csv')

    # Wydzielenie zbiorów dla predykcji ceny zestawu na podstawie liczby klocków, którą zawiera
    train_piece_counts = np.array(data_train['piece_count'])
    train_prices = np.array(data_train['list_price'])
    test_piece_counts = np.array(data_test['piece_count'])
    test_prices = np.array(data_test['list_price'])

    # Normalizacja
    normalizer = layers.Normalization(input_shape=[1, ], axis=None)
    normalizer.adapt(train_piece_counts)

    # Inicjalizacja
    model = tf.keras.Sequential([
        normalizer,
        layers.Dense(units=units)
    ])

    # Kompilacja
    model.compile(
        optimizer=tf.optimizers.Adam(learning_rate=learning_rate),
        loss='mean_absolute_error'
    )

    # Trening
    history = model.fit(
        train_piece_counts,
        train_prices,
        epochs=epochs,
        verbose=0,
        validation_split=0.2
    )

    # Wykonanie predykcji na danych ze zbioru testującego
    y_pred = model.predict(test_piece_counts)

    # Zapis predykcji do pliku
    results = pd.DataFrame(
        {'test_set_piece_count': test_piece_counts.tolist(), 'predicted_price': [round(a[0], 2) for a in y_pred.tolist()]})
    results.to_csv('lego_reg_results.csv', index=False, header=True)

    # Zapis modelu do pliku standardowo poprzez metodę kerasa i poprzez metodę obiektu Experiment z Sacred
    model.save('lego_reg_model')
    ex.add_artifact('lego_reg_model/saved_model.pb')

    # Przykładowo zwracamy loss ostatniej epoki w charakterze wyników, żeby było widoczne w plikach zapisanych przez obserwator
    hist = pd.DataFrame(history.history)
    hist['epoch'] = history.epoch
    _run.log_scalar('final.training.loss', hist['loss'].iloc[-1])

@ex.automain
def main(units, learning_rate):
    train()