from typing import List, Tuple

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier

from common.helpers import castle_neighbors, find_neighbours
from models.castle import Castle
from models.knight import Knight
from models.monster import Monster


def manhattan_distance(p1: Tuple[int, int], p2: Tuple[int, int]) -> int:
    x1, y1 = p1
    x2, y2 = p2
    return abs(x1 - x2) + abs(y1 - y2)


def parse_hp(hp: int) -> int:
    return max(0, hp)


def parse_idx_of_opp_or_monster(s: str) -> int:
    return int(s[-1]) - 1


class DecisionTree:
    def __init__(self) -> None:
        data_frame = pd.read_csv('learning/dataset_tree.csv', delimiter=';')
        unlabeled_goals = data_frame['goal']
        self.goals_label_encoder = LabelEncoder()
        self.goals = self.goals_label_encoder.fit_transform(unlabeled_goals)
        self.train_set = data_frame.drop('goal', axis='columns')
        self.model = DecisionTreeClassifier(criterion='entropy')
        self.model.fit(self.train_set.values, self.goals)

    def predict_move(self, grid: List[List[str]], current_knight: Knight, castle: Castle, monsters: List[Monster],
                     opponents: List[Knight]) -> \
            List[Tuple[int, int]]:
        distance_to_castle = manhattan_distance(current_knight.position, castle.position)

        monsters_parsed = []
        for monster in monsters:
            monsters_parsed.append((manhattan_distance(current_knight.position, monster.position), parse_hp(
                monster.current_hp)))

        opponents_parsed = []
        for opponent in opponents:
            opponents_parsed.append(
                (manhattan_distance(current_knight.position, opponent.position), parse_hp(opponent.health_bar.current_hp)))

        prediction = self.get_prediction(tower_dist=distance_to_castle, tower_hp=castle.current_hp,
                                         mob1_dist=monsters_parsed[0][0], mob1_hp=monsters_parsed[0][1],
                                         mob2_dist=monsters_parsed[1][0], mob2_hp=monsters_parsed[1][1],
                                         opp1_dist=opponents_parsed[0][0], opp1_hp=opponents_parsed[0][1],
                                         opp2_dist=opponents_parsed[1][0], opp2_hp=opponents_parsed[1][1],
                                         opp3_dist=opponents_parsed[2][0], opp3_hp=opponents_parsed[2][1],
                                         opp4_dist=opponents_parsed[3][0], opp4_hp=opponents_parsed[3][1],
                                         agent_hp=current_knight.health_bar.current_hp)
        print(prediction)
        if prediction == 'tower':  # castle...
            return castle_neighbors(grid, castle_bottom_right_row=castle.position[0],
                                    castle_bottom_right_col=castle.position[1])
        elif prediction.startswith('opp'):
            idx = parse_idx_of_opp_or_monster(prediction)
            return find_neighbours(grid, opponents[idx].position[1], opponents[idx].position[0])
        else:
            idx = parse_idx_of_opp_or_monster(prediction)
            return find_neighbours(grid, monsters[idx].position[1], monsters[idx].position[0])

    def get_prediction(self, tower_dist: int, mob1_dist: int, mob2_dist: int, opp1_dist: int, opp2_dist: int,
                       opp3_dist: int, opp4_dist: int, agent_hp: int, tower_hp: int, mob1_hp: int, mob2_hp: int,
                       opp1_hp: int, opp2_hp: int, opp3_hp: int, opp4_hp) -> str:
        prediction = self.model.predict(
            [[tower_dist, mob1_dist, mob2_dist, opp1_dist, opp2_dist, opp3_dist, opp4_dist, agent_hp,
              tower_hp, mob1_hp, mob2_hp, opp1_hp, opp2_hp, opp3_hp, opp4_hp]])
        return self.goals_label_encoder.inverse_transform(prediction)[0]