125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
|
import math
|
||
|
import queue
|
||
|
from dataclasses import dataclass, field
|
||
|
from typing import Any
|
||
|
|
||
|
from domain.world import World
|
||
|
|
||
|
|
||
|
class State:
|
||
|
def __init__(self, x, y, direction=(1, 0)):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.direction = direction
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((self.x, self.y))
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return (self.x == other.x and self.y == other.y
|
||
|
and self.direction == other.direction)
|
||
|
|
||
|
|
||
|
class Node:
|
||
|
def __init__(self, state: State):
|
||
|
self.state = state
|
||
|
self.parent = None
|
||
|
self.action = None
|
||
|
self.g = None
|
||
|
self.h = None
|
||
|
|
||
|
|
||
|
@dataclass(order=True)
|
||
|
class PrioritizedItem:
|
||
|
priority: int
|
||
|
item: Any = field(compare=False)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter((self.priority, self.item))
|
||
|
|
||
|
|
||
|
def action_sequence(node: Node):
|
||
|
actions = []
|
||
|
while node.parent:
|
||
|
actions.append(node.action)
|
||
|
node = node.parent
|
||
|
|
||
|
print(node.g)
|
||
|
actions.reverse()
|
||
|
return actions
|
||
|
|
||
|
|
||
|
class RotateAndGoAStar:
|
||
|
def __init__(self, world: World, start_state: State, goal_state: State):
|
||
|
self.world = world
|
||
|
self.start_state = start_state
|
||
|
self.goal_state = goal_state
|
||
|
self.fringe = queue.PriorityQueue()
|
||
|
self.enqueued_states = {}
|
||
|
self.explored = set()
|
||
|
self.actions = []
|
||
|
|
||
|
def search(self):
|
||
|
h = abs(self.start_state.x - self.goal_state.x) ** 2 + abs(self.start_state.y - self.goal_state.y) ** 2
|
||
|
self.fringe.put(PrioritizedItem(h, Node(self.start_state)))
|
||
|
|
||
|
while not self.fringe.empty():
|
||
|
priority, elem = self.fringe.get()
|
||
|
|
||
|
self.enqueued_states.pop(elem, 0)
|
||
|
|
||
|
if self.is_goal(elem.state):
|
||
|
self.actions = action_sequence(elem)
|
||
|
return True
|
||
|
|
||
|
self.explored.add(elem.state)
|
||
|
|
||
|
for (action, state) in self.successors(elem.state):
|
||
|
next_node = Node(state)
|
||
|
next_node.action = action
|
||
|
next_node.parent = elem
|
||
|
next_node.g = abs(elem.state.x - state.x) + abs(elem.state.y - state.y) + self.world.get_cost(state.x, state.y)
|
||
|
if next_node.g > 100:
|
||
|
print(str(state.x) + ":" + str(state.y))
|
||
|
next_node.h = abs(state.x - self.goal_state.x) ** 2 + abs(state.y - self.goal_state.y) ** 2
|
||
|
f = next_node.g + next_node.h
|
||
|
|
||
|
if state not in self.enqueued_states and state not in self.explored:
|
||
|
self.fringe.put(PrioritizedItem(f, next_node))
|
||
|
self.enqueued_states[state] = f
|
||
|
elif self.enqueued_states.get(state, -math.inf) > f:
|
||
|
self.add_existed(next_node, f)
|
||
|
self.enqueued_states.pop(state, 0)
|
||
|
self.enqueued_states[state] = f
|
||
|
|
||
|
return False
|
||
|
|
||
|
def add_existed(self, node: Node, f: int):
|
||
|
old = []
|
||
|
while not self.fringe.empty():
|
||
|
e = self.fringe.get()
|
||
|
if e.item.state == node.state:
|
||
|
break
|
||
|
old.append(e)
|
||
|
self.fringe.put(PrioritizedItem(f, node))
|
||
|
for e in old:
|
||
|
self.fringe.put(e)
|
||
|
|
||
|
def successors(self, state: State):
|
||
|
new_successors = [
|
||
|
# rotate right
|
||
|
("RR", State(state.x, state.y, (-state.direction[1], state.direction[0]))),
|
||
|
# rotate left
|
||
|
("RL", State(state.x, state.y, (state.direction[1], -state.direction[0]))),
|
||
|
]
|
||
|
if self.world.accepted_move(state.x + state.direction[0], state.y + state.direction[1]):
|
||
|
new_successors.append(
|
||
|
("GO", State(state.x + state.direction[0], state.y + state.direction[1], state.direction)))
|
||
|
return new_successors
|
||
|
|
||
|
def is_goal(self, state: State) -> bool:
|
||
|
return (
|
||
|
state.x == self.goal_state.x
|
||
|
and state.y == self.goal_state.y
|
||
|
)
|