113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
|
import heapq
|
||
|
from domain.world import World
|
||
|
|
||
|
|
||
|
class State:
|
||
|
def __init__(self, x, y, direction=(1, 0), entity=None):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.direction = direction
|
||
|
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((self.x, self.y))
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return (
|
||
|
self.x == other.x
|
||
|
and self.y == other.y
|
||
|
and self.direction == other.direction
|
||
|
)
|
||
|
|
||
|
def heuristic(self, goal_state):
|
||
|
return abs(self.x - goal_state.x) + abs(self.y - goal_state.y)
|
||
|
|
||
|
|
||
|
class Node:
|
||
|
def __init__(self, state: State, g_score: int, goal_state: State):
|
||
|
self.state = state
|
||
|
self.g_score = g_score
|
||
|
self.f_score = g_score + state.heuristic(goal_state)
|
||
|
self.parent = None
|
||
|
self.action = None
|
||
|
|
||
|
def __lt__(self, other):
|
||
|
return self.f_score < other.f_score
|
||
|
|
||
|
|
||
|
def action_sequence(node: Node):
|
||
|
actions = []
|
||
|
while node.parent:
|
||
|
actions.append(node.action)
|
||
|
node = node.parent
|
||
|
actions.reverse()
|
||
|
return actions
|
||
|
|
||
|
|
||
|
class RotateAndGoAStar:
|
||
|
def __init__(self, world: World, start_state: State, goal_state: State):
|
||
|
self.world = world
|
||
|
self.start_state = start_state
|
||
|
self.goal_state = goal_state
|
||
|
self.fringe = []
|
||
|
self.enqueued_states = set()
|
||
|
self.explored = set()
|
||
|
self.actions = []
|
||
|
|
||
|
def get_g_score(self, state):
|
||
|
return self.world.get_cost(state.x, state.y)
|
||
|
|
||
|
def search(self):
|
||
|
heapq.heappush(
|
||
|
self.fringe, Node(self.start_state, 0, self.goal_state)
|
||
|
)
|
||
|
|
||
|
while self.fringe:
|
||
|
elem = heapq.heappop(self.fringe)
|
||
|
if self.is_goal(elem.state):
|
||
|
self.actions = action_sequence(elem)
|
||
|
return True
|
||
|
self.explored.add(elem.state)
|
||
|
|
||
|
for action, state in self.successors(elem.state):
|
||
|
if state in self.explored:
|
||
|
continue
|
||
|
|
||
|
new_g_score = new_g_score = elem.g_score + self.world.get_cost(state.x, state.y)
|
||
|
if state not in self.enqueued_states:
|
||
|
next_node = Node(state, new_g_score, self.goal_state)
|
||
|
next_node.action = action
|
||
|
next_node.parent = elem
|
||
|
heapq.heappush(self.fringe, next_node)
|
||
|
self.enqueued_states.add(state)
|
||
|
elif new_g_score < self.get_g_score(state):
|
||
|
for node in self.fringe:
|
||
|
if node.state == state:
|
||
|
node.g_score = new_g_score
|
||
|
node.f_score = (
|
||
|
new_g_score + node.state.heuristic(self.goal_state)
|
||
|
)
|
||
|
node.parent = elem
|
||
|
node.action = action
|
||
|
heapq.heapify(self.fringe)
|
||
|
break
|
||
|
|
||
|
return False
|
||
|
|
||
|
def successors(self, state: State):
|
||
|
new_successors = [
|
||
|
("RR", State(state.x, state.y, (-state.direction[1], state.direction[0]))),
|
||
|
("RL", State(state.x, state.y, (state.direction[1], -state.direction[0]))),
|
||
|
]
|
||
|
next_x = state.x + state.direction[0]
|
||
|
next_y = state.y + state.direction[1]
|
||
|
if self.world.accepted_move(next_x, next_y):
|
||
|
new_successors.append(
|
||
|
("GO", State(next_x, next_y, state.direction))
|
||
|
)
|
||
|
return new_successors
|
||
|
|
||
|
def is_goal(self, state: State) -> bool:
|
||
|
return (
|
||
|
state.x == self.goal_state.x
|
||
|
and state.y == self.goal_state.y )
|