a_star #21

Merged
s464961 merged 15 commits from a_star into master 2022-04-27 19:50:51 +02:00
Showing only changes of commit 286164c1dd - Show all commits

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import heapq
from dataclasses import dataclass, field
from typing import Tuple, Optional, List
@ -33,7 +34,7 @@ class State:
class Node:
state: State
parent: Optional[Node]
action: str
action: Optional[str]
cost: int = field(init=False)
depth: int = field(init=False)
@ -48,8 +49,8 @@ class Node:
return hash(self.state)
def expand(node: Node) -> List[Node]:
return [child_node(node=node, action=action) for action in actions(node.state)]
def expand(node: Node, grid: List[List[str]]) -> List[Node]:
return [child_node(node=node, action=action) for action in actions(node.state, grid)]
def child_node(node: Node, action: str) -> Node:
@ -147,5 +148,37 @@ def f(current_node: Node, goal: Tuple[int, int]) -> int:
return current_node.cost + h(state=current_node.state, goal=goal)
def get_path_from_start(node: Node) -> List[str]:
path = [node]
while node.parent is not None:
node = node.parent
path.append(node.action)
path.reverse()
return path
def a_star(state: State, grid: List[List[str]], goals: List[Tuple[int, int]]) -> List[str]:
node = Node(state=state, parent=None, action=None)
frontier = list()
heapq.heappush(frontier, (f(node, goals[0]), node))
explored = set()
while frontier:
r, node = heapq.heappop(frontier)
if goal_test(node.state, goals):
return get_path_from_start(node)
explored.add(node.state)
for child in expand(node, grid):
p = f(child, goals[0])
if child.state not in explored and (p, child) not in frontier:
heapq.heappush(frontier, (p, child))
elif (r, child) in frontier and r > p:
heapq.heappush(frontier, (p, child))
return []