import heapq
from os import path
from settings import *


class Problem:
    def __init__(self, initial, goal):
        self.initial = initial
        self.goal = goal

    def actions(self, state):
        moves = []
        if self.turn_left(state):
            moves.append('Left')
        if self.turn_right(state):
            moves.append('Right')
        if self.move_forward(state):
            moves.append('Forward')

        # print(moves)
        return moves

    def turn_left(self, state):
        return True

    def turn_right(self, state):
        return True

    def move_forward(self, state):

        a_row = 0
        a_column = 0

        for row in range(MAP_SIZE):
            for column, pos in enumerate(state[row]):
                if pos == ">":
                    a_row = row
                    a_column = column

                    if a_column == MAP_SIZE-1:
                        return False
                    elif state[a_row][a_column+1] == '.':
                        return True
                    elif state[a_row][a_column+1] == 'p':
                        return True
                    return False
                if pos == "<":
                    a_row = row
                    a_column = column

                    if a_column == 0:
                        return False
                    elif state[a_row][a_column-1] == '.':
                        return True
                    return False

                if pos == "v":
                    a_row = row
                    a_column = column

                    if a_row == MAP_SIZE-1:
                        return False
                    elif state[a_row+1][a_column] == '.':
                        return True
                    return False
                if pos == "^":
                    a_row = row
                    a_column = column

                    if row == 0:
                        return False
                    elif state[a_row-1][a_column] == '.':
                        return True
                    return False

    def turn_me_or_move(self, state, do_it):

        temp_map = [list(item) for item in state]

        # print(temp_map)

        #a_row = 0
        #a_column = 0

        for row in range(MAP_SIZE):
            for column, pos in enumerate(temp_map[row]):
                if pos == ">":
                    a_row = row
                    a_column = column
                    #print("a_row:" + str(a_row))
                    #print("a_column" + str(a_column))

                    if(do_it == 'Left'):
                        temp_map[a_row][a_column] = "^"
                    if(do_it == 'Right'):
                        temp_map[a_row][a_column] = 'v'
                    if(do_it == 'Forward'):
                        temp_map[a_row][a_column] = '.'
                        temp_map[a_row][a_column+1] = '>'
                    return temp_map

                if pos == "<":
                    a_row = row
                    a_column = column
                    if(do_it == 'Left'):
                        temp_map[a_row][a_column] = 'v'
                    if(do_it == 'Right'):
                        temp_map[a_row][a_column] = '^'
                    if(do_it == 'Forward'):
                        temp_map[a_row][a_column] = '.'
                        temp_map[a_row][a_column-1] = '<'
                    return temp_map
                if pos == "v":
                    a_row = row
                    a_column = column
                    if(do_it == 'Left'):
                        temp_map[a_row][a_column] = '>'
                    if(do_it == 'Right'):
                        temp_map[a_row][a_column] = '<'
                    if(do_it == 'Forward'):
                        temp_map[a_row][a_column] = '.'
                        temp_map[a_row+1][a_column] = 'v'
                    return temp_map
                if pos == "^":
                    a_row = row
                    a_column = column
                    if(do_it == 'Left'):
                        temp_map[a_row][a_column] = '<'
                    if(do_it == 'Right'):
                        temp_map[a_row][a_column] = '>'
                    if(do_it == 'Forward'):
                        temp_map[a_row][a_column] = '.'
                        temp_map[a_row-1][a_column] = '^'
                    return temp_map
        return temp_map

    def result(self, state, action):
        new_state = []

        if action == 'Left':
            new_state = self.turn_me_or_move(state, 'Left')
        elif action == 'Right':
            new_state = self.turn_me_or_move(state, 'Right')
        elif action == 'Forward':
            new_state = self.turn_me_or_move(state, 'Forward')

        super_new_state = tuple(map(tuple, new_state))

        return super_new_state

    def goal_test(self, state):
        if self.goal == state:
            return True
        return False

    def path_cost(self, c, state1, action, state2, in_puddle1, in_puddle2):

        
        return c+1

    # funkcja heurystyki
    def h(self, node):
        node_row = node.row
        node_column = node.column

    


class Node:
    def __init__(self, state, parent=None, action=None, path_cost=0):
        """Create a search tree Node, derived from a parent by an action."""
        self.state = state
        self.parent = parent
        self.action = action
        self.path_cost = path_cost

        self.in_puddle = False

        #self.row = row
        #self.column = column

    def __repr__(self):
        return "<Node {}>".format(self.state)

    def expand(self, problem):
        """List the nodes reachable in one step from this node."""
        return [self.child_node(problem, action)
                for action in problem.actions(self.state)]

    def child_node(self, problem, action):
        next_state = problem.result(self.state, action)
        next_node = Node(next_state, self, action, problem.path_cost(self.path_cost, self.state, action, next_state, in_puddle))
        return next_node

    def where_am_i(self):
        temp_map = [list(item) for item in state]

        for row in range(MAP_SIZE):
            for column, pos in enumerate(temp_map[row]):
                if pos == ">" or pos == "<" or pos == "^" or pos == "v":
                    self.row = row
                    self.column = column



    

    

    def __eq__(self, other):
        return isinstance(other, Node) and self.state == other.state

    def __hash__(self):
        # We use the hash value of the state
        # stored in the node instead of the node
        # object itself to quickly search a node
        # with the same state in a Hash Table
        return hash(self.state)


class PriorityQueue:
    """A Queue in which the minimum (or maximum) element (as determined by f and
    order) is returned first.
    If order is 'min', the item with minimum f(x) is
    returned first; if order is 'max', then it is the item with maximum f(x).
    Also supports dict-like lookup."""

    def __init__(self, order='min', f=lambda x: x):
        self.heap = []
        if order == 'min':
            self.f = f
        elif order == 'max':  # now item with max f(x)
            self.f = lambda x: -f(x)  # will be popped first
        else:
            raise ValueError("Order must be either 'min' or 'max'.")

    def append(self, item):
        """Insert item at its correct position."""
        heapq.heappush(self.heap, (self.f(item), item))

    def extend(self, items):
        """Insert each item in items at its correct position."""
        for item in items:
            self.append(item)

    def pop(self):
        """Pop and return the item (with min or max f(x) value)
        depending on the order."""
        if self.heap:
            return heapq.heappop(self.heap)[1]
        else:
            raise Exception('Trying to pop from empty PriorityQueue.')

    def __len__(self):
        """Return current capacity of PriorityQueue."""
        return len(self.heap)

    def __contains__(self, key):
        """Return True if the key is in PriorityQueue."""
        return any([item == key for _, item in self.heap])

    def __getitem__(self, key):
        """Returns the first value associated with key in PriorityQueue.
        Raises KeyError if key is not present."""
        for value, item in self.heap:
            if item == key:
                return value
        raise KeyError(str(key) + " is not in the priority queue")

    def __delitem__(self, key):
        """Delete the first occurrence of key."""
        try:
            del self.heap[[item == key for _, item in self.heap].index(True)]
        except ValueError:
            raise KeyError(str(key) + " is not in the priority queue")
        heapq.heapify(self.heap)


class Astar:
    @staticmethod
    def best_first_graph_search(problem, f, display=False):
        """Search the nodes with the lowest f scores first.
        You specify the function f(node) that you want to minimize; for example,
        if f is a heuristic estimate to the goal, then we have greedy best
        first search; if f is node.depth then we have breadth-first search.
        There is a subtlety: the line "f = memoize(f, 'f')" means that the f
        values will be cached on the nodes as they are computed. So after doing
        a best first search you can examine the f values of the path returned."""
        #f = memoize(f, 'f')
        node = Node(problem.initial)
        # PriorityQueue ma przechowywac g+h
        frontier = PriorityQueue('min', f)
        frontier.append(node)
        explored = set()
        while frontier:
            node = frontier.pop()
            if problem.goal_test(node.state):
                if display:
                    print(len(explored), "paths have been expanded and",
                        len(frontier), "paths remain in the frontier")
                return node
            explored.add(node.state)
            for child in node.expand(problem):
                if child.state not in explored and child not in frontier:
                    frontier.append(child)
                elif child in frontier:
                    if f(child) < frontier[child]:
                        del frontier[child]
                        frontier.append(child)
        return None
    
    @staticmethod
    def loadMap(map_name=''):
            maze = []
            map_folder = path.dirname(__file__)
            with open(path.join(map_folder, map_name), 'rt') as f:
                for line in f:
                    maze.append(line.rstrip('\n'))
            
            #print(maze)
            return maze
    
    @staticmethod
    def run():
        initial_map = tuple(map(tuple, Astar.loadMap('map.txt')))
        goal_map = tuple(map(tuple, Astar.loadMap('goal_map.txt')))
        problem = Problem(initial_map, goal_map)

        #BFS.print_node_state(initial_map)
        #BFS.print_node_state(goal_map)

        result = Astar.breadth_first_graph_search(problem)
        print(result)
        return result
        #print(BFS.print_node_state(result))