diff --git a/AI_brain/rotate_and_go_a_star.py b/AI_brain/rotate_and_go_a_star.py new file mode 100644 index 0000000..33e9736 --- /dev/null +++ b/AI_brain/rotate_and_go_a_star.py @@ -0,0 +1,124 @@ +import math +import queue +from dataclasses import dataclass, field +from typing import Any + +from domain.world import World + + +class State: + def __init__(self, x, y, direction=(1, 0)): + self.x = x + self.y = y + self.direction = direction + + def __hash__(self): + return hash((self.x, self.y)) + + def __eq__(self, other): + return (self.x == other.x and self.y == other.y + and self.direction == other.direction) + + +class Node: + def __init__(self, state: State): + self.state = state + self.parent = None + self.action = None + self.g = None + self.h = None + + +@dataclass(order=True) +class PrioritizedItem: + priority: int + item: Any = field(compare=False) + + def __iter__(self): + return iter((self.priority, self.item)) + + +def action_sequence(node: Node): + actions = [] + while node.parent: + actions.append(node.action) + node = node.parent + + print(node.g) + actions.reverse() + return actions + + +class RotateAndGoAStar: + def __init__(self, world: World, start_state: State, goal_state: State): + self.world = world + self.start_state = start_state + self.goal_state = goal_state + self.fringe = queue.PriorityQueue() + self.enqueued_states = {} + self.explored = set() + self.actions = [] + + def search(self): + h = abs(self.start_state.x - self.goal_state.x) ** 2 + abs(self.start_state.y - self.goal_state.y) ** 2 + self.fringe.put(PrioritizedItem(h, Node(self.start_state))) + + while not self.fringe.empty(): + priority, elem = self.fringe.get() + + self.enqueued_states.pop(elem, 0) + + if self.is_goal(elem.state): + self.actions = action_sequence(elem) + return True + + self.explored.add(elem.state) + + for (action, state) in self.successors(elem.state): + next_node = Node(state) + next_node.action = action + next_node.parent = elem + next_node.g = abs(elem.state.x - state.x) + abs(elem.state.y - state.y) + self.world.get_cost(state.x, state.y) + if next_node.g > 100: + print(str(state.x) + ":" + str(state.y)) + next_node.h = abs(state.x - self.goal_state.x) ** 2 + abs(state.y - self.goal_state.y) ** 2 + f = next_node.g + next_node.h + + if state not in self.enqueued_states and state not in self.explored: + self.fringe.put(PrioritizedItem(f, next_node)) + self.enqueued_states[state] = f + elif self.enqueued_states.get(state, -math.inf) > f: + self.add_existed(next_node, f) + self.enqueued_states.pop(state, 0) + self.enqueued_states[state] = f + + return False + + def add_existed(self, node: Node, f: int): + old = [] + while not self.fringe.empty(): + e = self.fringe.get() + if e.item.state == node.state: + break + old.append(e) + self.fringe.put(PrioritizedItem(f, node)) + for e in old: + self.fringe.put(e) + + def successors(self, state: State): + new_successors = [ + # rotate right + ("RR", State(state.x, state.y, (-state.direction[1], state.direction[0]))), + # rotate left + ("RL", State(state.x, state.y, (state.direction[1], -state.direction[0]))), + ] + if self.world.accepted_move(state.x + state.direction[0], state.y + state.direction[1]): + new_successors.append( + ("GO", State(state.x + state.direction[0], state.y + state.direction[1], state.direction))) + return new_successors + + def is_goal(self, state: State) -> bool: + return ( + state.x == self.goal_state.x + and state.y == self.goal_state.y + ) diff --git a/domain/world.py b/domain/world.py index def2082..1324bf8 100644 --- a/domain/world.py +++ b/domain/world.py @@ -3,6 +3,7 @@ from domain.entities.entity import Entity class World: def __init__(self, width: int, height: int) -> object: + self.costs = [[0 for j in range(height)] for i in range(width)] self.width = width self.height = height self.dust = [[[] for j in range(height)] for i in range(width)] @@ -47,3 +48,6 @@ class World: return False return True + + def get_cost(self, x, y): + return self.costs[x][y] diff --git a/main.py b/main.py index 0bdd36a..b950899 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,8 @@ from domain.entities.docking_station import Doc_Station from domain.world import World from view.renderer import Renderer # from AI_brain.movement import GoAnyDirectionBFS, State -from AI_brain.rotate_and_go_bfs import RotateAndGoBFS, State +# from AI_brain.rotate_and_go_bfs import RotateAndGoBFS, State +from AI_brain.rotate_and_go_a_star import RotateAndGoAStar, State config = configparser.ConfigParser() @@ -50,7 +51,8 @@ class Main: end_state = State(self.world.doc_station.x, self.world.doc_station.y) # path_searcher = GoAnyDirectionBFS(self.world, start_state, end_state) - path_searcher = RotateAndGoBFS(self.world, start_state, end_state) + # path_searcher = RotateAndGoBFS(self.world, start_state, end_state) + path_searcher = RotateAndGoAStar(self.world, start_state, end_state) if not path_searcher.search(): print("No solution") exit(0) @@ -146,6 +148,16 @@ def generate_world(tiles_x: int, tiles_y: int) -> World: world.add_entity(Entity(3, 4, "PLANT2")) world.add_entity(Entity(8, 8, "PLANT2")) world.add_entity(Entity(9, 3, "PLANT3")) + + # TEST + world.costs[9][3] = 1000 + world.costs[8][3] = 1000 + world.costs[7][3] = 1000 + world.costs[6][3] = 1000 + world.costs[5][3] = 1000 + world.costs[4][3] = 1000 + world.costs[3][3] = 1000 + return world