"""A mypy_ plugin for managing a number of platform-specific annotations. Its functionality can be split into three distinct parts: * Assigning the (platform-dependent) precisions of certain `~numpy.number` subclasses, including the likes of `~numpy.int_`, `~numpy.intp` and `~numpy.longlong`. See the documentation on :ref:`scalar types ` for a comprehensive overview of the affected classes. Without the plugin the precision of all relevant classes will be inferred as `~typing.Any`. * Removing all extended-precision `~numpy.number` subclasses that are unavailable for the platform in question. Most notably this includes the likes of `~numpy.float128` and `~numpy.complex256`. Without the plugin *all* extended-precision types will, as far as mypy is concerned, be available to all platforms. * Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`. Without the plugin the type will default to `ctypes.c_int64`. .. versionadded:: 1.22 Examples -------- To enable the plugin, one must add it to their mypy `configuration file`_: .. code-block:: ini [mypy] plugins = numpy.typing.mypy_plugin .. _mypy: http://mypy-lang.org/ .. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html """ from __future__ import annotations from collections.abc import Iterable from typing import Final, TYPE_CHECKING, Callable import numpy as np try: import mypy.types from mypy.types import Type from mypy.plugin import Plugin, AnalyzeTypeContext from mypy.nodes import MypyFile, ImportFrom, Statement from mypy.build import PRI_MED _HookFunc = Callable[[AnalyzeTypeContext], Type] MYPY_EX: None | ModuleNotFoundError = None except ModuleNotFoundError as ex: MYPY_EX = ex __all__: list[str] = [] def _get_precision_dict() -> dict[str, str]: names = [ ("_NBitByte", np.byte), ("_NBitShort", np.short), ("_NBitIntC", np.intc), ("_NBitIntP", np.intp), ("_NBitInt", np.int_), ("_NBitLongLong", np.longlong), ("_NBitHalf", np.half), ("_NBitSingle", np.single), ("_NBitDouble", np.double), ("_NBitLongDouble", np.longdouble), ] ret = {} for name, typ in names: n: int = 8 * typ().dtype.itemsize ret[f'numpy._typing._nbit.{name}'] = f"numpy._{n}Bit" return ret def _get_extended_precision_list() -> list[str]: extended_names = [ "uint128", "uint256", "int128", "int256", "float80", "float96", "float128", "float256", "complex160", "complex192", "complex256", "complex512", ] return [i for i in extended_names if hasattr(np, i)] def _get_c_intp_name() -> str: # Adapted from `np.core._internal._getintp_ctype` char = np.dtype('p').char if char == 'i': return "c_int" elif char == 'l': return "c_long" elif char == 'q': return "c_longlong" else: return "c_long" #: A dictionary mapping type-aliases in `numpy._typing._nbit` to #: concrete `numpy.typing.NBitBase` subclasses. _PRECISION_DICT: Final = _get_precision_dict() #: A list with the names of all extended precision `np.number` subclasses. _EXTENDED_PRECISION_LIST: Final = _get_extended_precision_list() #: The name of the ctypes quivalent of `np.intp` _C_INTP: Final = _get_c_intp_name() def _hook(ctx: AnalyzeTypeContext) -> Type: """Replace a type-alias with a concrete ``NBitBase`` subclass.""" typ, _, api = ctx name = typ.name.split(".")[-1] name_new = _PRECISION_DICT[f"numpy._typing._nbit.{name}"] return api.named_type(name_new) if TYPE_CHECKING or MYPY_EX is None: def _index(iterable: Iterable[Statement], id: str) -> int: """Identify the first ``ImportFrom`` instance the specified `id`.""" for i, value in enumerate(iterable): if getattr(value, "id", None) == id: return i raise ValueError("Failed to identify a `ImportFrom` instance " f"with the following id: {id!r}") def _override_imports( file: MypyFile, module: str, imports: list[tuple[str, None | str]], ) -> None: """Override the first `module`-based import with new `imports`.""" # Construct a new `from module import y` statement import_obj = ImportFrom(module, 0, names=imports) import_obj.is_top_level = True # Replace the first `module`-based import statement with `import_obj` for lst in [file.defs, file.imports]: # type: list[Statement] i = _index(lst, module) lst[i] = import_obj class _NumpyPlugin(Plugin): """A mypy plugin for handling versus numpy-specific typing tasks.""" def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc: """Set the precision of platform-specific `numpy.number` subclasses. For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`. """ if fullname in _PRECISION_DICT: return _hook return None def get_additional_deps( self, file: MypyFile ) -> list[tuple[int, str, int]]: """Handle all import-based overrides. * Import platform-specific extended-precision `numpy.number` subclasses (*e.g.* `numpy.float96`, `numpy.float128` and `numpy.complex256`). * Import the appropriate `ctypes` equivalent to `numpy.intp`. """ ret = [(PRI_MED, file.fullname, -1)] if file.fullname == "numpy": _override_imports( file, "numpy._typing._extended_precision", imports=[(v, v) for v in _EXTENDED_PRECISION_LIST], ) elif file.fullname == "numpy.ctypeslib": _override_imports( file, "ctypes", imports=[(_C_INTP, "_c_intp")], ) return ret def plugin(version: str) -> type[_NumpyPlugin]: """An entry-point for mypy.""" return _NumpyPlugin else: def plugin(version: str) -> type[_NumpyPlugin]: """An entry-point for mypy.""" raise MYPY_EX