[a_star_search_with_heap] added a_star_search and heap data structure, moved bfs algorithm to folder
This commit is contained in:
parent
5691ef09df
commit
929377af27
82
data_structures/heap.py
Normal file
82
data_structures/heap.py
Normal file
@ -0,0 +1,82 @@
|
||||
class HeapElement:
|
||||
def __init__(self, value, index, compare_value):
|
||||
self.value = value
|
||||
self.index = index
|
||||
self.compare_value = compare_value
|
||||
|
||||
class Heap:
|
||||
def __init__(self):
|
||||
self.array: list[HeapElement] = []
|
||||
|
||||
def length(self) -> int:
|
||||
return len(self.array)
|
||||
|
||||
def contains(self, value) -> bool:
|
||||
for item in self.array:
|
||||
if(item.value == value):
|
||||
return True
|
||||
return False
|
||||
|
||||
def append(self, item, compare_value):
|
||||
new_element = HeapElement(item, len(self.array), compare_value)
|
||||
self.array.append(new_element)
|
||||
self.sort_up(new_element)
|
||||
|
||||
def take_first(self):
|
||||
first_item = self.array[0]
|
||||
new_first_item = self.array.pop()
|
||||
if(len(self.array) > 0):
|
||||
new_first_item.index = 0
|
||||
self.array[0] = new_first_item
|
||||
self.sort_down(new_first_item)
|
||||
return first_item.value
|
||||
|
||||
def sort_up(self, item: HeapElement):
|
||||
parent_index = (item.index - 1)//2
|
||||
if(parent_index < 0):
|
||||
return
|
||||
|
||||
parent_item = self.array[parent_index]
|
||||
|
||||
if(item.compare_value < parent_item.compare_value):
|
||||
self.swap_items(item, parent_item)
|
||||
item.index = parent_index
|
||||
self.sort_up(item)
|
||||
|
||||
def sort_down(self, item: HeapElement):
|
||||
child_left_index = item.index * 2 + 1
|
||||
child_right_index = item.index * 2 + 2
|
||||
|
||||
if (child_left_index < len(self.array)):
|
||||
swap_index = child_left_index
|
||||
|
||||
if(child_right_index < len(self.array)):
|
||||
child_left_item = self.array[child_left_index]
|
||||
child_right_item = self.array[child_right_index]
|
||||
if(child_left_item.compare_value > child_right_item.compare_value):
|
||||
swap_index = child_right_index
|
||||
|
||||
self.swap_items(item, self.array[swap_index])
|
||||
item.index = swap_index
|
||||
self.sort_down(item)
|
||||
|
||||
def swap_items(self, item_a: HeapElement, item_b: HeapElement):
|
||||
item_a.index, item_b.index = item_b.index, item_a.index
|
||||
self.array[item_a.index] = item_a
|
||||
self.array[item_b.index] = item_b
|
||||
|
||||
# some test code
|
||||
|
||||
# heap = Heap()
|
||||
# heap.append(5, 5)
|
||||
# heap.append(2, 2)
|
||||
# heap.append(3, 3)
|
||||
# heap.append(4, 4)
|
||||
# heap.append(6, 6)
|
||||
# heap.append(1, 1)
|
||||
# print(heap.take_first())
|
||||
# print("heap:")
|
||||
# for item in heap.array:
|
||||
# print(item.value)
|
||||
|
||||
|
29
main.py
29
main.py
@ -3,15 +3,14 @@ from game_objects.player import Player
|
||||
import pygame as pg
|
||||
import sys
|
||||
from os import path
|
||||
import math
|
||||
|
||||
from map import *
|
||||
# from agent import trashmaster
|
||||
# from house import House
|
||||
from settings import *
|
||||
from map import map
|
||||
from map import map_utils
|
||||
from SearchBfs import *
|
||||
import math
|
||||
from path_search_algorthms import bfs
|
||||
from path_search_algorthms import a_star
|
||||
|
||||
|
||||
class Game():
|
||||
@ -23,7 +22,8 @@ class Game():
|
||||
pg.display.set_caption("Trashmaster")
|
||||
self.load_data()
|
||||
self.init_game()
|
||||
self.init_bfs()
|
||||
# self.init_bfs()
|
||||
self.init_a_star()
|
||||
|
||||
def init_game(self):
|
||||
# initialize all variables and do all the setup for a new game
|
||||
@ -39,12 +39,12 @@ class Game():
|
||||
self.camera = map_utils.Camera(MAP_WIDTH_PX, MAP_HEIGHT_PX)
|
||||
|
||||
# other
|
||||
self.draw_debug = False
|
||||
self.debug_mode = False
|
||||
|
||||
def init_bfs(self):
|
||||
start_node = (0, 0)
|
||||
target_node = (18, 18)
|
||||
find_path = BreadthSearchAlgorithm(start_node, target_node, self.mapArray)
|
||||
find_path = bfs.BreadthSearchAlgorithm(start_node, target_node, self.mapArray)
|
||||
path = find_path.bfs()
|
||||
# print(path)
|
||||
realPath = []
|
||||
@ -56,6 +56,15 @@ class Game():
|
||||
nextNode = node[1]
|
||||
print(realPath)
|
||||
|
||||
def init_a_star(self):
|
||||
# szukanie sciezki na sztywno i wyprintowanie wyniku (tablica stringow)
|
||||
start_x = 0
|
||||
start_y = 0
|
||||
target_x = 6
|
||||
target_y = 2
|
||||
path = a_star.search_path(start_x, start_y, target_x, target_y, self.mapArray)
|
||||
print(path)
|
||||
|
||||
def load_data(self):
|
||||
game_folder = path.dirname(__file__)
|
||||
img_folder = path.join(game_folder, 'resources/textures')
|
||||
@ -88,12 +97,12 @@ class Game():
|
||||
|
||||
#rerender map
|
||||
map.render_tiles(self.roadTiles, self.screen, self.camera)
|
||||
map.render_tiles(self.wallTiles, self.screen, self.camera, self.draw_debug)
|
||||
map.render_tiles(self.wallTiles, self.screen, self.camera, self.debug_mode)
|
||||
|
||||
#rerender additional sprites
|
||||
for sprite in self.agentSprites:
|
||||
self.screen.blit(sprite.image, self.camera.apply(sprite))
|
||||
if self.draw_debug:
|
||||
if self.debug_mode:
|
||||
pg.draw.rect(self.screen, CYAN, self.camera.apply_rect(sprite.hit_rect), 1)
|
||||
|
||||
#finally update screen
|
||||
@ -107,7 +116,7 @@ class Game():
|
||||
if event.key == pg.K_ESCAPE:
|
||||
self.quit()
|
||||
if event.key == pg.K_h:
|
||||
self.draw_debug = not self.draw_debug
|
||||
self.debug_mode = not self.debug_mode
|
||||
if event.type == pg.MOUSEBUTTONUP:
|
||||
pos = pg.mouse.get_pos()
|
||||
clicked_coords = [math.floor(pos[0] / TILESIZE), math.floor(pos[1] / TILESIZE)]
|
||||
|
13
map/map.py
13
map/map.py
@ -4,7 +4,18 @@ import pygame as pg
|
||||
from settings import *
|
||||
|
||||
def get_tiles():
|
||||
array = map_utils.generate_map()
|
||||
# array = map_utils.generate_map()
|
||||
array = map_utils.get_blank_map_array()
|
||||
|
||||
array[1][1] = 1
|
||||
array[1][2] = 1
|
||||
array[1][3] = 1
|
||||
array[1][4] = 1
|
||||
array[1][5] = 1
|
||||
array[1][6] = 1
|
||||
|
||||
array[2][5] = 1
|
||||
|
||||
pattern = map_pattern.get_pattern()
|
||||
tiles = map_utils.get_sprites(array, pattern)
|
||||
return tiles, array
|
||||
|
112
path_search_algorthms/a_star.py
Normal file
112
path_search_algorthms/a_star.py
Normal file
@ -0,0 +1,112 @@
|
||||
from data_structures.heap import Heap
|
||||
from path_search_algorthms import a_star_utils as utils
|
||||
|
||||
def search_path(start_x: int, start_y: int, target_x: int, target_y: int, array: list[list[int]]) -> list[str]:
|
||||
|
||||
start_node = utils.Node(start_x, start_y, utils.Rotation.RIGHT)
|
||||
target_node = utils.Node(target_x, target_y, utils.Rotation.NONE)
|
||||
|
||||
# heap version
|
||||
|
||||
# nodes for check
|
||||
search_list = Heap()
|
||||
search_list.append(start_node, 0)
|
||||
|
||||
# checked nodes
|
||||
searched_list: list[(int, int)] = []
|
||||
|
||||
while (search_list.length() > 0):
|
||||
node: utils.Node = search_list.take_first()
|
||||
|
||||
searched_list.append((node.x, node.y))
|
||||
|
||||
# check for target node
|
||||
if ((node.x, node.y) == (target_x, target_y)):
|
||||
return trace_path(node)
|
||||
|
||||
# neightbours processing
|
||||
neighbours = utils.get_neighbours(node, searched_list, array)
|
||||
for neighbour in neighbours:
|
||||
|
||||
# calculate new g cost for neightbour (start -> node -> neightbour)
|
||||
new_neighbour_cost = node.g_cost + utils.get_neighbour_cost(node, neighbour)
|
||||
|
||||
if (new_neighbour_cost < neighbour.g_cost or not search_list.contains(neighbour)):
|
||||
|
||||
# replace cost and set parent node
|
||||
neighbour.g_cost = new_neighbour_cost
|
||||
neighbour.h_cost = utils.get_h_cost(neighbour, target_node)
|
||||
neighbour.parent = node
|
||||
|
||||
# add to search
|
||||
if(not search_list.contains(neighbour)):
|
||||
search_list.append(neighbour, neighbour.f_cost())
|
||||
|
||||
# array version
|
||||
|
||||
# nodes for check
|
||||
# search_list = [start_node]
|
||||
|
||||
# checked nodes
|
||||
# searched_list: list[(int, int)] = []
|
||||
|
||||
# while (len(search_list) > 0):
|
||||
# node = search_list[0]
|
||||
|
||||
# # find cheapest node in search_list
|
||||
# for i in range(1, len(search_list)):
|
||||
# if (search_list[i].f_cost() <= node.f_cost()):
|
||||
# if(search_list[i].h_cost < node.h_cost):
|
||||
# node = search_list[i]
|
||||
|
||||
# search_list.remove(node)
|
||||
# searched_list.append((node.x, node.y))
|
||||
|
||||
# # check for target node
|
||||
# if ((node.x, node.y) == (target_x, target_y)):
|
||||
# return trace_path(node)
|
||||
|
||||
# # neightbours processing
|
||||
# neighbours = utils.get_neighbours(node, searched_list, array)
|
||||
# for neighbour in neighbours:
|
||||
|
||||
# # calculate new g cost for neightbour (start -> node -> neightbour)
|
||||
# new_neighbour_cost = node.g_cost + utils.get_neighbour_cost(node, neighbour)
|
||||
|
||||
# if (new_neighbour_cost < neighbour.g_cost or neighbour not in search_list):
|
||||
|
||||
# # replace cost and set parent node
|
||||
# neighbour.g_cost = new_neighbour_cost
|
||||
# neighbour.h_cost = utils.get_h_cost(neighbour, target_node)
|
||||
# neighbour.parent = node
|
||||
|
||||
# # add to search
|
||||
# if(neighbour not in search_list):
|
||||
# search_list.append(neighbour)
|
||||
|
||||
def trace_path(end_node: utils.Node) -> list[str]:
|
||||
path = []
|
||||
node = end_node
|
||||
|
||||
# set final rotation of end_node because we don't do it before
|
||||
node.rotation = utils.get_needed_rotation(node.parent, node)
|
||||
|
||||
while (node.parent != 0):
|
||||
move = utils.get_move(node.parent, node)
|
||||
path += move
|
||||
node = node.parent
|
||||
|
||||
# delete move on initial tile
|
||||
path.pop()
|
||||
|
||||
# we found path from end, so we need to reverse it (get_move reverse move words)
|
||||
path.reverse()
|
||||
|
||||
# last forward to destination
|
||||
path.append("forward")
|
||||
|
||||
return path
|
||||
|
||||
|
||||
|
||||
|
102
path_search_algorthms/a_star_utils.py
Normal file
102
path_search_algorthms/a_star_utils.py
Normal file
@ -0,0 +1,102 @@
|
||||
from enum import Enum
|
||||
|
||||
from settings import *
|
||||
|
||||
ROAD_TILE = 0
|
||||
|
||||
class Rotation(Enum):
|
||||
UP = 0
|
||||
RIGHT = 1
|
||||
DOWN = 2
|
||||
LEFT = 3
|
||||
NONE = 100
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
class Node:
|
||||
def __init__(self, x: int, y: int, rotation: Rotation):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.g_cost = 0
|
||||
self.h_cost = 0
|
||||
self.parent = 0
|
||||
self.rotation = rotation
|
||||
|
||||
def f_cost(self):
|
||||
return self.g_cost + self.h_cost
|
||||
|
||||
def get_neighbours(node: Node, searched_list: list[Node], array: list[list[int]]) -> list[Node]:
|
||||
neighbours = []
|
||||
for offset_x in range (-1, 2):
|
||||
for offset_y in range (-1, 2):
|
||||
# don't look for cross neighbours
|
||||
if(abs(offset_x) + abs(offset_y) == 1):
|
||||
x = node.x + offset_x
|
||||
y = node.y + offset_y
|
||||
# prevent out of map coords
|
||||
if (x >= 0 and x <= MAP_WIDTH and y >= 0 and y <= MAP_HEIGHT):
|
||||
if(array[y][x] == ROAD_TILE and (x, y) not in searched_list):
|
||||
neighbour = Node(x, y, Rotation.NONE)
|
||||
neighbour.rotation = get_needed_rotation(node, neighbour)
|
||||
neighbours.append(neighbour)
|
||||
return neighbours
|
||||
|
||||
# move cost schema:
|
||||
# - move from tile to tile: 10
|
||||
# - add extra 10 (1 rotation) if it exists
|
||||
def get_h_cost(start_node: Node, target_node: Node) -> int:
|
||||
distance_x = abs(start_node.x - target_node.x)
|
||||
distance_y = abs(start_node.y - target_node.y)
|
||||
cost = (distance_x + distance_y) * 10
|
||||
|
||||
if(distance_x > 0 and distance_y > 0):
|
||||
cost += 10
|
||||
|
||||
return cost
|
||||
|
||||
# move cost schema:
|
||||
# - move from tile to tile: 10
|
||||
# - every rotation 90*: 10
|
||||
def get_neighbour_cost(start_node: Node, target_node: Node) -> int:
|
||||
new_rotation = get_needed_rotation(start_node, target_node)
|
||||
rotate_change = abs(get_rotate_change(start_node.rotation, new_rotation))
|
||||
if (rotate_change == 0):
|
||||
return 10
|
||||
elif (rotate_change == 1 or rotate_change == 3):
|
||||
return 20
|
||||
else:
|
||||
return 30
|
||||
|
||||
# translate rotation change to move
|
||||
def get_move(start_node: Node, target_node: Node) -> list[str]:
|
||||
rotate_change = get_rotate_change(start_node.rotation, target_node.rotation)
|
||||
if (rotate_change == 0):
|
||||
return ["forward"]
|
||||
if (abs(rotate_change) == 2):
|
||||
return ["right", "right", "forward"]
|
||||
if (rotate_change < 0 or rotate_change == 3):
|
||||
return ["right", "forward"]
|
||||
else:
|
||||
return ["left", "forward"]
|
||||
|
||||
# simple calc func
|
||||
def get_rotate_change(rotationA: Rotation, rotationB: Rotation) -> int:
|
||||
return int(rotationA) - int(rotationB)
|
||||
|
||||
# get new rotation for target_node as neighbour of start_node
|
||||
def get_needed_rotation(start_node: Node, target_node: Node) -> Rotation:
|
||||
if (start_node.x - target_node.x > 0):
|
||||
return Rotation.LEFT
|
||||
if (start_node.x - target_node.x < 0):
|
||||
return Rotation.RIGHT
|
||||
if (start_node.y - target_node.y > 0):
|
||||
return Rotation.UP
|
||||
if (start_node.y - target_node.y < 0):
|
||||
return Rotation.DOWN
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user