diff --git a/main.py b/main.py index 9f46b88..23b002d 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from pyglet.gl import * # for blocky textures # other files of this project import project_constants as const import minefield as mf -import searching_algorithms.bfs as bfs +import searching_algorithms.a_star as a_star from display_assets import blit_graphics @@ -27,8 +27,8 @@ def main(): minefield = mf.Minefield(const.MAP_RANDOM_10x10) # get sequence of actions found by BFS algorithm - action_sequence = bfs.graphsearch( - initial_state=bfs.State( + action_sequence = a_star.graphsearch( + initial_state=a_star.State( row=minefield.agent.position[0], column=minefield.agent.position[1], direction=const.Direction.UP), diff --git a/searching_algorithms/a_star.py b/searching_algorithms/a_star.py new file mode 100644 index 0000000..a1e36f1 --- /dev/null +++ b/searching_algorithms/a_star.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from heapq import heappush, heappop, heapify +from typing import List +import itertools +import ctypes + +from project_constants import Direction, Action +from minefield import Minefield + +# temporary goal for testing +GOAL = (2, 6) + + +class State: + def __init__(self, row, column, direction: Direction): + self.row = row + self.column = column + self.direction = direction + + +class Node: + def __init__(self, state: State, parent: Node = None, action: Action = None, weight=0): + self.state = state + self.parent = parent + self.action = action + + if not weight: + self.weight = self._get_weight() + else: + self.weight = weight + + def _get_weight(self): + weight = 0 + if self.parent is not None: + weight += self.parent.weight + + heuristics = abs(self.state.row - GOAL[0]) + abs(self.state.column - GOAL[1]) + weight += heuristics + + return weight + + +def goal_test(state: State): + if (state.row, state.column) == GOAL: + return True + return False + + +def get_successors(state: State, minefield: Minefield): + successors = list() + + state_left = State(state.row, state.column, state.direction.previous()) + successors.append((Action.ROTATE_LEFT, state_left)) + + state_right = State(state.row, state.column, state.direction.next()) + successors.append((Action.ROTATE_RIGHT, state_right)) + + target = go(state.row, state.column, state.direction) + + if minefield.is_valid_move(target[0], target[1]): + state_go = State(target[0], target[1], state.direction) + successors.append((Action.GO, state_go)) + + return successors + + +def graphsearch(initial_state: State, minefield: Minefield, fringe: List[Node] = None, explored: List[Node] = None): + # reset global priority queue helpers + global entry_finder + global counter + entry_finder = {} + counter = itertools.count() + + # fringe and explored initialization + if fringe is None: + fringe = list() + heapify(fringe) + if explored is None: + explored = list() + + explored_states = set() + fringe_states = set() + + # root Node + add_node(fringe, Node(initial_state)) + fringe_states.add((initial_state.row, initial_state.column, initial_state.direction)) + + while True: + # fringe empty -> solution not found + if not any(fringe): + ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1) + return [] + + # get first element from fringe + element = pop_node(fringe) + if element is None: + ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1) + return [] + + fringe_states.remove((element.state.row, element.state.column, element.state.direction)) + + # if solution was found, prepare and return actions sequence + if goal_test(element.state): + actions_sequence = [element.action] + parent = element.parent + + while parent is not None: + # root's action will be None, don't add it + if parent.action is not None: + actions_sequence.append(parent.action) + parent = parent.parent + + actions_sequence.reverse() + return actions_sequence + + # add current node to explored (prevents infinite cycles) + explored.append(element) + explored_states.add((element.state.row, element.state.column, element.state.direction)) + + # loop through every possible next action + for successor in get_successors(element.state, minefield): + + new_node = Node(state=successor[1], + parent=element, + action=successor[0]) + successor_state = (successor[1].row, successor[1].column, successor[1].direction) + + if successor_state not in fringe_states and \ + successor_state not in explored_states: + + add_node(fringe, new_node) + fringe_states.add((new_node.state.row, new_node.state.column, new_node.state.direction)) + + # update weight if it's lower + elif successor_state in fringe and entry_finder[successor_state][0] > new_node.weight: + update_priority(fringe, new_node) + + else: + del new_node + + +# TEMPORARY METHOD +def go(row, column, direction): + target = tuple() + + if direction == Direction.RIGHT: + target = row, column + 1 + elif direction == Direction.LEFT: + target = row, column - 1 + elif direction == Direction.UP: + target = row - 1, column + elif direction == Direction.DOWN: + target = row + 1, column + + return target + + +# PRIORITY QUEUE HANDLER +entry_finder = {} # mapping of states to entries in a heap +REMOVED = '' # placeholder for a removed nodes +counter = itertools.count() # unique sequence count + + +def add_node(heap, node: Node): + count = next(counter) + entry = [node.weight, count, node] + entry_finder[node.state] = entry + heappush(heap, entry) + + +def pop_node(heap): + while heap: + priority, count, node = heappop(heap) + if node is not REMOVED: + del entry_finder[node.state] + return node + return None + + +def update_priority(heap, new_node): + old_entry = entry_finder.pop(new_node.state) + old_entry[-1] = REMOVED + add_node(heap, new_node)