import os
import time

import pandas as pd
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn import preprocessing
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import ParameterGrid

TRAIN_DATA_DIR = "datasets_train_raw"
TEST_DATA_DIR = "datasets_test"


def invoke_and_measure(func, *args, **kwargs):
    start_time = time.time()
    result = func(*args, **kwargs)
    end_time = time.time()

    elapsed_time = end_time - start_time
    return result, elapsed_time


train_df_list = []
for file in os.listdir(TRAIN_DATA_DIR):
    file_path = os.path.join(TRAIN_DATA_DIR, file)
    df, load_time = invoke_and_measure(pd.read_csv, file_path, delim_whitespace=True, skiprows=1,
                                       names=["tbid", "tphys", "r", "vr", "vt", "ik1", "ik2", "sm1", "sm2", "a", "e",
                                              "collapsed"])
    train_df_list.append(df)
data_train = pd.concat(train_df_list, ignore_index=True).sample(frac=1, random_state=42)

test_df_list = []
for file in os.listdir(TEST_DATA_DIR):
    file_path = os.path.join(TEST_DATA_DIR, file)
    df, load_time = invoke_and_measure(pd.read_csv, file_path, delim_whitespace=True, skiprows=1,
                                       names=["tbid", "tphys", "r", "vr", "vt", "ik1", "ik2", "sm1", "sm2", "a", "e",
                                              "collapsed"])
    test_df_list.append(df)

data_test = pd.concat(test_df_list, ignore_index=True).sample(frac=1, random_state=42)

merged_data = pd.merge(data_train, data_test, indicator=True, how='outer')
overlap_rows = merged_data[merged_data['_merge'] == 'both']
if overlap_rows.empty:
    print("There are no overlapping rows between train and test datasets.")
else:
    print("Train and test datasets have following overlapping rows: ")
    print(overlap_rows)

X_train = data_train.iloc[:, 1:-1].values
y_train = data_train.iloc[:, -1].values

lab = preprocessing.LabelEncoder()
y_train_transformed = lab.fit_transform(y_train)

X_test = data_test.iloc[:, 1:-1].values
y_test = data_test.iloc[:, -1].values
y_test_transformed = lab.fit_transform(y_test)

classifiers_and_parameters = [
    {
        "name": "Nearest Neighbors",
        "classifier": KNeighborsClassifier(),
        "parameters": {
            "n_neighbors": [3, 5, 10, 50]
        }
    },
    {
        "name": "Decision Tree",
        "classifier": DecisionTreeClassifier(),
        "parameters": {
            "max_depth": [10, 20, 50]
        }
    },
    {
        "name": "Random Forest",
        "classifier": RandomForestClassifier(),
        "parameters": {
            "max_depth": [10, 20, 50],
            "n_estimators": [10, 50, 100],
            "max_features": ['sqrt', 'log2']
        }
    },
    {
        "name": "Naive Bayes",
        "classifier": GaussianNB(),
        "parameters": {
            "var_smoothing": [1e-09, 1e-08, 1e-07]
        }
    },
    {
        "name": "QDA",
        "classifier": QuadraticDiscriminantAnalysis(),
        "parameters": {
            "reg_param": [0.0, 0.5, 1.0]
        }
    },
    {
        "name": "Gradient Boosting",
        "classifier": GradientBoostingClassifier(),
        "parameters": {
            "learning_rate": [0.01, 0.05, 0.1],
            "n_estimators": [50, 100, 200]
        }
    }
]

for item in classifiers_and_parameters:
    name = item["name"]
    clf = item["classifier"]
    param_grid = ParameterGrid(item["parameters"])

    for params in param_grid:
        clf.set_params(**params)
        _, fit_time = invoke_and_measure(clf.fit, X_train, y_train_transformed)
        y_pred, pred_time = invoke_and_measure(clf.predict, X_test)

        accuracy = accuracy_score(y_test_transformed, y_pred)
        precision = precision_score(y_test_transformed, y_pred)
        recall = recall_score(y_test_transformed, y_pred)
        f1 = f1_score(y_test_transformed, y_pred)
        print(
            f"{name} with params {params}: accuracy={accuracy * 100:.2f}% precision={precision * 100:.2f}% recall={recall * 100:.2f}% "
            f"train_time={fit_time:.5f}s predict_time={pred_time:.5f}s")