A* works now ;)

This commit is contained in:
s473603 2023-05-04 15:47:01 +02:00
parent 10ebc02f8f
commit e270f6f519

View File

@ -2,6 +2,7 @@ import random
import time import time
from heapq import * from heapq import *
from enum import Enum, IntEnum from enum import Enum, IntEnum
from queue import PriorityQueue
from collections import deque from collections import deque
import pygame import pygame
@ -12,12 +13,12 @@ WHITE = (200, 200, 200)
BLUE = (46, 34, 240) BLUE = (46, 34, 240)
WINDOW_DIMENSIONS = 900 WINDOW_DIMENSIONS = 900
BLOCK_SIZE = 60 BLOCK_SIZE = 60
ROCKS_NUMBER = 15 ROCKS_NUMBER = 20
VEGETABLES_NUMBER = 20 VEGETABLES_NUMBER = 20
VEGETABLES = ('Potato', 'Broccoli', 'Carrot', 'Onion') VEGETABLES = ('Potato', 'Broccoli', 'Carrot', 'Onion')
BOARD_SIZE = int(WINDOW_DIMENSIONS / BLOCK_SIZE) BOARD_SIZE = int(WINDOW_DIMENSIONS / BLOCK_SIZE)
WATER_TANK_CAPACITY = 10 WATER_TANK_CAPACITY = 10
GAS_TANK_CAPACITY = 100 GAS_TANK_CAPACITY = 250
SPAWN_POINT = (0, 0) SPAWN_POINT = (0, 0)
tractor_image = pygame.transform.scale(pygame.image.load("images/tractor_image.png"), (BLOCK_SIZE, BLOCK_SIZE)) tractor_image = pygame.transform.scale(pygame.image.load("images/tractor_image.png"), (BLOCK_SIZE, BLOCK_SIZE))
@ -92,7 +93,7 @@ def draw_interface():
elif event.type == pygame.MOUSEBUTTONDOWN: elif event.type == pygame.MOUSEBUTTONDOWN:
startpoint = (tractor.x, tractor.y, tractor.direction) startpoint = (tractor.x, tractor.y, tractor.direction)
endpoint = get_click_mouse_pos() endpoint = get_click_mouse_pos()
a, c = graph1.dijkstra(startpoint, endpoint) a, c = graph1.a_star(startpoint, endpoint)
b = getRoad(startpoint, c, a) b = getRoad(startpoint, c, a)
movement(tractor, grid, b) movement(tractor, grid, b)
updateDisplay(tractor, grid) updateDisplay(tractor, grid)
@ -169,32 +170,31 @@ class Graph:
for direction in Direction: for direction in Direction:
self.graph[(x, y, direction)] = get_next_nodes(x, y, direction, grid) self.graph[(x, y, direction)] = get_next_nodes(x, y, direction, grid)
def dijkstra(self, start, goal): def a_star(self, start, goal):
# not finished yet https://www.youtube.com/watch?v=abHftC1GU6w # not finished yet https://www.youtube.com/watch?v=abHftC1GU6w
queue = [] queue = PriorityQueue()
heappush(queue, (0, start)) queue.put((0, start))
cost_visited = {start: 0} cost_visited = {start: 0}
visited = {start: None} visited = {start: None}
returnGoal = goal returnGoal = goal
h = lambda start, goal: abs(start[0] - goal[0]) + abs(start[1] - goal[1]) #heuristic function (manhattan distance)
while queue: while not queue.empty():
cur_cost, cur_node = heappop(queue) cur_cost, cur_node = queue.get()
if cur_node[0] == goal[0] and cur_node[1] == goal[1]: if cur_node[0] == goal[0] and cur_node[1] == goal[1]:
queue = []
returnGoal=cur_node returnGoal=cur_node
break break
next_nodes = self.graph[cur_node] next_nodes = self.graph[cur_node]
# print()
for next_node in next_nodes: for next_node in next_nodes:
neigh_cost, neigh_node = next_node neigh_cost, neigh_node = next_node
# print(neigh_node)
new_cost = cost_visited[cur_node] + neigh_cost new_cost = cost_visited[cur_node] + neigh_cost + h(neigh_node, goal)
if neigh_node not in cost_visited or new_cost < cost_visited[neigh_node]: if neigh_node not in cost_visited or new_cost < cost_visited[neigh_node]:
heappush(queue, (new_cost, neigh_node)) queue.put((new_cost, neigh_node))
cost_visited[neigh_node] = new_cost cost_visited[neigh_node] = new_cost - h(neigh_node, goal)
visited[neigh_node] = cur_node visited[neigh_node] = cur_node
# print(visited, returnGoal) # print(visited, returnGoal)
return visited, returnGoal return visited, returnGoal
@ -233,9 +233,9 @@ class Tractor:
self.x -= 1 self.x -= 1
if grid.grid[self.x][self.y] == types.ROCK: if grid.grid[self.x][self.y] == types.ROCK:
self.gas -= 5 self.gas -= 12
else: else:
self.gas -= 1 self.gas -= 2
return return
@ -253,9 +253,9 @@ def get_next_nodes(x, y, direction: Direction, grid: Grid):
if check_next_node(x + way[0], y + way[1]): if check_next_node(x + way[0], y + way[1]):
if grid.grid[x + way[0]][y + way[1]] == types.ROCK: if grid.grid[x + way[0]][y + way[1]] == types.ROCK:
# print(x, y, "to", x + way[0], y + way[1], 'costs 5') # print(x, y, "to", x + way[0], y + way[1], 'costs 5')
next_nodes.append((5, (x + way[0], y + way[1], new_direction))) next_nodes.append((12, (x + way[0], y + way[1], new_direction)))
else: else:
next_nodes.append((1, (x + way[0], y + way[1], new_direction))) next_nodes.append((2, (x + way[0], y + way[1], new_direction)))
# print(x,y, direction, next_nodes, '\n') # print(x,y, direction, next_nodes, '\n')
return next_nodes return next_nodes