61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
from typing import Iterator, Protocol, List, TypeVar, Tuple, Dict
|
|
|
|
from PatchType import PatchType
|
|
|
|
GridLocation = Tuple[int, int]
|
|
Location = TypeVar('Location')
|
|
|
|
|
|
class Graph(Protocol):
|
|
def neighbors(self, id: Location) -> List[Location]: pass
|
|
|
|
|
|
class SquareGrid:
|
|
def __init__(self, width: int, height: int):
|
|
self.width = width
|
|
self.height = height
|
|
self.walls: List[GridLocation] = []
|
|
self.puddles: List[GridLocation] = []
|
|
self.packingStations: List[tuple[PatchType, GridLocation]] = []
|
|
|
|
def in_bounds(self, id: GridLocation) -> bool:
|
|
(x, y) = id
|
|
return 0 <= x < self.width and 0 <= y < self.height
|
|
|
|
def passable(self, id: GridLocation) -> bool:
|
|
return id not in self.walls
|
|
|
|
def neighbors(self, id: GridLocation) -> Iterator[GridLocation]:
|
|
(x, y) = id
|
|
neighbors = [(x + 1, y), (x - 1, y), (x, y - 1), (x, y + 1)]
|
|
if (x + y) % 2 == 0: neighbors.reverse()
|
|
results = filter(self.in_bounds, neighbors)
|
|
results = filter(self.passable, results)
|
|
return results
|
|
|
|
|
|
class WeightedGraph(Graph):
|
|
def cost(self, from_id: Location, to_id: Location) -> float: pass
|
|
|
|
|
|
class GridWithWeights(SquareGrid):
|
|
def __init__(self, width: int, height: int):
|
|
super().__init__(width, height)
|
|
self.weights: Dict[GridLocation, float] = {}
|
|
|
|
def cost(self, from_node: GridLocation, to_node: GridLocation) -> float:
|
|
return self.weights.get(to_node, 1)
|
|
|
|
|
|
# utility functions for dealing with square grids
|
|
def from_id_width(id, width):
|
|
return (id % width, id // width)
|
|
|
|
|
|
def inverse_y(height, y):
|
|
return height - y
|
|
|
|
|
|
class Graph(Protocol):
|
|
def neighbors(self, id: Location) -> List[Location]: pass
|