from dataclasses import dataclass

import numpy as np

from const import *
from typing import List, Dict, Tuple

import numpy.typing as npt


@dataclass
class Position:
    row: int
    col: int


@dataclass
class Area:
    position: Position
    width: int
    height: int


AREAS_TO_CROSS = [
    # up above left knights spawn
    Area(position=Position(row=0, col=0),
         width=KNIGHTS_SPAWN_WIDTH,
         height=LEFT_KNIGHTS_SPAWN_FIRST_ROW),

    # down below left knights spawn
    Area(position=Position(row=LEFT_KNIGHTS_SPAWN_FIRST_ROW + KNIGHTS_SPAWN_HEIGHT, col=0),
         width=KNIGHTS_SPAWN_WIDTH,
         height=ROWS - LEFT_KNIGHTS_SPAWN_FIRST_ROW - KNIGHTS_SPAWN_HEIGHT),

    # between left knights spawn and castle
    Area(position=Position(row=0, col=KNIGHTS_SPAWN_WIDTH),
         width=CASTLE_SPAWN_FIRST_COL - KNIGHTS_SPAWN_WIDTH,
         height=ROWS),

    # up above castle
    Area(position=Position(row=0, col=CASTLE_SPAWN_FIRST_COL),
         width=2,
         height=CASTLE_SPAWN_FIRST_ROW),

    # down below castle
    Area(position=Position(row=CASTLE_SPAWN_FIRST_ROW + 2, col=CASTLE_SPAWN_FIRST_COL),
         width=2,
         height=ROWS - CASTLE_SPAWN_FIRST_ROW - 2),

    # between castle and right knights spawn
    Area(position=Position(row=0, col=CASTLE_SPAWN_FIRST_COL + 2),
         width=RIGHT_KNIGHTS_SPAWN_FIRST_COL - CASTLE_SPAWN_FIRST_COL - 2,
         height=ROWS),

    # up above right knights spawn
    Area(position=Position(row=0, col=RIGHT_KNIGHTS_SPAWN_FIRST_COL),
         width=KNIGHTS_SPAWN_WIDTH,
         height=RIGHT_KNIGHTS_SPAWN_FIRST_ROW),

    # down below right knights spawn
    Area(position=Position(row=RIGHT_KNIGHTS_SPAWN_FIRST_ROW + KNIGHTS_SPAWN_HEIGHT, col=RIGHT_KNIGHTS_SPAWN_FIRST_COL),
         width=KNIGHTS_SPAWN_WIDTH,
         height=ROWS - RIGHT_KNIGHTS_SPAWN_FIRST_ROW - KNIGHTS_SPAWN_HEIGHT),
]


def dfs(grid: npt.NDArray, visited: Dict[Tuple[int, int], bool], position: Position, rows: int, cols: int) -> None:
    visited[(position.row, position.col)] = True

    row_vector = [0, 0, 1, -1]
    col_vector = [-1, 1, 0, 0]

    neighbours = []
    for i in range(4):
        rr = position.row + row_vector[i]
        cc = position.col + col_vector[i]
        if rr < 0 or rr >= ROWS:
            continue
        elif cc < 0 or cc >= COLUMNS:
            continue
        else:
            p = Position(rr, cc)
            if (p.row, p.col) in visited:
                neighbours.append(p)

    for neighbour in neighbours:
        if not visited[(neighbour.row, neighbour.col)]:
            dfs(grid, visited, neighbour, rows, cols)


def get_islands(grid: npt.NDArray, positions: List[Position], rows: int = ROWS, cols: int = COLUMNS) -> List[Position]:
    """it returns list of all islands roots"""
    visited = {}

    for position in positions:
        visited[(position.row, position.col)] = False

    islands = 0
    roots = []
    for position in positions:
        if not visited[(position.row, position.col)]:
            dfs(grid, visited, position, rows, cols)
            roots.append(position)
            islands += 1

    return roots


def find_neighbours(grid: npt.NDArray, col: int, row: int) -> List[Position]:
    dr = [-1, 1, 0, 0]
    dc = [0, 0, -1, 1]

    neighbours = []

    for i in range(4):
        rr = row + dr[i]
        cc = col + dc[i]

        if 0 <= rr < ROWS and 0 <= cc < COLUMNS and grid[rr][cc] == MAP_ALIASES.get('GRASS'):
            neighbours.append(Position(row=rr, col=cc))

    return neighbours