[a_star_search_with_heap] added a_star_search and heap data structure, moved bfs algorithm to folder
This commit is contained in:
parent
5691ef09df
commit
929377af27
82
data_structures/heap.py
Normal file
82
data_structures/heap.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
class HeapElement:
|
||||||
|
def __init__(self, value, index, compare_value):
|
||||||
|
self.value = value
|
||||||
|
self.index = index
|
||||||
|
self.compare_value = compare_value
|
||||||
|
|
||||||
|
class Heap:
|
||||||
|
def __init__(self):
|
||||||
|
self.array: list[HeapElement] = []
|
||||||
|
|
||||||
|
def length(self) -> int:
|
||||||
|
return len(self.array)
|
||||||
|
|
||||||
|
def contains(self, value) -> bool:
|
||||||
|
for item in self.array:
|
||||||
|
if(item.value == value):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def append(self, item, compare_value):
|
||||||
|
new_element = HeapElement(item, len(self.array), compare_value)
|
||||||
|
self.array.append(new_element)
|
||||||
|
self.sort_up(new_element)
|
||||||
|
|
||||||
|
def take_first(self):
|
||||||
|
first_item = self.array[0]
|
||||||
|
new_first_item = self.array.pop()
|
||||||
|
if(len(self.array) > 0):
|
||||||
|
new_first_item.index = 0
|
||||||
|
self.array[0] = new_first_item
|
||||||
|
self.sort_down(new_first_item)
|
||||||
|
return first_item.value
|
||||||
|
|
||||||
|
def sort_up(self, item: HeapElement):
|
||||||
|
parent_index = (item.index - 1)//2
|
||||||
|
if(parent_index < 0):
|
||||||
|
return
|
||||||
|
|
||||||
|
parent_item = self.array[parent_index]
|
||||||
|
|
||||||
|
if(item.compare_value < parent_item.compare_value):
|
||||||
|
self.swap_items(item, parent_item)
|
||||||
|
item.index = parent_index
|
||||||
|
self.sort_up(item)
|
||||||
|
|
||||||
|
def sort_down(self, item: HeapElement):
|
||||||
|
child_left_index = item.index * 2 + 1
|
||||||
|
child_right_index = item.index * 2 + 2
|
||||||
|
|
||||||
|
if (child_left_index < len(self.array)):
|
||||||
|
swap_index = child_left_index
|
||||||
|
|
||||||
|
if(child_right_index < len(self.array)):
|
||||||
|
child_left_item = self.array[child_left_index]
|
||||||
|
child_right_item = self.array[child_right_index]
|
||||||
|
if(child_left_item.compare_value > child_right_item.compare_value):
|
||||||
|
swap_index = child_right_index
|
||||||
|
|
||||||
|
self.swap_items(item, self.array[swap_index])
|
||||||
|
item.index = swap_index
|
||||||
|
self.sort_down(item)
|
||||||
|
|
||||||
|
def swap_items(self, item_a: HeapElement, item_b: HeapElement):
|
||||||
|
item_a.index, item_b.index = item_b.index, item_a.index
|
||||||
|
self.array[item_a.index] = item_a
|
||||||
|
self.array[item_b.index] = item_b
|
||||||
|
|
||||||
|
# some test code
|
||||||
|
|
||||||
|
# heap = Heap()
|
||||||
|
# heap.append(5, 5)
|
||||||
|
# heap.append(2, 2)
|
||||||
|
# heap.append(3, 3)
|
||||||
|
# heap.append(4, 4)
|
||||||
|
# heap.append(6, 6)
|
||||||
|
# heap.append(1, 1)
|
||||||
|
# print(heap.take_first())
|
||||||
|
# print("heap:")
|
||||||
|
# for item in heap.array:
|
||||||
|
# print(item.value)
|
||||||
|
|
||||||
|
|
29
main.py
29
main.py
@ -3,15 +3,14 @@ from game_objects.player import Player
|
|||||||
import pygame as pg
|
import pygame as pg
|
||||||
import sys
|
import sys
|
||||||
from os import path
|
from os import path
|
||||||
|
import math
|
||||||
|
|
||||||
from map import *
|
from map import *
|
||||||
# from agent import trashmaster
|
|
||||||
# from house import House
|
|
||||||
from settings import *
|
from settings import *
|
||||||
from map import map
|
from map import map
|
||||||
from map import map_utils
|
from map import map_utils
|
||||||
from SearchBfs import *
|
from path_search_algorthms import bfs
|
||||||
import math
|
from path_search_algorthms import a_star
|
||||||
|
|
||||||
|
|
||||||
class Game():
|
class Game():
|
||||||
@ -23,7 +22,8 @@ class Game():
|
|||||||
pg.display.set_caption("Trashmaster")
|
pg.display.set_caption("Trashmaster")
|
||||||
self.load_data()
|
self.load_data()
|
||||||
self.init_game()
|
self.init_game()
|
||||||
self.init_bfs()
|
# self.init_bfs()
|
||||||
|
self.init_a_star()
|
||||||
|
|
||||||
def init_game(self):
|
def init_game(self):
|
||||||
# initialize all variables and do all the setup for a new game
|
# initialize all variables and do all the setup for a new game
|
||||||
@ -39,12 +39,12 @@ class Game():
|
|||||||
self.camera = map_utils.Camera(MAP_WIDTH_PX, MAP_HEIGHT_PX)
|
self.camera = map_utils.Camera(MAP_WIDTH_PX, MAP_HEIGHT_PX)
|
||||||
|
|
||||||
# other
|
# other
|
||||||
self.draw_debug = False
|
self.debug_mode = False
|
||||||
|
|
||||||
def init_bfs(self):
|
def init_bfs(self):
|
||||||
start_node = (0, 0)
|
start_node = (0, 0)
|
||||||
target_node = (18, 18)
|
target_node = (18, 18)
|
||||||
find_path = BreadthSearchAlgorithm(start_node, target_node, self.mapArray)
|
find_path = bfs.BreadthSearchAlgorithm(start_node, target_node, self.mapArray)
|
||||||
path = find_path.bfs()
|
path = find_path.bfs()
|
||||||
# print(path)
|
# print(path)
|
||||||
realPath = []
|
realPath = []
|
||||||
@ -56,6 +56,15 @@ class Game():
|
|||||||
nextNode = node[1]
|
nextNode = node[1]
|
||||||
print(realPath)
|
print(realPath)
|
||||||
|
|
||||||
|
def init_a_star(self):
|
||||||
|
# szukanie sciezki na sztywno i wyprintowanie wyniku (tablica stringow)
|
||||||
|
start_x = 0
|
||||||
|
start_y = 0
|
||||||
|
target_x = 6
|
||||||
|
target_y = 2
|
||||||
|
path = a_star.search_path(start_x, start_y, target_x, target_y, self.mapArray)
|
||||||
|
print(path)
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
game_folder = path.dirname(__file__)
|
game_folder = path.dirname(__file__)
|
||||||
img_folder = path.join(game_folder, 'resources/textures')
|
img_folder = path.join(game_folder, 'resources/textures')
|
||||||
@ -88,12 +97,12 @@ class Game():
|
|||||||
|
|
||||||
#rerender map
|
#rerender map
|
||||||
map.render_tiles(self.roadTiles, self.screen, self.camera)
|
map.render_tiles(self.roadTiles, self.screen, self.camera)
|
||||||
map.render_tiles(self.wallTiles, self.screen, self.camera, self.draw_debug)
|
map.render_tiles(self.wallTiles, self.screen, self.camera, self.debug_mode)
|
||||||
|
|
||||||
#rerender additional sprites
|
#rerender additional sprites
|
||||||
for sprite in self.agentSprites:
|
for sprite in self.agentSprites:
|
||||||
self.screen.blit(sprite.image, self.camera.apply(sprite))
|
self.screen.blit(sprite.image, self.camera.apply(sprite))
|
||||||
if self.draw_debug:
|
if self.debug_mode:
|
||||||
pg.draw.rect(self.screen, CYAN, self.camera.apply_rect(sprite.hit_rect), 1)
|
pg.draw.rect(self.screen, CYAN, self.camera.apply_rect(sprite.hit_rect), 1)
|
||||||
|
|
||||||
#finally update screen
|
#finally update screen
|
||||||
@ -107,7 +116,7 @@ class Game():
|
|||||||
if event.key == pg.K_ESCAPE:
|
if event.key == pg.K_ESCAPE:
|
||||||
self.quit()
|
self.quit()
|
||||||
if event.key == pg.K_h:
|
if event.key == pg.K_h:
|
||||||
self.draw_debug = not self.draw_debug
|
self.debug_mode = not self.debug_mode
|
||||||
if event.type == pg.MOUSEBUTTONUP:
|
if event.type == pg.MOUSEBUTTONUP:
|
||||||
pos = pg.mouse.get_pos()
|
pos = pg.mouse.get_pos()
|
||||||
clicked_coords = [math.floor(pos[0] / TILESIZE), math.floor(pos[1] / TILESIZE)]
|
clicked_coords = [math.floor(pos[0] / TILESIZE), math.floor(pos[1] / TILESIZE)]
|
||||||
|
13
map/map.py
13
map/map.py
@ -4,7 +4,18 @@ import pygame as pg
|
|||||||
from settings import *
|
from settings import *
|
||||||
|
|
||||||
def get_tiles():
|
def get_tiles():
|
||||||
array = map_utils.generate_map()
|
# array = map_utils.generate_map()
|
||||||
|
array = map_utils.get_blank_map_array()
|
||||||
|
|
||||||
|
array[1][1] = 1
|
||||||
|
array[1][2] = 1
|
||||||
|
array[1][3] = 1
|
||||||
|
array[1][4] = 1
|
||||||
|
array[1][5] = 1
|
||||||
|
array[1][6] = 1
|
||||||
|
|
||||||
|
array[2][5] = 1
|
||||||
|
|
||||||
pattern = map_pattern.get_pattern()
|
pattern = map_pattern.get_pattern()
|
||||||
tiles = map_utils.get_sprites(array, pattern)
|
tiles = map_utils.get_sprites(array, pattern)
|
||||||
return tiles, array
|
return tiles, array
|
||||||
|
112
path_search_algorthms/a_star.py
Normal file
112
path_search_algorthms/a_star.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from data_structures.heap import Heap
|
||||||
|
from path_search_algorthms import a_star_utils as utils
|
||||||
|
|
||||||
|
def search_path(start_x: int, start_y: int, target_x: int, target_y: int, array: list[list[int]]) -> list[str]:
|
||||||
|
|
||||||
|
start_node = utils.Node(start_x, start_y, utils.Rotation.RIGHT)
|
||||||
|
target_node = utils.Node(target_x, target_y, utils.Rotation.NONE)
|
||||||
|
|
||||||
|
# heap version
|
||||||
|
|
||||||
|
# nodes for check
|
||||||
|
search_list = Heap()
|
||||||
|
search_list.append(start_node, 0)
|
||||||
|
|
||||||
|
# checked nodes
|
||||||
|
searched_list: list[(int, int)] = []
|
||||||
|
|
||||||
|
while (search_list.length() > 0):
|
||||||
|
node: utils.Node = search_list.take_first()
|
||||||
|
|
||||||
|
searched_list.append((node.x, node.y))
|
||||||
|
|
||||||
|
# check for target node
|
||||||
|
if ((node.x, node.y) == (target_x, target_y)):
|
||||||
|
return trace_path(node)
|
||||||
|
|
||||||
|
# neightbours processing
|
||||||
|
neighbours = utils.get_neighbours(node, searched_list, array)
|
||||||
|
for neighbour in neighbours:
|
||||||
|
|
||||||
|
# calculate new g cost for neightbour (start -> node -> neightbour)
|
||||||
|
new_neighbour_cost = node.g_cost + utils.get_neighbour_cost(node, neighbour)
|
||||||
|
|
||||||
|
if (new_neighbour_cost < neighbour.g_cost or not search_list.contains(neighbour)):
|
||||||
|
|
||||||
|
# replace cost and set parent node
|
||||||
|
neighbour.g_cost = new_neighbour_cost
|
||||||
|
neighbour.h_cost = utils.get_h_cost(neighbour, target_node)
|
||||||
|
neighbour.parent = node
|
||||||
|
|
||||||
|
# add to search
|
||||||
|
if(not search_list.contains(neighbour)):
|
||||||
|
search_list.append(neighbour, neighbour.f_cost())
|
||||||
|
|
||||||
|
# array version
|
||||||
|
|
||||||
|
# nodes for check
|
||||||
|
# search_list = [start_node]
|
||||||
|
|
||||||
|
# checked nodes
|
||||||
|
# searched_list: list[(int, int)] = []
|
||||||
|
|
||||||
|
# while (len(search_list) > 0):
|
||||||
|
# node = search_list[0]
|
||||||
|
|
||||||
|
# # find cheapest node in search_list
|
||||||
|
# for i in range(1, len(search_list)):
|
||||||
|
# if (search_list[i].f_cost() <= node.f_cost()):
|
||||||
|
# if(search_list[i].h_cost < node.h_cost):
|
||||||
|
# node = search_list[i]
|
||||||
|
|
||||||
|
# search_list.remove(node)
|
||||||
|
# searched_list.append((node.x, node.y))
|
||||||
|
|
||||||
|
# # check for target node
|
||||||
|
# if ((node.x, node.y) == (target_x, target_y)):
|
||||||
|
# return trace_path(node)
|
||||||
|
|
||||||
|
# # neightbours processing
|
||||||
|
# neighbours = utils.get_neighbours(node, searched_list, array)
|
||||||
|
# for neighbour in neighbours:
|
||||||
|
|
||||||
|
# # calculate new g cost for neightbour (start -> node -> neightbour)
|
||||||
|
# new_neighbour_cost = node.g_cost + utils.get_neighbour_cost(node, neighbour)
|
||||||
|
|
||||||
|
# if (new_neighbour_cost < neighbour.g_cost or neighbour not in search_list):
|
||||||
|
|
||||||
|
# # replace cost and set parent node
|
||||||
|
# neighbour.g_cost = new_neighbour_cost
|
||||||
|
# neighbour.h_cost = utils.get_h_cost(neighbour, target_node)
|
||||||
|
# neighbour.parent = node
|
||||||
|
|
||||||
|
# # add to search
|
||||||
|
# if(neighbour not in search_list):
|
||||||
|
# search_list.append(neighbour)
|
||||||
|
|
||||||
|
def trace_path(end_node: utils.Node) -> list[str]:
|
||||||
|
path = []
|
||||||
|
node = end_node
|
||||||
|
|
||||||
|
# set final rotation of end_node because we don't do it before
|
||||||
|
node.rotation = utils.get_needed_rotation(node.parent, node)
|
||||||
|
|
||||||
|
while (node.parent != 0):
|
||||||
|
move = utils.get_move(node.parent, node)
|
||||||
|
path += move
|
||||||
|
node = node.parent
|
||||||
|
|
||||||
|
# delete move on initial tile
|
||||||
|
path.pop()
|
||||||
|
|
||||||
|
# we found path from end, so we need to reverse it (get_move reverse move words)
|
||||||
|
path.reverse()
|
||||||
|
|
||||||
|
# last forward to destination
|
||||||
|
path.append("forward")
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
102
path_search_algorthms/a_star_utils.py
Normal file
102
path_search_algorthms/a_star_utils.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from settings import *
|
||||||
|
|
||||||
|
ROAD_TILE = 0
|
||||||
|
|
||||||
|
class Rotation(Enum):
|
||||||
|
UP = 0
|
||||||
|
RIGHT = 1
|
||||||
|
DOWN = 2
|
||||||
|
LEFT = 3
|
||||||
|
NONE = 100
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, x: int, y: int, rotation: Rotation):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
self.g_cost = 0
|
||||||
|
self.h_cost = 0
|
||||||
|
self.parent = 0
|
||||||
|
self.rotation = rotation
|
||||||
|
|
||||||
|
def f_cost(self):
|
||||||
|
return self.g_cost + self.h_cost
|
||||||
|
|
||||||
|
def get_neighbours(node: Node, searched_list: list[Node], array: list[list[int]]) -> list[Node]:
|
||||||
|
neighbours = []
|
||||||
|
for offset_x in range (-1, 2):
|
||||||
|
for offset_y in range (-1, 2):
|
||||||
|
# don't look for cross neighbours
|
||||||
|
if(abs(offset_x) + abs(offset_y) == 1):
|
||||||
|
x = node.x + offset_x
|
||||||
|
y = node.y + offset_y
|
||||||
|
# prevent out of map coords
|
||||||
|
if (x >= 0 and x <= MAP_WIDTH and y >= 0 and y <= MAP_HEIGHT):
|
||||||
|
if(array[y][x] == ROAD_TILE and (x, y) not in searched_list):
|
||||||
|
neighbour = Node(x, y, Rotation.NONE)
|
||||||
|
neighbour.rotation = get_needed_rotation(node, neighbour)
|
||||||
|
neighbours.append(neighbour)
|
||||||
|
return neighbours
|
||||||
|
|
||||||
|
# move cost schema:
|
||||||
|
# - move from tile to tile: 10
|
||||||
|
# - add extra 10 (1 rotation) if it exists
|
||||||
|
def get_h_cost(start_node: Node, target_node: Node) -> int:
|
||||||
|
distance_x = abs(start_node.x - target_node.x)
|
||||||
|
distance_y = abs(start_node.y - target_node.y)
|
||||||
|
cost = (distance_x + distance_y) * 10
|
||||||
|
|
||||||
|
if(distance_x > 0 and distance_y > 0):
|
||||||
|
cost += 10
|
||||||
|
|
||||||
|
return cost
|
||||||
|
|
||||||
|
# move cost schema:
|
||||||
|
# - move from tile to tile: 10
|
||||||
|
# - every rotation 90*: 10
|
||||||
|
def get_neighbour_cost(start_node: Node, target_node: Node) -> int:
|
||||||
|
new_rotation = get_needed_rotation(start_node, target_node)
|
||||||
|
rotate_change = abs(get_rotate_change(start_node.rotation, new_rotation))
|
||||||
|
if (rotate_change == 0):
|
||||||
|
return 10
|
||||||
|
elif (rotate_change == 1 or rotate_change == 3):
|
||||||
|
return 20
|
||||||
|
else:
|
||||||
|
return 30
|
||||||
|
|
||||||
|
# translate rotation change to move
|
||||||
|
def get_move(start_node: Node, target_node: Node) -> list[str]:
|
||||||
|
rotate_change = get_rotate_change(start_node.rotation, target_node.rotation)
|
||||||
|
if (rotate_change == 0):
|
||||||
|
return ["forward"]
|
||||||
|
if (abs(rotate_change) == 2):
|
||||||
|
return ["right", "right", "forward"]
|
||||||
|
if (rotate_change < 0 or rotate_change == 3):
|
||||||
|
return ["right", "forward"]
|
||||||
|
else:
|
||||||
|
return ["left", "forward"]
|
||||||
|
|
||||||
|
# simple calc func
|
||||||
|
def get_rotate_change(rotationA: Rotation, rotationB: Rotation) -> int:
|
||||||
|
return int(rotationA) - int(rotationB)
|
||||||
|
|
||||||
|
# get new rotation for target_node as neighbour of start_node
|
||||||
|
def get_needed_rotation(start_node: Node, target_node: Node) -> Rotation:
|
||||||
|
if (start_node.x - target_node.x > 0):
|
||||||
|
return Rotation.LEFT
|
||||||
|
if (start_node.x - target_node.x < 0):
|
||||||
|
return Rotation.RIGHT
|
||||||
|
if (start_node.y - target_node.y > 0):
|
||||||
|
return Rotation.UP
|
||||||
|
if (start_node.y - target_node.y < 0):
|
||||||
|
return Rotation.DOWN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user