2022-04-11 19:18:03 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
2022-04-11 19:49:32 +02:00
|
|
|
from typing import Tuple, Optional, List
|
2022-04-11 19:18:03 +02:00
|
|
|
|
2022-04-11 21:52:21 +02:00
|
|
|
from common.constants import ROWS, COLUMNS
|
|
|
|
|
|
|
|
FREE_FIELD = ' '
|
2022-04-11 22:52:58 +02:00
|
|
|
LEFT = 'LEFT'
|
|
|
|
RIGHT = 'RIGHT'
|
|
|
|
UP = 'UP'
|
|
|
|
DOWN = 'DOWN'
|
2022-04-11 21:52:21 +02:00
|
|
|
|
2022-04-11 23:53:50 +02:00
|
|
|
TURN_LEFT = 'TURN_LEFT'
|
|
|
|
TURN_RIGHT = 'TURN_RIGHT'
|
|
|
|
FORWARD = 'FORWARD'
|
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
directions = {
|
|
|
|
LEFT: (0, -1),
|
|
|
|
RIGHT: (0, 1),
|
|
|
|
UP: (-1, 0),
|
|
|
|
DOWN: (1, 0)
|
|
|
|
}
|
2022-04-11 19:49:32 +02:00
|
|
|
|
2022-04-11 19:18:03 +02:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class State:
|
|
|
|
position: Tuple[int, int]
|
2022-04-11 22:52:58 +02:00
|
|
|
direction: str
|
2022-04-11 19:18:03 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Node:
|
|
|
|
state: State
|
|
|
|
parent: Optional[Node]
|
2022-04-11 22:52:58 +02:00
|
|
|
action: str
|
2022-04-11 19:18:03 +02:00
|
|
|
cost: int = field(init=False)
|
|
|
|
depth: int = field(init=False)
|
|
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
|
self.cost = 0 if not self.parent else self.parent.cost + 1
|
|
|
|
self.depth = self.cost
|
2022-04-11 19:49:32 +02:00
|
|
|
|
|
|
|
def __eq__(self, other: Node) -> bool:
|
|
|
|
return self.state == other.state
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash(self.state)
|
|
|
|
|
|
|
|
|
2022-04-11 20:04:53 +02:00
|
|
|
def expand(node: Node) -> List[Node]:
|
|
|
|
return [child_node(node=node, action=action) for action in actions(node.state)]
|
|
|
|
|
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
def child_node(node: Node, action: str) -> Node:
|
2022-04-11 20:04:53 +02:00
|
|
|
next_state = result(state=node.state, action=action)
|
|
|
|
return Node(state=next_state, parent=node, action=action)
|
2022-04-11 19:49:32 +02:00
|
|
|
|
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
def next_position(current_position: Tuple[int, int], direction: str) -> Tuple[int, int]:
|
|
|
|
next_row, next_col = directions[direction]
|
|
|
|
row, col = current_position
|
|
|
|
return next_row + row, next_col + col
|
2022-04-11 21:52:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
def valid_move(position: Tuple[int, int], grid: List[List[str]]) -> bool:
|
|
|
|
row, col = position
|
|
|
|
return grid[row][col] == FREE_FIELD
|
|
|
|
|
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
def actions(state: State, grid: List[List[str]]) -> List[str]:
|
|
|
|
possible_actions = [FORWARD, TURN_LEFT, TURN_RIGHT]
|
2022-04-11 21:52:21 +02:00
|
|
|
row, col = state.position
|
|
|
|
direction = state.direction
|
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
if direction == UP and row == 0:
|
2022-04-11 21:52:21 +02:00
|
|
|
remove_forward(possible_actions)
|
2022-04-11 22:52:58 +02:00
|
|
|
if direction == DOWN and row == ROWS - 1:
|
2022-04-11 21:52:21 +02:00
|
|
|
remove_forward(possible_actions)
|
2022-04-11 22:52:58 +02:00
|
|
|
if direction == LEFT and col == 0:
|
2022-04-11 21:52:21 +02:00
|
|
|
remove_forward(possible_actions)
|
2022-04-11 22:52:58 +02:00
|
|
|
if direction == RIGHT and col == COLUMNS - 1:
|
2022-04-11 21:52:21 +02:00
|
|
|
remove_forward(possible_actions)
|
2022-04-11 19:49:32 +02:00
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
if FORWARD not in possible_actions and not valid_move(next_position(state.position, direction), grid):
|
2022-04-11 21:52:21 +02:00
|
|
|
remove_forward(possible_actions)
|
2022-04-11 19:49:32 +02:00
|
|
|
|
2022-04-11 21:52:21 +02:00
|
|
|
return possible_actions
|
|
|
|
|
|
|
|
|
2022-04-11 22:52:58 +02:00
|
|
|
def remove_forward(possible_actions: List[str]) -> None:
|
|
|
|
if FORWARD in possible_actions:
|
|
|
|
possible_actions.remove(FORWARD)
|
2022-04-11 21:52:21 +02:00
|
|
|
|
|
|
|
|
2022-04-11 23:53:50 +02:00
|
|
|
def result(state: State, action: str) -> State:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state = State(state.position, state.direction)
|
2022-04-11 23:53:50 +02:00
|
|
|
|
|
|
|
if state.direction == UP:
|
|
|
|
if action == TURN_LEFT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = LEFT
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == TURN_RIGHT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = RIGHT
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == FORWARD:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.position = next_position(state.position, UP)
|
2022-04-11 23:53:50 +02:00
|
|
|
|
|
|
|
elif state.direction == DOWN:
|
|
|
|
if action == TURN_LEFT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = RIGHT
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == TURN_RIGHT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = LEFT
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == FORWARD:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.position = next_position(state.position, DOWN)
|
2022-04-11 23:53:50 +02:00
|
|
|
|
|
|
|
elif state.direction == LEFT:
|
|
|
|
if action == TURN_LEFT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = DOWN
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == TURN_RIGHT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = UP
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == FORWARD:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.position = next_position(state.position, LEFT)
|
2022-04-11 23:53:50 +02:00
|
|
|
|
|
|
|
elif state.direction == RIGHT:
|
|
|
|
if action == TURN_LEFT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = UP
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == TURN_RIGHT:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.direction = DOWN
|
2022-04-11 23:53:50 +02:00
|
|
|
elif action == FORWARD:
|
2022-04-11 23:56:35 +02:00
|
|
|
next_state.position = next_position(state.position, RIGHT)
|
2022-04-11 23:53:50 +02:00
|
|
|
|
2022-04-11 23:56:35 +02:00
|
|
|
return next_state
|
2022-04-11 19:49:32 +02:00
|
|
|
|
|
|
|
|
|
|
|
def goal_test(state: State, goal_list: List[Tuple[int, int]]) -> bool:
|
2022-04-11 21:52:21 +02:00
|
|
|
return state.position in goal_list
|
2022-04-11 19:49:32 +02:00
|
|
|
|
|
|
|
|
2022-04-11 21:52:21 +02:00
|
|
|
def h(state: State, goal: Tuple[int, int]) -> int:
|
|
|
|
"""heuristics that calculates Manhattan distance between current position and goal"""
|
|
|
|
x1, y1 = state.position
|
|
|
|
x2, y2 = goal
|
2022-04-11 22:52:58 +02:00
|
|
|
return abs(x1 - x2) + abs(y1 - y2)
|
2022-04-11 19:49:32 +02:00
|
|
|
|
|
|
|
|
2022-04-11 21:52:21 +02:00
|
|
|
def f(current_node: Node, goal: Tuple[int, int]) -> int:
|
|
|
|
"""f(n) = g(n) + h(n), g stands for current cost, h for heuristics"""
|
|
|
|
return current_node.cost + h(state=current_node.state, goal=goal)
|