import os
import time

import pandas as pd
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import ParameterGrid

TRAIN_DATA_DIR = "datasets_train_raw"
TEST_DATA_DIR = "datasets_test"


def invoke_and_measure(func, *args, **kwargs):
    start_time = time.time()
    result = func(*args, **kwargs)
    end_time = time.time()

    elapsed_time = end_time - start_time
    return result, elapsed_time


def load_dataset(directory):
    df_list = []
    for file in os.listdir(directory):
        file_path = os.path.join(directory, file)
        df, load_time = invoke_and_measure(pd.read_csv, file_path, delim_whitespace=True, skiprows=1,
                                           names=["tbid", "tphys", "r", "vr", "vt", "ik1", "ik2", "sm1", "sm2", "a",
                                                  "e",
                                                  "collapsed"])
        df_list.append(df)
    return pd.concat(df_list, ignore_index=True).sample(frac=0.01, random_state=42)


data_train = load_dataset(TRAIN_DATA_DIR)
data_test = load_dataset(TEST_DATA_DIR)

merged_data = pd.merge(data_train, data_test, indicator=True, how='outer')
overlap_rows = merged_data[merged_data['_merge'] == 'both']
if overlap_rows.empty:
    print("There are no overlapping rows between train and test datasets.")
else:
    print("Train and test datasets have following overlapping rows: ")
    print(overlap_rows)

X_train = data_train.iloc[:, 1:-1].values
y_train = data_train.iloc[:, -1].values

lab = preprocessing.LabelEncoder()
y_train_transformed = lab.fit_transform(y_train)

X_test = data_test.iloc[:, 1:-1].values
y_test = data_test.iloc[:, -1].values
y_test_transformed = lab.fit_transform(y_test)

classifiers_and_parameters = [
    {
        "name": "Nearest Neighbors",
        "classifier": KNeighborsClassifier(),
        "parameters": {
            "n_neighbors": [3, 5, 10, 50]
        }
    },
    {
        "name": "Decision Tree",
        "classifier": DecisionTreeClassifier(),
        "parameters": {
            "max_depth": [10, 20, 50]
        }
    },
    {
        "name": "Random Forest",
        "classifier": RandomForestClassifier(),
        "parameters": {
            "max_depth": [10, 20, 50],
            "n_estimators": [10, 50, 100],
            "max_features": ['sqrt', 'log2']
        }
    },
    {
        "name": "Naive Bayes",
        "classifier": GaussianNB(),
        "parameters": {
            "var_smoothing": [1e-09, 1e-08, 1e-07]
        }
    },
    {
        "name": "QDA",
        "classifier": QuadraticDiscriminantAnalysis(),
        "parameters": {
            "reg_param": [0.0, 0.5, 1.0]
        }
    },
    {
        "name": "Gradient Boosting",
        "classifier": GradientBoostingClassifier(),
        "parameters": {
            "learning_rate": [0.01, 0.05, 0.1],
            "n_estimators": [50, 100, 200]
        }
    }
]

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

for item in classifiers_and_parameters:
    name = item["name"]
    clf = item["classifier"]
    param_grid = ParameterGrid(item["parameters"])

    for params in param_grid:
        clf.set_params(**params)
        _, fit_time = invoke_and_measure(clf.fit, X_train_scaled, y_train_transformed)
        y_pred, pred_time = invoke_and_measure(clf.predict, X_test_scaled)

        accuracy = accuracy_score(y_test_transformed, y_pred)
        precision = precision_score(y_test_transformed, y_pred)
        recall = recall_score(y_test_transformed, y_pred)
        f1 = f1_score(y_test_transformed, y_pred)
        print(
            f"{name} with params {params}: accuracy={accuracy * 100:.2f}% precision={precision * 100:.2f}% recall={recall * 100:.2f}% "
            f"train_time={fit_time:.5f}s predict_time={pred_time:.5f}s")