from typing import Iterator, Protocol, List, TypeVar, Tuple, Dict from PatchType import PatchType GridLocation = Tuple[int, int] Location = TypeVar('Location') class Graph(Protocol): def neighbors(self, id: Location) -> List[Location]: pass class SquareGrid: def __init__(self, width: int, height: int): self.width = width self.height = height self.walls: List[GridLocation] = [] self.puddles: List[GridLocation] = [] self.packingStations: List[tuple[PatchType, GridLocation]] = [] def in_bounds(self, id: GridLocation) -> bool: (x, y) = id return 0 <= x < self.width and 0 <= y < self.height def passable(self, id: GridLocation) -> bool: return id not in self.walls def neighbors(self, id: GridLocation) -> Iterator[GridLocation]: (x, y) = id neighbors = [(x + 1, y), (x - 1, y), (x, y - 1), (x, y + 1)] if (x + y) % 2 == 0: neighbors.reverse() results = filter(self.in_bounds, neighbors) results = filter(self.passable, results) return results class WeightedGraph(Graph): def cost(self, from_id: Location, to_id: Location) -> float: pass class GridWithWeights(SquareGrid): def __init__(self, width: int, height: int): super().__init__(width, height) self.weights: Dict[GridLocation, float] = {} def cost(self, from_node: GridLocation, to_node: GridLocation) -> float: return self.weights.get(to_node, 1) # utility functions for dealing with square grids def from_id_width(id, width): return (id % width, id // width) def inverse_y(height, y): return height - y class Graph(Protocol): def neighbors(self, id: Location) -> List[Location]: pass