from __future__ import annotations

from heapq import heappush, heappop, heapify
from typing import List
import itertools
import ctypes

from project_constants import Direction, Action
from minefield import Minefield

# temporary goal for testing
GOAL = (2, 6)


class State:
    def __init__(self, row, column, direction: Direction):
        self.row = row
        self.column = column
        self.direction = direction


class Node:
    def __init__(self, state: State, parent: Node = None, action: Action = None, cost=0):
        self.state = state
        self.parent = parent
        self.action = action
        self.cost = cost


def get_node_cost(node: Node, minefield: Minefield):
    row = node.state.row
    column = node.state.column

    if node.action != Action.GO:
        return node.parent.cost + 1

    # if Tile considered its mine in cost calculation, this code would be priettier
    if minefield.matrix[row][column].mine is not None:
        return node.parent.cost + 500
    else:
        return node.parent.cost + minefield.matrix[row][column].cost.value


def get_estimated_cost(node: Node):
    return abs(node.state.row - GOAL[0]) + abs(node.state.column - GOAL[1])


def tile_goal_test(state: State):
    if (state.row, state.column) == GOAL:
        return True
    return False


def mine_goal_test(state: State):
    if state.row == GOAL[0] and state.column == GOAL[1] - 1:
        if state.direction == Direction.RIGHT:
            return True

    elif state.row == GOAL[0] and state.column == GOAL[1] + 1:
        if state.direction == Direction.LEFT:
            return True

    elif state.row == GOAL[0] - 1 and state.column == GOAL[1]:
        if state.direction == Direction.DOWN:
            return True

    elif state.row == GOAL[0] + 1 and state.column == GOAL[1]:
        if state.direction == Direction.UP:
            return True

    return False


def get_successors(state: State, minefield: Minefield):
    successors = list()

    state_left = State(state.row, state.column, state.direction.previous())
    successors.append((Action.ROTATE_LEFT, state_left))

    state_right = State(state.row, state.column, state.direction.next())
    successors.append((Action.ROTATE_RIGHT, state_right))

    target = go(state.row, state.column, state.direction)

    if minefield.is_valid_move(target[0], target[1]):
        state_go = State(target[0], target[1], state.direction)
        successors.append((Action.GO, state_go))

    return successors


def graphsearch(initial_state: State,
                minefield: Minefield,
                fringe: List[Node] = None,
                explored: List[Node] = None,
                target_type: str = "tile",
                tox: int = None,
                toy: int = None):

    # reset global priority queue helpers
    global entry_finder
    global counter
    entry_finder = {}
    counter = itertools.count()

    global GOAL
    if tox is not None and toy is not None:
        GOAL = (tox, toy)

    if target_type == "mine":
        goal_test = mine_goal_test
    else:
        goal_test = tile_goal_test
        if minefield.matrix[GOAL[0]][GOAL[1]].mine is not None and minefield.matrix[GOAL[0]][GOAL[1]].mine.active:
            # TODO: cross-platform popup, move to separate function
            ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
            return []

    # fringe and explored initialization
    if fringe is None:
        fringe = list()
        heapify(fringe)
    if explored is None:
        explored = list()

    explored_states = set()
    fringe_states = set()

    # root Node
    add_node(fringe, Node(initial_state), 0)
    fringe_states.add((initial_state.row, initial_state.column, initial_state.direction))

    while True:
        # fringe empty -> solution not found
        if not any(fringe):
            # TODO: cross-platform popup, move to separate function
            ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
            return []

        # get first element from fringe
        element = pop_node(fringe)
        if element is None:
            # TODO: cross-platform popup, move to separate function
            ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
            return []

        fringe_states.remove((element.state.row, element.state.column, element.state.direction))

        # if solution was found, prepare and return actions sequence
        if goal_test(element.state):
            actions_sequence = [element.action]
            parent = element.parent

            while parent is not None:
                # root's action will be None, don't add it
                if parent.action is not None:
                    actions_sequence.append(parent.action)
                parent = parent.parent

            actions_sequence.reverse()
            return actions_sequence

        # add current node to explored (prevents infinite cycles)
        explored.append(element)
        explored_states.add((element.state.row, element.state.column, element.state.direction))

        # loop through every possible next action
        for successor in get_successors(element.state, minefield):

            new_node = Node(state=successor[1],
                            parent=element,
                            action=successor[0])

            new_node.cost = get_node_cost(new_node, minefield)
            priority = new_node.cost + get_estimated_cost(new_node)
            successor_state = (successor[1].row, successor[1].column, successor[1].direction)

            if successor_state not in fringe_states and \
                    successor_state not in explored_states:

                add_node(fringe, new_node, priority)
                fringe_states.add((new_node.state.row, new_node.state.column, new_node.state.direction))

            # update weight if it's lower
            elif successor_state in fringe_states and entry_finder[successor_state][0] > priority:
                update_priority(fringe, new_node, priority)

            else:
                del new_node


# TEMPORARY METHOD
def go(row, column, direction):
    target = tuple()

    if direction == Direction.RIGHT:
        target = row, column + 1
    elif direction == Direction.LEFT:
        target = row, column - 1
    elif direction == Direction.UP:
        target = row - 1, column
    elif direction == Direction.DOWN:
        target = row + 1, column

    return target


# PRIORITY QUEUE HANDLER
entry_finder = {}               # mapping of states to entries in a heap
REMOVED = '<removed-node>'      # placeholder for a removed nodes
counter = itertools.count()     # unique sequence count


def add_node(heap, node: Node, priority):
    count = next(counter)
    entry = [priority, count, node]
    entry_finder[(node.state.row, node.state.column, node.state.direction)] = entry
    heappush(heap, entry)


def pop_node(heap):
    while heap:
        priority, count, node = heappop(heap)
        if node is not REMOVED:
            del entry_finder[(node.state.row, node.state.column, node.state.direction)]
            return node
    return None


def update_priority(heap, new_node, new_priority):
    old_entry = entry_finder.pop((new_node.state.row, new_node.state.column, new_node.state.direction))
    old_entry[-1] = REMOVED
    add_node(heap, new_node, new_priority)