[a_star_search_with_heap] added a_star_search and heap data structure, moved bfs algorithm to folder

This commit is contained in:
czorekk 2022-04-24 19:05:48 +02:00
parent 5691ef09df
commit 929377af27
6 changed files with 327 additions and 11 deletions

82
data_structures/heap.py Normal file
View File

@ -0,0 +1,82 @@
class HeapElement:
def __init__(self, value, index, compare_value):
self.value = value
self.index = index
self.compare_value = compare_value
class Heap:
def __init__(self):
self.array: list[HeapElement] = []
def length(self) -> int:
return len(self.array)
def contains(self, value) -> bool:
for item in self.array:
if(item.value == value):
return True
return False
def append(self, item, compare_value):
new_element = HeapElement(item, len(self.array), compare_value)
self.array.append(new_element)
self.sort_up(new_element)
def take_first(self):
first_item = self.array[0]
new_first_item = self.array.pop()
if(len(self.array) > 0):
new_first_item.index = 0
self.array[0] = new_first_item
self.sort_down(new_first_item)
return first_item.value
def sort_up(self, item: HeapElement):
parent_index = (item.index - 1)//2
if(parent_index < 0):
return
parent_item = self.array[parent_index]
if(item.compare_value < parent_item.compare_value):
self.swap_items(item, parent_item)
item.index = parent_index
self.sort_up(item)
def sort_down(self, item: HeapElement):
child_left_index = item.index * 2 + 1
child_right_index = item.index * 2 + 2
if (child_left_index < len(self.array)):
swap_index = child_left_index
if(child_right_index < len(self.array)):
child_left_item = self.array[child_left_index]
child_right_item = self.array[child_right_index]
if(child_left_item.compare_value > child_right_item.compare_value):
swap_index = child_right_index
self.swap_items(item, self.array[swap_index])
item.index = swap_index
self.sort_down(item)
def swap_items(self, item_a: HeapElement, item_b: HeapElement):
item_a.index, item_b.index = item_b.index, item_a.index
self.array[item_a.index] = item_a
self.array[item_b.index] = item_b
# some test code
# heap = Heap()
# heap.append(5, 5)
# heap.append(2, 2)
# heap.append(3, 3)
# heap.append(4, 4)
# heap.append(6, 6)
# heap.append(1, 1)
# print(heap.take_first())
# print("heap:")
# for item in heap.array:
# print(item.value)

29
main.py
View File

@ -3,15 +3,14 @@ from game_objects.player import Player
import pygame as pg import pygame as pg
import sys import sys
from os import path from os import path
import math
from map import * from map import *
# from agent import trashmaster
# from house import House
from settings import * from settings import *
from map import map from map import map
from map import map_utils from map import map_utils
from SearchBfs import * from path_search_algorthms import bfs
import math from path_search_algorthms import a_star
class Game(): class Game():
@ -23,7 +22,8 @@ class Game():
pg.display.set_caption("Trashmaster") pg.display.set_caption("Trashmaster")
self.load_data() self.load_data()
self.init_game() self.init_game()
self.init_bfs() # self.init_bfs()
self.init_a_star()
def init_game(self): def init_game(self):
# initialize all variables and do all the setup for a new game # initialize all variables and do all the setup for a new game
@ -39,12 +39,12 @@ class Game():
self.camera = map_utils.Camera(MAP_WIDTH_PX, MAP_HEIGHT_PX) self.camera = map_utils.Camera(MAP_WIDTH_PX, MAP_HEIGHT_PX)
# other # other
self.draw_debug = False self.debug_mode = False
def init_bfs(self): def init_bfs(self):
start_node = (0, 0) start_node = (0, 0)
target_node = (18, 18) target_node = (18, 18)
find_path = BreadthSearchAlgorithm(start_node, target_node, self.mapArray) find_path = bfs.BreadthSearchAlgorithm(start_node, target_node, self.mapArray)
path = find_path.bfs() path = find_path.bfs()
# print(path) # print(path)
realPath = [] realPath = []
@ -56,6 +56,15 @@ class Game():
nextNode = node[1] nextNode = node[1]
print(realPath) print(realPath)
def init_a_star(self):
# szukanie sciezki na sztywno i wyprintowanie wyniku (tablica stringow)
start_x = 0
start_y = 0
target_x = 6
target_y = 2
path = a_star.search_path(start_x, start_y, target_x, target_y, self.mapArray)
print(path)
def load_data(self): def load_data(self):
game_folder = path.dirname(__file__) game_folder = path.dirname(__file__)
img_folder = path.join(game_folder, 'resources/textures') img_folder = path.join(game_folder, 'resources/textures')
@ -88,12 +97,12 @@ class Game():
#rerender map #rerender map
map.render_tiles(self.roadTiles, self.screen, self.camera) map.render_tiles(self.roadTiles, self.screen, self.camera)
map.render_tiles(self.wallTiles, self.screen, self.camera, self.draw_debug) map.render_tiles(self.wallTiles, self.screen, self.camera, self.debug_mode)
#rerender additional sprites #rerender additional sprites
for sprite in self.agentSprites: for sprite in self.agentSprites:
self.screen.blit(sprite.image, self.camera.apply(sprite)) self.screen.blit(sprite.image, self.camera.apply(sprite))
if self.draw_debug: if self.debug_mode:
pg.draw.rect(self.screen, CYAN, self.camera.apply_rect(sprite.hit_rect), 1) pg.draw.rect(self.screen, CYAN, self.camera.apply_rect(sprite.hit_rect), 1)
#finally update screen #finally update screen
@ -107,7 +116,7 @@ class Game():
if event.key == pg.K_ESCAPE: if event.key == pg.K_ESCAPE:
self.quit() self.quit()
if event.key == pg.K_h: if event.key == pg.K_h:
self.draw_debug = not self.draw_debug self.debug_mode = not self.debug_mode
if event.type == pg.MOUSEBUTTONUP: if event.type == pg.MOUSEBUTTONUP:
pos = pg.mouse.get_pos() pos = pg.mouse.get_pos()
clicked_coords = [math.floor(pos[0] / TILESIZE), math.floor(pos[1] / TILESIZE)] clicked_coords = [math.floor(pos[0] / TILESIZE), math.floor(pos[1] / TILESIZE)]

View File

@ -4,7 +4,18 @@ import pygame as pg
from settings import * from settings import *
def get_tiles(): def get_tiles():
array = map_utils.generate_map() # array = map_utils.generate_map()
array = map_utils.get_blank_map_array()
array[1][1] = 1
array[1][2] = 1
array[1][3] = 1
array[1][4] = 1
array[1][5] = 1
array[1][6] = 1
array[2][5] = 1
pattern = map_pattern.get_pattern() pattern = map_pattern.get_pattern()
tiles = map_utils.get_sprites(array, pattern) tiles = map_utils.get_sprites(array, pattern)
return tiles, array return tiles, array

View File

@ -0,0 +1,112 @@
from data_structures.heap import Heap
from path_search_algorthms import a_star_utils as utils
def search_path(start_x: int, start_y: int, target_x: int, target_y: int, array: list[list[int]]) -> list[str]:
start_node = utils.Node(start_x, start_y, utils.Rotation.RIGHT)
target_node = utils.Node(target_x, target_y, utils.Rotation.NONE)
# heap version
# nodes for check
search_list = Heap()
search_list.append(start_node, 0)
# checked nodes
searched_list: list[(int, int)] = []
while (search_list.length() > 0):
node: utils.Node = search_list.take_first()
searched_list.append((node.x, node.y))
# check for target node
if ((node.x, node.y) == (target_x, target_y)):
return trace_path(node)
# neightbours processing
neighbours = utils.get_neighbours(node, searched_list, array)
for neighbour in neighbours:
# calculate new g cost for neightbour (start -> node -> neightbour)
new_neighbour_cost = node.g_cost + utils.get_neighbour_cost(node, neighbour)
if (new_neighbour_cost < neighbour.g_cost or not search_list.contains(neighbour)):
# replace cost and set parent node
neighbour.g_cost = new_neighbour_cost
neighbour.h_cost = utils.get_h_cost(neighbour, target_node)
neighbour.parent = node
# add to search
if(not search_list.contains(neighbour)):
search_list.append(neighbour, neighbour.f_cost())
# array version
# nodes for check
# search_list = [start_node]
# checked nodes
# searched_list: list[(int, int)] = []
# while (len(search_list) > 0):
# node = search_list[0]
# # find cheapest node in search_list
# for i in range(1, len(search_list)):
# if (search_list[i].f_cost() <= node.f_cost()):
# if(search_list[i].h_cost < node.h_cost):
# node = search_list[i]
# search_list.remove(node)
# searched_list.append((node.x, node.y))
# # check for target node
# if ((node.x, node.y) == (target_x, target_y)):
# return trace_path(node)
# # neightbours processing
# neighbours = utils.get_neighbours(node, searched_list, array)
# for neighbour in neighbours:
# # calculate new g cost for neightbour (start -> node -> neightbour)
# new_neighbour_cost = node.g_cost + utils.get_neighbour_cost(node, neighbour)
# if (new_neighbour_cost < neighbour.g_cost or neighbour not in search_list):
# # replace cost and set parent node
# neighbour.g_cost = new_neighbour_cost
# neighbour.h_cost = utils.get_h_cost(neighbour, target_node)
# neighbour.parent = node
# # add to search
# if(neighbour not in search_list):
# search_list.append(neighbour)
def trace_path(end_node: utils.Node) -> list[str]:
path = []
node = end_node
# set final rotation of end_node because we don't do it before
node.rotation = utils.get_needed_rotation(node.parent, node)
while (node.parent != 0):
move = utils.get_move(node.parent, node)
path += move
node = node.parent
# delete move on initial tile
path.pop()
# we found path from end, so we need to reverse it (get_move reverse move words)
path.reverse()
# last forward to destination
path.append("forward")
return path

View File

@ -0,0 +1,102 @@
from enum import Enum
from settings import *
ROAD_TILE = 0
class Rotation(Enum):
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
NONE = 100
def __int__(self):
return self.value
class Node:
def __init__(self, x: int, y: int, rotation: Rotation):
self.x = x
self.y = y
self.g_cost = 0
self.h_cost = 0
self.parent = 0
self.rotation = rotation
def f_cost(self):
return self.g_cost + self.h_cost
def get_neighbours(node: Node, searched_list: list[Node], array: list[list[int]]) -> list[Node]:
neighbours = []
for offset_x in range (-1, 2):
for offset_y in range (-1, 2):
# don't look for cross neighbours
if(abs(offset_x) + abs(offset_y) == 1):
x = node.x + offset_x
y = node.y + offset_y
# prevent out of map coords
if (x >= 0 and x <= MAP_WIDTH and y >= 0 and y <= MAP_HEIGHT):
if(array[y][x] == ROAD_TILE and (x, y) not in searched_list):
neighbour = Node(x, y, Rotation.NONE)
neighbour.rotation = get_needed_rotation(node, neighbour)
neighbours.append(neighbour)
return neighbours
# move cost schema:
# - move from tile to tile: 10
# - add extra 10 (1 rotation) if it exists
def get_h_cost(start_node: Node, target_node: Node) -> int:
distance_x = abs(start_node.x - target_node.x)
distance_y = abs(start_node.y - target_node.y)
cost = (distance_x + distance_y) * 10
if(distance_x > 0 and distance_y > 0):
cost += 10
return cost
# move cost schema:
# - move from tile to tile: 10
# - every rotation 90*: 10
def get_neighbour_cost(start_node: Node, target_node: Node) -> int:
new_rotation = get_needed_rotation(start_node, target_node)
rotate_change = abs(get_rotate_change(start_node.rotation, new_rotation))
if (rotate_change == 0):
return 10
elif (rotate_change == 1 or rotate_change == 3):
return 20
else:
return 30
# translate rotation change to move
def get_move(start_node: Node, target_node: Node) -> list[str]:
rotate_change = get_rotate_change(start_node.rotation, target_node.rotation)
if (rotate_change == 0):
return ["forward"]
if (abs(rotate_change) == 2):
return ["right", "right", "forward"]
if (rotate_change < 0 or rotate_change == 3):
return ["right", "forward"]
else:
return ["left", "forward"]
# simple calc func
def get_rotate_change(rotationA: Rotation, rotationB: Rotation) -> int:
return int(rotationA) - int(rotationB)
# get new rotation for target_node as neighbour of start_node
def get_needed_rotation(start_node: Node, target_node: Node) -> Rotation:
if (start_node.x - target_node.x > 0):
return Rotation.LEFT
if (start_node.x - target_node.x < 0):
return Rotation.RIGHT
if (start_node.y - target_node.y > 0):
return Rotation.UP
if (start_node.y - target_node.y < 0):
return Rotation.DOWN