Fixed cost and priority calculation

This commit is contained in:
s452645 2021-05-09 02:06:13 +02:00
parent 7d16dcc616
commit 36258b11d0
3 changed files with 44 additions and 25 deletions

View File

@ -6,7 +6,6 @@ import random
import project_constants as const import project_constants as const
import minefield as mf import minefield as mf
import searching_algorithms.a_star as a_star import searching_algorithms.a_star as a_star
import searching_algorithms.bfs as bfs
from display_assets import blit_graphics from display_assets import blit_graphics
from ui.input_box import * from ui.input_box import *
from ui.button import * from ui.button import *
@ -94,8 +93,8 @@ def main():
to_x = int(input1.get_input()) to_x = int(input1.get_input())
to_y = int(input2.get_input()) to_y = int(input2.get_input())
action_sequence = bfs.graphsearch( action_sequence = a_star.graphsearch(
initial_state=bfs.State( initial_state=a_star.State(
row=minefield.agent.position[0], row=minefield.agent.position[0],
column=minefield.agent.position[1], column=minefield.agent.position[1],
direction=minefield.agent.direction), direction=minefield.agent.direction),

View File

@ -17,7 +17,7 @@ V_NAME_OF_WINDOW = "MineFusion TM"
DIR_ASSETS = os.path.join("resources", "assets") DIR_ASSETS = os.path.join("resources", "assets")
V_FPS = 60 V_FPS = 60
ACTION_INTERVAL = 1 # interval between two actions in seconds ACTION_INTERVAL = 0.5 # interval between two actions in seconds
V_TILE_SIZE = 60 V_TILE_SIZE = 60
V_GRID_VER_TILES = 10 # vertical (number of rows) V_GRID_VER_TILES = 10 # vertical (number of rows)

View File

@ -20,25 +20,30 @@ class State:
class Node: class Node:
def __init__(self, state: State, parent: Node = None, action: Action = None, weight=0): def __init__(self, state: State, parent: Node = None, action: Action = None, cost=0):
self.state = state self.state = state
self.parent = parent self.parent = parent
self.action = action self.action = action
self.cost = cost
if not weight:
self.weight = self._get_weight() def get_node_cost(node: Node, minefield: Minefield):
row = node.state.row
column = node.state.column
# Rotation Cost??
if node.action != Action.GO:
return node.parent.cost + 1
# if Tile considered its mine in cost calculation, this code would be priettier
if minefield.matrix[row][column].mine is not None:
return node.parent.cost + 10
else: else:
self.weight = weight return node.parent.cost + minefield.matrix[row][column].cost.value
def _get_weight(self):
weight = 0
if self.parent is not None:
weight += self.parent.weight
heuristics = abs(self.state.row - GOAL[0]) + abs(self.state.column - GOAL[1]) def get_estimated_cost(node: Node):
weight += heuristics return abs(node.state.row - GOAL[0]) + abs(node.state.column - GOAL[1])
return weight
def goal_test(state: State): def goal_test(state: State):
@ -65,13 +70,23 @@ def get_successors(state: State, minefield: Minefield):
return successors return successors
def graphsearch(initial_state: State, minefield: Minefield, fringe: List[Node] = None, explored: List[Node] = None): def graphsearch(initial_state: State,
minefield: Minefield,
fringe: List[Node] = None,
explored: List[Node] = None,
tox: int = None,
toy: int = None):
# reset global priority queue helpers # reset global priority queue helpers
global entry_finder global entry_finder
global counter global counter
entry_finder = {} entry_finder = {}
counter = itertools.count() counter = itertools.count()
global GOAL
if tox is not None and toy is not None:
GOAL = (tox, toy)
# fringe and explored initialization # fringe and explored initialization
if fringe is None: if fringe is None:
fringe = list() fringe = list()
@ -83,18 +98,20 @@ def graphsearch(initial_state: State, minefield: Minefield, fringe: List[Node] =
fringe_states = set() fringe_states = set()
# root Node # root Node
add_node(fringe, Node(initial_state)) add_node(fringe, Node(initial_state), 0)
fringe_states.add((initial_state.row, initial_state.column, initial_state.direction)) fringe_states.add((initial_state.row, initial_state.column, initial_state.direction))
while True: while True:
# fringe empty -> solution not found # fringe empty -> solution not found
if not any(fringe): if not any(fringe):
# TODO: cross-platform popup, move to separate function
ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1) ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
return [] return []
# get first element from fringe # get first element from fringe
element = pop_node(fringe) element = pop_node(fringe)
if element is None: if element is None:
# TODO: cross-platform popup, move to separate function
ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1) ctypes.windll.user32.MessageBoxW(0, "Brak rozwiązania", "GAME OVER", 1)
return [] return []
@ -124,17 +141,20 @@ def graphsearch(initial_state: State, minefield: Minefield, fringe: List[Node] =
new_node = Node(state=successor[1], new_node = Node(state=successor[1],
parent=element, parent=element,
action=successor[0]) action=successor[0])
new_node.cost = get_node_cost(new_node, minefield)
priority = new_node.cost + get_estimated_cost(new_node)
successor_state = (successor[1].row, successor[1].column, successor[1].direction) successor_state = (successor[1].row, successor[1].column, successor[1].direction)
if successor_state not in fringe_states and \ if successor_state not in fringe_states and \
successor_state not in explored_states: successor_state not in explored_states:
add_node(fringe, new_node) add_node(fringe, new_node, priority)
fringe_states.add((new_node.state.row, new_node.state.column, new_node.state.direction)) fringe_states.add((new_node.state.row, new_node.state.column, new_node.state.direction))
# update weight if it's lower # update weight if it's lower
elif successor_state in fringe and entry_finder[successor_state][0] > new_node.weight: elif successor_state in fringe and entry_finder[successor_state][0] > priority:
update_priority(fringe, new_node) update_priority(fringe, new_node, priority)
else: else:
del new_node del new_node
@ -162,9 +182,9 @@ REMOVED = '<removed-node>' # placeholder for a removed nodes
counter = itertools.count() # unique sequence count counter = itertools.count() # unique sequence count
def add_node(heap, node: Node): def add_node(heap, node: Node, priority):
count = next(counter) count = next(counter)
entry = [node.weight, count, node] entry = [priority, count, node]
entry_finder[node.state] = entry entry_finder[node.state] = entry
heappush(heap, entry) heappush(heap, entry)
@ -178,7 +198,7 @@ def pop_node(heap):
return None return None
def update_priority(heap, new_node): def update_priority(heap, new_node, new_priority):
old_entry = entry_finder.pop(new_node.state) old_entry = entry_finder.pop(new_node.state)
old_entry[-1] = REMOVED old_entry[-1] = REMOVED
add_node(heap, new_node) add_node(heap, new_node, new_priority)