Projekt_Sztuczna_Inteligencja/algorithms/search/a_star.py

235 lines
7.2 KiB
Python
Raw Normal View History

from __future__ import annotations
from heapq import heappush, heappop, heapify
from typing import List
import itertools
import ctypes
from project_constants import Direction, Action
from minefield import Minefield
# temporary goal for testing
GOAL = (2, 6)
class State:
def __init__(self, row, column, direction: Direction):
self.row = row
self.column = column
self.direction = direction
class Node:
2021-05-09 02:06:13 +02:00
def __init__(self, state: State, parent: Node = None, action: Action = None, cost=0):
self.state = state
self.parent = parent
self.action = action
2021-05-09 02:06:13 +02:00
self.cost = cost
2021-05-09 02:06:13 +02:00
def get_node_cost(node: Node, minefield: Minefield):
row = node.state.row
column = node.state.column
2021-05-09 02:06:13 +02:00
if node.action != Action.GO:
return node.parent.cost + 1
return node.parent.cost + minefield.matrix[row][column].cost.value
2021-05-09 02:06:13 +02:00
def get_estimated_cost(node: Node):
return abs(node.state.row - GOAL[0]) + abs(node.state.column - GOAL[1])
def tile_goal_test(state: State):
if (state.row, state.column) == GOAL:
return True
return False
def mine_goal_test(state: State):
if state.row == GOAL[0] and state.column == GOAL[1] - 1:
if state.direction == Direction.RIGHT:
return True
elif state.row == GOAL[0] and state.column == GOAL[1] + 1:
if state.direction == Direction.LEFT:
return True
elif state.row == GOAL[0] - 1 and state.column == GOAL[1]:
if state.direction == Direction.DOWN:
return True
elif state.row == GOAL[0] + 1 and state.column == GOAL[1]:
if state.direction == Direction.UP:
return True
return False
def get_successors(state: State, minefield: Minefield):
successors = list()
state_left = State(state.row, state.column, state.direction.previous())
successors.append((Action.ROTATE_LEFT, state_left))
state_right = State(state.row, state.column, state.direction.next())
successors.append((Action.ROTATE_RIGHT, state_right))
target = go(state.row, state.column, state.direction)
if minefield.is_valid_move(target[0], target[1]):
state_go = State(target[0], target[1], state.direction)
successors.append((Action.GO, state_go))
return successors
2021-05-09 02:06:13 +02:00
def graphsearch(initial_state: State,
minefield: Minefield,
fringe: List[Node] = None,
explored: List[Node] = None,
target_type: str = "tile",
2021-05-09 02:06:13 +02:00
tox: int = None,
toy: int = None,
with_data=False):
2021-05-09 02:06:13 +02:00
# reset global priority queue helpers
global entry_finder
global counter
entry_finder = {}
counter = itertools.count()
2021-05-09 02:06:13 +02:00
global GOAL
if tox is not None and toy is not None:
GOAL = (tox, toy)
if target_type == "mine":
goal_test = mine_goal_test
else:
goal_test = tile_goal_test
2021-05-23 08:43:31 +02:00
if minefield.matrix[GOAL[0]][GOAL[1]].mine is not None and minefield.matrix[GOAL[0]][GOAL[1]].mine.active:
# TODO: cross-platform popup, move to separate function
ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
return []
# fringe and explored initialization
if fringe is None:
fringe = list()
heapify(fringe)
if explored is None:
explored = list()
explored_states = set()
fringe_states = set()
# root Node
2021-05-09 02:06:13 +02:00
add_node(fringe, Node(initial_state), 0)
fringe_states.add((initial_state.row, initial_state.column, initial_state.direction))
while True:
# fringe empty -> solution not found
if not any(fringe):
2021-05-09 02:06:13 +02:00
# TODO: cross-platform popup, move to separate function
ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
return []
# get first element from fringe
element = pop_node(fringe)
if element is None:
2021-05-09 02:06:13 +02:00
# TODO: cross-platform popup, move to separate function
ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
return []
fringe_states.remove((element.state.row, element.state.column, element.state.direction))
# if solution was found, prepare and return actions sequence
if goal_test(element.state):
actions_sequence = [element.action]
parent = element.parent
while parent is not None:
# root's action will be None, don't add it
if parent.action is not None:
actions_sequence.append(parent.action)
parent = parent.parent
actions_sequence.reverse()
if with_data:
return actions_sequence, element.state, element.cost
return actions_sequence
# add current node to explored (prevents infinite cycles)
explored.append(element)
explored_states.add((element.state.row, element.state.column, element.state.direction))
# loop through every possible next action
for successor in get_successors(element.state, minefield):
new_node = Node(state=successor[1],
parent=element,
action=successor[0])
2021-05-09 02:06:13 +02:00
new_node.cost = get_node_cost(new_node, minefield)
priority = new_node.cost + get_estimated_cost(new_node)
successor_state = (successor[1].row, successor[1].column, successor[1].direction)
if successor_state not in fringe_states and \
successor_state not in explored_states:
2021-05-09 02:06:13 +02:00
add_node(fringe, new_node, priority)
fringe_states.add((new_node.state.row, new_node.state.column, new_node.state.direction))
# update weight if it's lower
2021-05-09 19:39:00 +02:00
elif successor_state in fringe_states and entry_finder[successor_state][0] > priority:
2021-05-09 02:06:13 +02:00
update_priority(fringe, new_node, priority)
else:
del new_node
# TEMPORARY METHOD
def go(row, column, direction):
target = tuple()
if direction == Direction.RIGHT:
target = row, column + 1
elif direction == Direction.LEFT:
target = row, column - 1
elif direction == Direction.UP:
target = row - 1, column
elif direction == Direction.DOWN:
target = row + 1, column
return target
# PRIORITY QUEUE HANDLER
entry_finder = {} # mapping of states to entries in a heap
REMOVED = '<removed-node>' # placeholder for a removed nodes
counter = itertools.count() # unique sequence count
2021-05-09 02:06:13 +02:00
def add_node(heap, node: Node, priority):
count = next(counter)
2021-05-09 02:06:13 +02:00
entry = [priority, count, node]
2021-05-09 19:39:00 +02:00
entry_finder[(node.state.row, node.state.column, node.state.direction)] = entry
heappush(heap, entry)
def pop_node(heap):
while heap:
priority, count, node = heappop(heap)
if node is not REMOVED:
2021-05-09 19:39:00 +02:00
del entry_finder[(node.state.row, node.state.column, node.state.direction)]
return node
return None
2021-05-09 02:06:13 +02:00
def update_priority(heap, new_node, new_priority):
2021-05-09 19:39:00 +02:00
old_entry = entry_finder.pop((new_node.state.row, new_node.state.column, new_node.state.direction))
old_entry[-1] = REMOVED
2021-05-09 02:06:13 +02:00
add_node(heap, new_node, new_priority)