Machine_learning_2023/AI_brain/rotate_and_go_aStar.py

113 lines
3.8 KiB
Python
Raw Normal View History

2023-05-18 23:18:07 +02:00
import heapq
from domain.world import World
class State:
2023-06-15 16:46:08 +02:00
def __init__(self, x: int, y: int, direction=(1, 0), entity=None):
2023-05-18 23:18:07 +02:00
self.x = x
self.y = y
self.direction = direction
def __hash__(self):
return hash((self.x, self.y))
def __eq__(self, other):
return (
self.x == other.x
and self.y == other.y
and self.direction == other.direction
)
2023-06-15 16:46:08 +02:00
def heuristic(self, goal_state) -> int:
2023-05-18 23:18:07 +02:00
return abs(self.x - goal_state.x) + abs(self.y - goal_state.y)
class Node:
def __init__(self, state: State, g_score: int, goal_state: State):
self.state = state
self.g_score = g_score
self.f_score = g_score + state.heuristic(goal_state)
self.parent = None
self.action = None
def __lt__(self, other):
return self.f_score < other.f_score
def action_sequence(node: Node):
actions = []
while node.parent:
actions.append(node.action)
node = node.parent
actions.reverse()
return actions
class RotateAndGoAStar:
def __init__(self, world: World, start_state: State, goal_state: State):
self.world = world
self.start_state = start_state
self.goal_state = goal_state
self.fringe = []
self.enqueued_states = set()
self.explored = set()
self.actions = []
2023-06-15 16:46:08 +02:00
self.cost = 0
2023-05-18 23:18:07 +02:00
2023-06-15 16:46:08 +02:00
def get_g_score(self, state) -> int:
2023-05-18 23:18:07 +02:00
return self.world.get_cost(state.x, state.y)
def search(self):
2023-06-15 16:46:08 +02:00
heapq.heappush(self.fringe, Node(self.start_state, 0, self.goal_state))
2023-05-18 23:18:07 +02:00
while self.fringe:
2023-06-15 16:46:08 +02:00
elem: Node = heapq.heappop(self.fringe)
2023-05-18 23:18:07 +02:00
if self.is_goal(elem.state):
self.actions = action_sequence(elem)
2023-06-15 16:46:08 +02:00
self.cost = elem.g_score
2023-05-18 23:18:07 +02:00
return True
self.explored.add(elem.state)
for action, state in self.successors(elem.state):
if state in self.explored:
continue
2023-06-15 16:46:08 +02:00
new_g_score = elem.g_score + self.world.get_cost(state.x, state.y)
2023-05-18 23:18:07 +02:00
if state not in self.enqueued_states:
next_node = Node(state, new_g_score, self.goal_state)
next_node.action = action
next_node.parent = elem
heapq.heappush(self.fringe, next_node)
self.enqueued_states.add(state)
elif new_g_score < self.get_g_score(state):
for node in self.fringe:
if node.state == state:
node.g_score = new_g_score
2023-06-15 16:46:08 +02:00
node.f_score = new_g_score + node.state.heuristic(
self.goal_state
2023-05-18 23:18:07 +02:00
)
node.parent = elem
node.action = action
2023-06-15 16:46:08 +02:00
heapq.heapify(self.fringe)
2023-05-18 23:18:07 +02:00
break
return False
def successors(self, state: State):
new_successors = [
("RR", State(state.x, state.y, (-state.direction[1], state.direction[0]))),
("RL", State(state.x, state.y, (state.direction[1], -state.direction[0]))),
]
next_x = state.x + state.direction[0]
next_y = state.y + state.direction[1]
if self.world.accepted_move(next_x, next_y):
2023-06-15 16:46:08 +02:00
new_successors.append(("GO", State(next_x, next_y, state.direction)))
2023-05-18 23:18:07 +02:00
return new_successors
2023-06-15 16:46:08 +02:00
2023-05-18 23:18:07 +02:00
def is_goal(self, state: State) -> bool:
2023-06-15 16:46:08 +02:00
return state.x == self.goal_state.x and state.y == self.goal_state.y
def number_of_moves_forward(self) -> int:
go_count = self.actions.count("GO")
return go_count