fix: actions

This commit is contained in:
korzepadawid 2022-04-11 22:52:58 +02:00
parent b68013a0cd
commit fbbb521a4b
2 changed files with 36 additions and 33 deletions

View File

@ -1,40 +1,39 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, unique
from typing import Tuple, Optional, List from typing import Tuple, Optional, List
from common.constants import ROWS, COLUMNS from common.constants import ROWS, COLUMNS
FREE_FIELD = ' ' FREE_FIELD = ' '
LEFT = 'LEFT'
RIGHT = 'RIGHT'
UP = 'UP'
DOWN = 'DOWN'
directions = {
LEFT: (0, -1),
RIGHT: (0, 1),
UP: (-1, 0),
DOWN: (1, 0)
}
@unique TURN_LEFT = 'TURN_LEFT'
class Direction(Enum): TURN_RIGHT = 'TURN_RIGHT'
LEFT = (0, -1) FORWARD = 'FORWARD'
RIGHT = (0, 1)
UP = (-1, 0)
DOWN = (1, 0)
@unique
class Action(Enum):
TURN_LEFT = 'TURN_LEFT'
TURN_RIGHT = 'TURN_RIGHT'
FORWARD = 'FORWARD'
@dataclass @dataclass
class State: class State:
position: Tuple[int, int] position: Tuple[int, int]
direction: Direction direction: str
@dataclass @dataclass
class Node: class Node:
state: State state: State
parent: Optional[Node] parent: Optional[Node]
action: Action action: str
cost: int = field(init=False) cost: int = field(init=False)
depth: int = field(init=False) depth: int = field(init=False)
@ -53,15 +52,15 @@ def expand(node: Node) -> List[Node]:
return [child_node(node=node, action=action) for action in actions(node.state)] return [child_node(node=node, action=action) for action in actions(node.state)]
def child_node(node: Node, action: Action) -> Node: def child_node(node: Node, action: str) -> Node:
next_state = result(state=node.state, action=action) next_state = result(state=node.state, action=action)
return Node(state=next_state, parent=node, action=action) return Node(state=next_state, parent=node, action=action)
def next_position(current_position: Tuple[int, int], direction: Direction) -> Tuple[int, int]: def next_position(current_position: Tuple[int, int], direction: str) -> Tuple[int, int]:
x1, y1 = direction.value next_row, next_col = directions[direction]
x2, y2 = current_position row, col = current_position
return x1 + x2, y1 + y2 return next_row + row, next_col + col
def valid_move(position: Tuple[int, int], grid: List[List[str]]) -> bool: def valid_move(position: Tuple[int, int], grid: List[List[str]]) -> bool:
@ -69,32 +68,32 @@ def valid_move(position: Tuple[int, int], grid: List[List[str]]) -> bool:
return grid[row][col] == FREE_FIELD return grid[row][col] == FREE_FIELD
def actions(state: State, grid: List[List[str]]) -> List[Action]: def actions(state: State, grid: List[List[str]]) -> List[str]:
possible_actions = [Action.FORWARD, Action.TURN_LEFT, Action.TURN_RIGHT] possible_actions = [FORWARD, TURN_LEFT, TURN_RIGHT]
row, col = state.position row, col = state.position
direction = state.direction direction = state.direction
if direction == Direction.UP and row == 0: if direction == UP and row == 0:
remove_forward(possible_actions) remove_forward(possible_actions)
if direction == Direction.DOWN and row == ROWS - 1: if direction == DOWN and row == ROWS - 1:
remove_forward(possible_actions) remove_forward(possible_actions)
if direction == Direction.LEFT and col == 0: if direction == LEFT and col == 0:
remove_forward(possible_actions) remove_forward(possible_actions)
if direction == Direction.RIGHT and col == COLUMNS - 1: if direction == RIGHT and col == COLUMNS - 1:
remove_forward(possible_actions) remove_forward(possible_actions)
if not valid_move(next_position(state.position, direction), grid): if FORWARD not in possible_actions and not valid_move(next_position(state.position, direction), grid):
remove_forward(possible_actions) remove_forward(possible_actions)
return possible_actions return possible_actions
def remove_forward(possible_actions: List[Action]) -> None: def remove_forward(possible_actions: List[str]) -> None:
if Action.FORWARD in possible_actions: if FORWARD in possible_actions:
possible_actions.remove(Action.FORWARD) possible_actions.remove(FORWARD)
def result(state: State, action: Action, grid: List[List[str]]) -> State: def result(state: State, action: str, grid: List[List[str]]) -> State:
pass pass
@ -106,7 +105,7 @@ def h(state: State, goal: Tuple[int, int]) -> int:
"""heuristics that calculates Manhattan distance between current position and goal""" """heuristics that calculates Manhattan distance between current position and goal"""
x1, y1 = state.position x1, y1 = state.position
x2, y2 = goal x2, y2 = goal
return abs(x1 - x2) + abs(y1 - y2) # Manhattan distance return abs(x1 - x2) + abs(y1 - y2)
def f(current_node: Node, goal: Tuple[int, int]) -> int: def f(current_node: Node, goal: Tuple[int, int]) -> int:

View File

@ -2,6 +2,7 @@ import random
import pygame import pygame
import algorithms.a_star as a_s
from algorithms.bfs import graphsearch, State from algorithms.bfs import graphsearch, State
from common.constants import * from common.constants import *
from common.helpers import castle_neighbors from common.helpers import castle_neighbors
@ -110,6 +111,9 @@ class Level:
castle_cords = (self.list_castles[0].position[0], self.list_castles[0].position[1]) castle_cords = (self.list_castles[0].position[0], self.list_castles[0].position[1])
goal_list = castle_neighbors(self.map, castle_cords[0], castle_cords[1]) # list of castle neighbors goal_list = castle_neighbors(self.map, castle_cords[0], castle_cords[1]) # list of castle neighbors
print(knight_pos_x, knight_pos_y)
st = a_s.State((knight_pos_x, knight_pos_y), a_s.UP)
print(f'Actions: {a_s.actions(st, self.map)}')
action_list = graphsearch(state, self.map, goal_list) action_list = graphsearch(state, self.map, goal_list)
print(action_list) print(action_list)