# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dis
import inspect

from dataclasses import dataclass
from typing import Union

from . import DimList

_vmap_levels = []


@dataclass
class LevelInfo:
    level: int
    alive: bool = True


class Dim:
    def __init__(self, name: str, size: Union[None, int] = None):
        self.name = name
        self._size = None
        self._vmap_level = None
        if size is not None:
            self.size = size

    def __del__(self):
        if self._vmap_level is not None:
            _vmap_active_levels[self._vmap_stack].alive = False  # noqa: F821
            while (
                not _vmap_levels[-1].alive
                and current_level() == _vmap_levels[-1].level  # noqa: F821
            ):
                _vmap_decrement_nesting()  # noqa: F821
                _vmap_levels.pop()

    @property
    def size(self):
        assert self.is_bound
        return self._size

    @size.setter
    def size(self, size: int):
        from . import DimensionBindError

        if self._size is None:
            self._size = size
            self._vmap_level = _vmap_increment_nesting(size, "same")  # noqa: F821
            self._vmap_stack = len(_vmap_levels)
            _vmap_levels.append(LevelInfo(self._vmap_level))

        elif self._size != size:
            raise DimensionBindError(
                f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
            )

    @property
    def is_bound(self):
        return self._size is not None

    def __repr__(self):
        return self.name


def extract_name(inst):
    assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
    return inst.argval


_cache = {}


def dims(lists=0):
    frame = inspect.currentframe()
    assert frame is not None
    calling_frame = frame.f_back
    assert calling_frame is not None
    code, lasti = calling_frame.f_code, calling_frame.f_lasti
    key = (code, lasti)
    if key not in _cache:
        first = lasti // 2 + 1
        instructions = list(dis.get_instructions(calling_frame.f_code))
        unpack = instructions[first]

        if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
            # just a single dim, not a list
            name = unpack.argval
            ctor = Dim if lists == 0 else DimList
            _cache[key] = lambda: ctor(name=name)
        else:
            assert unpack.opname == "UNPACK_SEQUENCE"
            ndims = unpack.argval
            names = tuple(
                extract_name(instructions[first + 1 + i]) for i in range(ndims)
            )
            first_list = len(names) - lists
            _cache[key] = lambda: tuple(
                Dim(n) if i < first_list else DimList(name=n)
                for i, n in enumerate(names)
            )
    return _cache[key]()


def _dim_set(positional, arg):
    def convert(a):
        if isinstance(a, Dim):
            return a
        else:
            assert isinstance(a, int)
            return positional[a]

    if arg is None:
        return positional
    elif not isinstance(arg, (Dim, int)):
        return tuple(convert(a) for a in arg)
    else:
        return (convert(arg),)