feat: a_star implementation

This commit is contained in:
korzepadawid 2022-04-12 00:21:30 +02:00
parent 544c5276d5
commit 286164c1dd

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import heapq
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Optional, List from typing import Tuple, Optional, List
@ -33,7 +34,7 @@ class State:
class Node: class Node:
state: State state: State
parent: Optional[Node] parent: Optional[Node]
action: str action: Optional[str]
cost: int = field(init=False) cost: int = field(init=False)
depth: int = field(init=False) depth: int = field(init=False)
@ -48,8 +49,8 @@ class Node:
return hash(self.state) return hash(self.state)
def expand(node: Node) -> List[Node]: def expand(node: Node, grid: List[List[str]]) -> List[Node]:
return [child_node(node=node, action=action) for action in actions(node.state)] return [child_node(node=node, action=action) for action in actions(node.state, grid)]
def child_node(node: Node, action: str) -> Node: def child_node(node: Node, action: str) -> Node:
@ -147,5 +148,37 @@ def f(current_node: Node, goal: Tuple[int, int]) -> int:
return current_node.cost + h(state=current_node.state, goal=goal) return current_node.cost + h(state=current_node.state, goal=goal)
def get_path_from_start(node: Node) -> List[str]:
path = [node]
while node.parent is not None:
node = node.parent
path.append(node.action)
path.reverse()
return path
def a_star(state: State, grid: List[List[str]], goals: List[Tuple[int, int]]) -> List[str]: def a_star(state: State, grid: List[List[str]], goals: List[Tuple[int, int]]) -> List[str]:
node = Node(state=state, parent=None, action=None)
frontier = list()
heapq.heappush(frontier, (f(node, goals[0]), node))
explored = set()
while frontier:
r, node = heapq.heappop(frontier)
if goal_test(node.state, goals):
return get_path_from_start(node)
explored.add(node.state)
for child in expand(node, grid):
p = f(child, goals[0])
if child.state not in explored and (p, child) not in frontier:
heapq.heappush(frontier, (p, child))
elif (r, child) in frontier and r > p:
heapq.heappush(frontier, (p, child))
return [] return []