import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelBinarizer
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split



class MyNeuralNetwork(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(MyNeuralNetwork, self).__init__(*args, **kwargs)
        self.fc1 = nn.Linear(12, 64)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(12, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

def prepare_df_for_nn(df: pd.DataFrame):
    
    id_column_name_list: list[str] = [column for column in df.columns.to_list() if 'id' in  column.lower()]
    if len(id_column_name_list) == 0:
        pass
    else:
        df.drop(id_column_name_list[0], inplace=True, axis=1)
    encoder: LabelBinarizer =  LabelBinarizer()
    df.reset_index(inplace=True)
    for column in df.columns:
        if str(df[column].dtype).lower() == 'object':
            encoded_column: np.ndarray = encoder.fit_transform(df[column])
            df[column] = pd.Series(encoded_column.flatten(), dtype=pd.Int16Dtype)
    return df

def load_data(path: str):
    df: pd.DataFrame = pd.read_csv('home_loan_train.csv') 
    train_dataset: pd.DataFrame = prepare_df_for_nn(df)
    x: np.ndarray = train_dataset.iloc[:, :-1].values.astype(float)
    y: np.ndarray = train_dataset.iloc[:, -1].values.astype(float)
    x_tensor: torch.Tensor = torch.tensor(x, dtype=torch.float32)
    y_tensor: torch.Tensor = torch.tensor(y, dtype=torch.float32)
    dataset: TensorDataset = TensorDataset(x_tensor, y_tensor)
    return dataset

def train(epochs: int, dataloader_train: DataLoader, dataloader_val: DataLoader):
    model: MyNeuralNetwork = MyNeuralNetwork()
    criterion: nn.BCELoss = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        total_correct_train = 0
        total_samples_train = 0
        total_correct_val = 0
        total_samples_val = 0

        for inputs, labels in dataloader_train:
            outputs = model(inputs)
            labels = labels.reshape((labels.shape[0], 1))
            loss = criterion(outputs, labels)
            predicted_labels = (outputs > 0.5).float()
            total_correct_train += (predicted_labels == labels).sum().item()
            total_samples_train += labels.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            for inputs, labels in dataloader_val:
                outputs_val = model(inputs)
                predicted_labels_val = (outputs_val > 0.5).float()
                labels = labels.reshape((labels.shape[0], 1))
                total_correct_val += (predicted_labels_val == labels).sum().item()
                total_samples_val += labels.size(0)

        accuracy_val = total_correct_val / total_samples_val
        accuracy_train = total_correct_train / total_samples_train
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Accuracy train: {accuracy_train:.4f}, Accuracy val: {accuracy_val:.4f}")
        
    return model

def create_dataset():
    home_loan_train = pd.read_csv('/Users/wojciechbatruszewicz/InformatykaStudia/SEMESTR8/IUM/ZADANIA/createDataset/loan_sanction_train.csv')
    home_loan_test = pd.read_csv('/Users/wojciechbatruszewicz/InformatykaStudia/SEMESTR8/IUM/ZADANIA/createDataset/loan_sanction_test.csv')

    home_loan_train_final, home_loan_test = train_test_split(home_loan_train, test_size=0.2, random_state=1)
    home_loan_test_final, home_loan_val_final = train_test_split(home_loan_test, test_size=0.5, random_state=1)

    numeric_cols_train = home_loan_train_final.select_dtypes(include='number').columns
    numeric_cols_test = home_loan_test_final.select_dtypes(include='number').columns
    numeric_cols_val = home_loan_val_final.select_dtypes(include='number').columns

    scaler = MinMaxScaler()

    home_loan_train_final[numeric_cols_train] = scaler.fit_transform(home_loan_train_final[numeric_cols_train])
    home_loan_test_final[numeric_cols_test] = scaler.fit_transform(home_loan_test_final[numeric_cols_test])
    home_loan_val_final[numeric_cols_val] = scaler.fit_transform(home_loan_val_final[numeric_cols_val])
    
    home_loan_train_final = home_loan_train_final.dropna()
    home_loan_test_final = home_loan_test_final.dropna()
    home_loan_val_final = home_loan_val_final.dropna()

    home_loan_train_final.to_csv('home_loan_train.csv', index=False)
    home_loan_test_final.to_csv('home_loan_test.csv', index=False)
    home_loan_val_final.to_csv('home_loan_val.csv', index=False)

def main() -> None:
    # create_dataset()
    train_dataset = load_data("home_loan_train.csv")
    val_dataset = load_data("home_loan_val.csv")
    
    batch_size: int = 32
    dataloader_train = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
    dataloader_val = DataLoader(val_dataset, batch_size = batch_size)
        
    model = train(20, dataloader_train, dataloader_val)
    torch.save(model.state_dict(), 'model.pt')
    
if __name__ == "__main__":
    main()