from __future__ import annotations import heapq from dataclasses import dataclass, field from typing import Tuple, Optional, List from common.constants import ROWS, COLUMNS FREE_FIELD = ' ' LEFT = 'LEFT' RIGHT = 'RIGHT' UP = 'UP' DOWN = 'DOWN' TURN_LEFT = 'TURN_LEFT' TURN_RIGHT = 'TURN_RIGHT' FORWARD = 'FORWARD' directions = { LEFT: (0, -1), RIGHT: (0, 1), UP: (-1, 0), DOWN: (1, 0) } @dataclass class State: position: Tuple[int, int] direction: str def __lt__(self, state): return self.position < state.position def __hash__(self) -> int: return hash(self.position) @dataclass class Node: state: State parent: Optional[Node] action: Optional[str] cost: int = field(init=False) depth: int = field(init=False) def __lt__(self, node): return self.state < node.state def __post_init__(self) -> None: self.cost = 0 if not self.parent else self.parent.cost + 1 self.depth = self.cost def __eq__(self, other: Node) -> bool: return self.state == other.state def __hash__(self) -> int: return hash(self.state) def expand(node: Node, grid: List[List[str]]) -> List[Node]: return [child_node(node=node, action=action) for action in actions(node.state, grid)] def child_node(node: Node, action: str) -> Node: next_state = result(state=node.state, action=action) return Node(state=next_state, parent=node, action=action) def next_position(current_position: Tuple[int, int], direction: str) -> Tuple[int, int]: next_row, next_col = directions[direction] row, col = current_position return next_row + row, next_col + col def valid_move(position: Tuple[int, int], grid: List[List[str]]) -> bool: row, col = position return grid[row][col] == FREE_FIELD def actions(state: State, grid: List[List[str]]) -> List[str]: possible_actions = [FORWARD, TURN_LEFT, TURN_RIGHT] row, col = state.position direction = state.direction if direction == UP and row == 0: remove_forward(possible_actions) if direction == DOWN and row == ROWS - 1: remove_forward(possible_actions) if direction == LEFT and col == 0: remove_forward(possible_actions) if direction == RIGHT and col == COLUMNS - 1: remove_forward(possible_actions) if FORWARD not in possible_actions and not valid_move(next_position(state.position, direction), grid): remove_forward(possible_actions) return possible_actions def remove_forward(possible_actions: List[str]) -> None: if FORWARD in possible_actions: possible_actions.remove(FORWARD) def result(state: State, action: str) -> State: next_state = State(state.position, state.direction) if state.direction == UP: if action == TURN_LEFT: next_state.direction = LEFT elif action == TURN_RIGHT: next_state.direction = RIGHT elif action == FORWARD: next_state.position = next_position(state.position, UP) elif state.direction == DOWN: if action == TURN_LEFT: next_state.direction = RIGHT elif action == TURN_RIGHT: next_state.direction = LEFT elif action == FORWARD: next_state.position = next_position(state.position, DOWN) elif state.direction == LEFT: if action == TURN_LEFT: next_state.direction = DOWN elif action == TURN_RIGHT: next_state.direction = UP elif action == FORWARD: next_state.position = next_position(state.position, LEFT) elif state.direction == RIGHT: if action == TURN_LEFT: next_state.direction = UP elif action == TURN_RIGHT: next_state.direction = DOWN elif action == FORWARD: next_state.position = next_position(state.position, RIGHT) return next_state def goal_test(state: State, goal_list: List[Tuple[int, int]]) -> bool: return state.position in goal_list def h(state: State, goal: Tuple[int, int]) -> int: """heuristics that calculates Manhattan distance between current position and goal""" x1, y1 = state.position x2, y2 = goal return abs(x1 - x2) + abs(y1 - y2) def f(current_node: Node, goal: Tuple[int, int]) -> int: """f(n) = g(n) + h(n), g stands for current cost, h for heuristics""" return current_node.cost + h(state=current_node.state, goal=goal) def get_path_from_start(node: Node) -> List[str]: path = [node] while node.parent is not None: node = node.parent path.append(node.action) path.reverse() return path def a_star(state: State, grid: List[List[str]], goals: List[Tuple[int, int]]) -> List[str]: node = Node(state=state, parent=None, action=None) frontier = list() heapq.heappush(frontier, (f(node, goals[0]), node)) explored = set() while frontier: r, node = heapq.heappop(frontier) if goal_test(node.state, goals): return get_path_from_start(node) explored.add(node.state) for child in expand(node, grid): p = f(child, goals[0]) if child.state not in explored and (p, child) not in frontier: heapq.heappush(frontier, (p, child)) elif (r, child) in frontier and r > p: heapq.heappush(frontier, (p, child)) return []