WMICraft/algorithms/a_star.py

199 lines
5.7 KiB
Python
Raw Normal View History

2022-04-11 19:18:03 +02:00
from __future__ import annotations
2022-04-12 00:21:30 +02:00
import heapq
2022-04-11 19:18:03 +02:00
from dataclasses import dataclass, field
2022-04-11 19:49:32 +02:00
from typing import Tuple, Optional, List
2022-04-11 19:18:03 +02:00
2022-04-11 21:52:21 +02:00
from common.constants import ROWS, COLUMNS
2022-04-13 18:35:02 +02:00
EMPTY_FIELDS = ['s', 'g', ' ']
2022-04-11 22:52:58 +02:00
LEFT = 'LEFT'
RIGHT = 'RIGHT'
UP = 'UP'
DOWN = 'DOWN'
2022-04-11 21:52:21 +02:00
2022-04-11 23:53:50 +02:00
TURN_LEFT = 'TURN_LEFT'
TURN_RIGHT = 'TURN_RIGHT'
FORWARD = 'FORWARD'
2022-04-11 22:52:58 +02:00
directions = {
LEFT: (0, -1),
RIGHT: (0, 1),
UP: (-1, 0),
DOWN: (1, 0)
}
2022-04-11 19:49:32 +02:00
2022-04-11 19:18:03 +02:00
@dataclass
class State:
position: Tuple[int, int]
2022-04-11 22:52:58 +02:00
direction: str
2022-04-11 19:18:03 +02:00
2022-04-12 19:12:41 +02:00
def __eq__(self, other: State) -> bool:
return other.position == self.position and self.direction == other.direction
2022-04-12 00:35:01 +02:00
def __lt__(self, state):
return self.position < state.position
def __hash__(self) -> int:
return hash(self.position)
2022-04-11 19:18:03 +02:00
@dataclass
class Node:
state: State
parent: Optional[Node]
2022-04-12 00:21:30 +02:00
action: Optional[str]
2022-04-13 18:35:02 +02:00
grid: List[List[str]]
2022-04-11 19:18:03 +02:00
cost: int = field(init=False)
depth: int = field(init=False)
2022-04-12 19:12:41 +02:00
def __lt__(self, node) -> None:
2022-04-12 00:35:01 +02:00
return self.state < node.state
2022-04-11 19:18:03 +02:00
def __post_init__(self) -> None:
2022-04-13 18:35:02 +02:00
if self.grid[self.state.position[0]][self.state.position[1]] == 'g':
self.cost = 1 if not self.parent else self.parent.cost + 1
else:
2022-04-13 18:56:39 +02:00
self.cost = 2 if not self.parent else self.parent.cost + 2
2022-04-13 18:35:02 +02:00
self.depth = 0 if not self.parent else self.parent.depth + 1
2022-04-11 19:49:32 +02:00
def __hash__(self) -> int:
return hash(self.state)
2022-04-12 00:21:30 +02:00
def expand(node: Node, grid: List[List[str]]) -> List[Node]:
2022-04-13 18:35:02 +02:00
return [child_node(node=node, action=action, grid=grid) for action in actions(node.state, grid)]
2022-04-11 20:04:53 +02:00
2022-04-13 18:35:02 +02:00
def child_node(node: Node, action: str, grid: List[List[str]]) -> Node:
2022-04-11 20:04:53 +02:00
next_state = result(state=node.state, action=action)
2022-04-13 18:35:02 +02:00
return Node(state=next_state, parent=node, action=action, grid=grid)
2022-04-11 19:49:32 +02:00
2022-04-11 22:52:58 +02:00
def next_position(current_position: Tuple[int, int], direction: str) -> Tuple[int, int]:
next_row, next_col = directions[direction]
row, col = current_position
return next_row + row, next_col + col
2022-04-11 21:52:21 +02:00
def valid_move(position: Tuple[int, int], grid: List[List[str]]) -> bool:
row, col = position
2022-04-13 18:35:02 +02:00
return grid[row][col] in EMPTY_FIELDS
2022-04-11 21:52:21 +02:00
2022-04-11 22:52:58 +02:00
def actions(state: State, grid: List[List[str]]) -> List[str]:
possible_actions = [FORWARD, TURN_LEFT, TURN_RIGHT]
2022-04-11 21:52:21 +02:00
row, col = state.position
direction = state.direction
2022-04-11 22:52:58 +02:00
if direction == UP and row == 0:
2022-04-11 21:52:21 +02:00
remove_forward(possible_actions)
2022-04-11 22:52:58 +02:00
if direction == DOWN and row == ROWS - 1:
2022-04-11 21:52:21 +02:00
remove_forward(possible_actions)
2022-04-11 22:52:58 +02:00
if direction == LEFT and col == 0:
2022-04-11 21:52:21 +02:00
remove_forward(possible_actions)
2022-04-11 22:52:58 +02:00
if direction == RIGHT and col == COLUMNS - 1:
2022-04-11 21:52:21 +02:00
remove_forward(possible_actions)
2022-04-11 19:49:32 +02:00
2022-04-12 20:21:29 +02:00
if FORWARD in possible_actions and not valid_move(next_position(state.position, direction), grid):
2022-04-11 21:52:21 +02:00
remove_forward(possible_actions)
2022-04-11 19:49:32 +02:00
2022-04-11 21:52:21 +02:00
return possible_actions
2022-04-11 22:52:58 +02:00
def remove_forward(possible_actions: List[str]) -> None:
if FORWARD in possible_actions:
possible_actions.remove(FORWARD)
2022-04-11 21:52:21 +02:00
2022-04-11 23:53:50 +02:00
def result(state: State, action: str) -> State:
next_state = State(state.position, state.direction)
2022-04-11 23:53:50 +02:00
if state.direction == UP:
if action == TURN_LEFT:
next_state.direction = LEFT
2022-04-11 23:53:50 +02:00
elif action == TURN_RIGHT:
next_state.direction = RIGHT
2022-04-11 23:53:50 +02:00
elif action == FORWARD:
next_state.position = next_position(state.position, UP)
2022-04-11 23:53:50 +02:00
elif state.direction == DOWN:
if action == TURN_LEFT:
next_state.direction = RIGHT
2022-04-11 23:53:50 +02:00
elif action == TURN_RIGHT:
next_state.direction = LEFT
2022-04-11 23:53:50 +02:00
elif action == FORWARD:
next_state.position = next_position(state.position, DOWN)
2022-04-11 23:53:50 +02:00
elif state.direction == LEFT:
if action == TURN_LEFT:
next_state.direction = DOWN
2022-04-11 23:53:50 +02:00
elif action == TURN_RIGHT:
next_state.direction = UP
2022-04-11 23:53:50 +02:00
elif action == FORWARD:
next_state.position = next_position(state.position, LEFT)
2022-04-11 23:53:50 +02:00
elif state.direction == RIGHT:
if action == TURN_LEFT:
next_state.direction = UP
2022-04-11 23:53:50 +02:00
elif action == TURN_RIGHT:
next_state.direction = DOWN
2022-04-11 23:53:50 +02:00
elif action == FORWARD:
next_state.position = next_position(state.position, RIGHT)
2022-04-11 23:53:50 +02:00
return next_state
2022-04-11 19:49:32 +02:00
def goal_test(state: State, goal_list: List[Tuple[int, int]]) -> bool:
2022-04-11 21:52:21 +02:00
return state.position in goal_list
2022-04-11 19:49:32 +02:00
2022-04-11 21:52:21 +02:00
def h(state: State, goal: Tuple[int, int]) -> int:
"""heuristics that calculates Manhattan distance between current position and goal"""
x1, y1 = state.position
x2, y2 = goal
2022-04-11 22:52:58 +02:00
return abs(x1 - x2) + abs(y1 - y2)
2022-04-11 19:49:32 +02:00
2022-04-11 21:52:21 +02:00
def f(current_node: Node, goal: Tuple[int, int]) -> int:
"""f(n) = g(n) + h(n), g stands for current cost, h for heuristics"""
return current_node.cost + h(state=current_node.state, goal=goal)
2022-04-12 00:05:47 +02:00
2022-04-12 00:21:30 +02:00
def get_path_from_start(node: Node) -> List[str]:
2022-04-12 19:12:41 +02:00
path = [node.action]
2022-04-12 00:21:30 +02:00
while node.parent is not None:
node = node.parent
2022-04-12 19:12:41 +02:00
if node.action:
path.append(node.action)
2022-04-12 00:21:30 +02:00
path.reverse()
return path
2022-04-12 00:05:47 +02:00
def a_star(state: State, grid: List[List[str]], goals: List[Tuple[int, int]]) -> List[str]:
2022-04-13 18:35:02 +02:00
node = Node(state=state, parent=None, action=None, grid=grid)
2022-04-12 00:21:30 +02:00
frontier = list()
heapq.heappush(frontier, (f(node, goals[0]), node))
explored = set()
while frontier:
r, node = heapq.heappop(frontier)
if goal_test(node.state, goals):
return get_path_from_start(node)
explored.add(node.state)
for child in expand(node, grid):
p = f(child, goals[0])
if child.state not in explored and (p, child) not in frontier:
heapq.heappush(frontier, (p, child))
elif (r, child) in frontier and r > p:
heapq.heappush(frontier, (p, child))
2022-04-12 00:05:47 +02:00
return []