From 286164c1dd30d8006287191090b07136bf98f588 Mon Sep 17 00:00:00 2001 From: korzepadawid Date: Tue, 12 Apr 2022 00:21:30 +0200 Subject: [PATCH] feat: a_star implementation --- algorithms/a_star.py | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/algorithms/a_star.py b/algorithms/a_star.py index c657c1e..d6f630a 100644 --- a/algorithms/a_star.py +++ b/algorithms/a_star.py @@ -1,5 +1,6 @@ from __future__ import annotations +import heapq from dataclasses import dataclass, field from typing import Tuple, Optional, List @@ -33,7 +34,7 @@ class State: class Node: state: State parent: Optional[Node] - action: str + action: Optional[str] cost: int = field(init=False) depth: int = field(init=False) @@ -48,8 +49,8 @@ class Node: return hash(self.state) -def expand(node: Node) -> List[Node]: - return [child_node(node=node, action=action) for action in actions(node.state)] +def expand(node: Node, grid: List[List[str]]) -> List[Node]: + return [child_node(node=node, action=action) for action in actions(node.state, grid)] def child_node(node: Node, action: str) -> Node: @@ -147,5 +148,37 @@ def f(current_node: Node, goal: Tuple[int, int]) -> int: return current_node.cost + h(state=current_node.state, goal=goal) +def get_path_from_start(node: Node) -> List[str]: + path = [node] + + while node.parent is not None: + node = node.parent + path.append(node.action) + + path.reverse() + return path + + def a_star(state: State, grid: List[List[str]], goals: List[Tuple[int, int]]) -> List[str]: + node = Node(state=state, parent=None, action=None) + + frontier = list() + heapq.heappush(frontier, (f(node, goals[0]), node)) + explored = set() + + while frontier: + r, node = heapq.heappop(frontier) + + if goal_test(node.state, goals): + return get_path_from_start(node) + + explored.add(node.state) + + for child in expand(node, grid): + p = f(child, goals[0]) + if child.state not in explored and (p, child) not in frontier: + heapq.heappush(frontier, (p, child)) + elif (r, child) in frontier and r > p: + heapq.heappush(frontier, (p, child)) + return []