import pandas as pd
import tensorflow as tf
import sys
import mlflow
from sklearn.metrics import accuracy_score

mlflow.set_tracking_uri("http://localhost:5000")

def main():
    train_data = pd.read_csv('./beer_reviews_train.csv')
    X_train = train_data[['review_aroma', 'review_appearance', 'review_palate', 'review_taste']]
    y_train = train_data['review_overall']

    tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=10000)
    tokenizer.fit_on_texts(X_train)
    X_train_seq = tokenizer.texts_to_sequences(X_train)

    X_train_pad = tf.keras.preprocessing.sequence.pad_sequences(X_train_seq, maxlen=100)

    with mlflow.start_run() as run:
        print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
        print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
        model = tf.keras.Sequential([
            tf.keras.layers.Embedding(input_dim=10000, output_dim=16, input_length=100),
            tf.keras.layers.GlobalAveragePooling1D(),
            tf.keras.layers.Dense(16, activation='relu'),
            tf.keras.layers.Dense(1, activation='sigmoid')
        ])

        model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['accuracy'])

        print(sys.argv[1])
        print(sys.argv[2])
        model.fit(X_train_pad, y_train, epochs=int(sys.argv[1]), batch_size=int(sys.argv[2]), validation_split=0.1)

        mlflow.log_param("epochs", int(sys.argv[1]))
        mlflow.log_param("batch_size", int(sys.argv[2]))

        test_data = pd.read_csv('./beer_reviews_test.csv')
        X_test = test_data[['review_aroma', 'review_appearance', 'review_palate', 'review_taste']]
        y_test = test_data['review_overall']

        predictions = model.predict(X_test).flatten()

        y_test_binary = (y_test >= 3).astype(int)

        accuracy = accuracy_score(y_test_binary, predictions.round())
        mlflow.log_metric("accuracy", accuracy)

if __name__ == '__main__':
    main()