wozek-projekt/a_star.py

157 lines
4.3 KiB
Python

import heapq
from num_map import num_matrix
MAX_ROWS = 15
MAX_COLS = 25
mapping = {
'd': 2,
's': 20,
'r': 999
}
class State:
"""
Directions
UP: 0
RIGHT: 1
DOWN: 2
LEFT: 3
"""
def __init__(self, row, column, direction):
self.direction = direction
self.row = row
self.column = column
def rotate_left(self):
return (self.direction - 1) % 4
def rotate_right(self):
return (self.direction + 1) % 4
def __eq__(self, state: "State"):
return (state.row, state.column, state.direction) == (self.row, self.column, self.direction)
def __lt__(self, state: "State"):
return (self.row, self.column) < (state.row, state.column)
def __hash__(self):
return hash((self.row, self.column))
class Node:
def __init__(self, state: "State", parent=None, action=None, cost=1):
self.state = state
self.parent = parent
self.action = action
self.cost = cost
def __lt__(self, node):
return self.state < node.state
def h(state: State, goal: tuple[int, int]):
"""
Heuristics calculating Manhattan distance
"""
x1, y1 = state.row, state.column
x2, y2 = goal
return abs(x1 - x2) + abs(y1 - y2)
def f(curr_node: Node, goal: tuple[int, int]):
"""
f(n) = g(n) + h(n)
"""
return curr_node.cost + h(state=curr_node.state, goal=goal)
def goal_test(goal_list, state: State):
if (state.row, state.column) == goal_list:
return True
return False
def is_valid_move(target_row, target_column):
if 0 <= target_row < MAX_ROWS and 0 < target_column < MAX_COLS:
return True
return False
def get_successor(state: "State"):
successors = list()
rotate_left = State(row=state.row, column=state.column, direction=state.rotate_left())
rotate_right = State(row=state.row, column=state.column, direction=state.rotate_right())
successors.append(('rotate_left', rotate_left))
successors.append(('rotate_right', rotate_right))
if state.direction == 0:
if is_valid_move(target_row=state.row-1, target_column=state.column):
forward = State(row=state.row-1, column=state.column, direction=state.direction)
successors.append(('go', forward))
elif state.direction == 1:
if is_valid_move(target_row=state.row, target_column=state.column+1):
forward = State(row=state.row, column=state.column+1, direction=state.direction)
successors.append(('go', forward))
elif state.direction == 2:
if is_valid_move(target_row=state.row+1, target_column=state.column):
forward = State(row=state.row+1, column=state.column, direction=state.direction)
successors.append(('go', forward))
elif state.direction == 3:
if is_valid_move(target_row=state.row, target_column=state.column-1):
forward = State(row=state.row, column=state.column-1, direction=state.direction)
successors.append(('go', forward))
return successors
def get_path_from_start(node: Node):
path = [node.action]
while node.parent is not None:
node = node.parent
if node.action:
path.append(node.action)
path.reverse()
return path
def check_cost(row: int, col: int, action):
if action in ('rotate_left', 'rotate_right'):
return 1
else:
return mapping.get(num_matrix[row][col], 1)
def a_star(state: State, goal: tuple[int, int]):
node = Node(state=state, parent=None, action=None)
fringe = list()
heapq.heappush(fringe, (f(node, goal), node))
explored_states = set()
while fringe:
r, node = heapq.heappop(fringe)
if goal_test(goal, node.state):
return get_path_from_start(node)
explored_states.add(node.state)
for successor in get_successor(node.state):
action, next_state = successor
movement_cost = check_cost(row=next_state.row, col=next_state.column, action=action)+node.cost
child = Node(state=next_state, parent=node, action=action, cost=movement_cost)
p = f(child, goal=goal)
if child.state not in explored_states and (p, child) not in fringe:
heapq.heappush(fringe, (p, child))
elif (r, child) in fringe and r > p:
heapq.heappush(fringe, (p, child))