Intelegentny_Pszczelarz/.venv/Lib/site-packages/numpy/typing/tests/test_generic_alias.py
2023-06-19 00:49:18 +02:00

189 lines
6.9 KiB
Python

from __future__ import annotations
import sys
import copy
import types
import pickle
import weakref
from typing import TypeVar, Any, Union, Callable
import pytest
import numpy as np
from numpy._typing._generic_alias import _GenericAlias
from typing_extensions import Unpack
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
T1 = TypeVar("T1")
T2 = TypeVar("T2")
DType = _GenericAlias(np.dtype, (ScalarType,))
NDArray = _GenericAlias(np.ndarray, (Any, DType))
# NOTE: The `npt._GenericAlias` *class* isn't quite stable on python >=3.11.
# This is not a problem during runtime (as it's 3.8-exclusive), but we still
# need it for the >=3.9 in order to verify its semantics match
# `types.GenericAlias` replacement. xref numpy/numpy#21526
if sys.version_info >= (3, 9):
DType_ref = types.GenericAlias(np.dtype, (ScalarType,))
NDArray_ref = types.GenericAlias(np.ndarray, (Any, DType_ref))
FuncType = Callable[["_GenericAlias | types.GenericAlias"], Any]
else:
DType_ref = Any
NDArray_ref = Any
FuncType = Callable[["_GenericAlias"], Any]
GETATTR_NAMES = sorted(set(dir(np.ndarray)) - _GenericAlias._ATTR_EXCEPTIONS)
BUFFER = np.array([1], dtype=np.int64)
BUFFER.setflags(write=False)
def _get_subclass_mro(base: type) -> tuple[type, ...]:
class Subclass(base): # type: ignore[misc,valid-type]
pass
return Subclass.__mro__[1:]
class TestGenericAlias:
"""Tests for `numpy._typing._generic_alias._GenericAlias`."""
@pytest.mark.parametrize("name,func", [
("__init__", lambda n: n),
("__init__", lambda n: _GenericAlias(np.ndarray, Any)),
("__init__", lambda n: _GenericAlias(np.ndarray, (Any,))),
("__init__", lambda n: _GenericAlias(np.ndarray, (Any, Any))),
("__init__", lambda n: _GenericAlias(np.ndarray, T1)),
("__init__", lambda n: _GenericAlias(np.ndarray, (T1,))),
("__init__", lambda n: _GenericAlias(np.ndarray, (T1, T2))),
("__origin__", lambda n: n.__origin__),
("__args__", lambda n: n.__args__),
("__parameters__", lambda n: n.__parameters__),
("__mro_entries__", lambda n: n.__mro_entries__([object])),
("__hash__", lambda n: hash(n)),
("__repr__", lambda n: repr(n)),
("__getitem__", lambda n: n[np.float64]),
("__getitem__", lambda n: n[ScalarType][np.float64]),
("__getitem__", lambda n: n[Union[np.int64, ScalarType]][np.float64]),
("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
("__eq__", lambda n: n == n),
("__ne__", lambda n: n != np.ndarray),
("__call__", lambda n: n((1,), np.int64, BUFFER)),
("__call__", lambda n: n(shape=(1,), dtype=np.int64, buffer=BUFFER)),
("subclassing", lambda n: _get_subclass_mro(n)),
("pickle", lambda n: n == pickle.loads(pickle.dumps(n))),
])
def test_pass(self, name: str, func: FuncType) -> None:
"""Compare `types.GenericAlias` with its numpy-based backport.
Checker whether ``func`` runs as intended and that both `GenericAlias`
and `_GenericAlias` return the same result.
"""
value = func(NDArray)
if sys.version_info >= (3, 9):
value_ref = func(NDArray_ref)
assert value == value_ref
@pytest.mark.parametrize("name,func", [
("__copy__", lambda n: n == copy.copy(n)),
("__deepcopy__", lambda n: n == copy.deepcopy(n)),
])
def test_copy(self, name: str, func: FuncType) -> None:
value = func(NDArray)
# xref bpo-45167
GE_398 = (
sys.version_info[:2] == (3, 9) and sys.version_info >= (3, 9, 8)
)
if GE_398 or sys.version_info >= (3, 10, 1):
value_ref = func(NDArray_ref)
assert value == value_ref
def test_dir(self) -> None:
value = dir(NDArray)
if sys.version_info < (3, 9):
return
# A number attributes only exist in `types.GenericAlias` in >= 3.11
if sys.version_info < (3, 11, 0, "beta", 3):
value.remove("__typing_unpacked_tuple_args__")
if sys.version_info < (3, 11, 0, "beta", 1):
value.remove("__unpacked__")
assert value == dir(NDArray_ref)
@pytest.mark.parametrize("name,func,dev_version", [
("__iter__", lambda n: len(list(n)), ("beta", 1)),
("__iter__", lambda n: next(iter(n)), ("beta", 1)),
("__unpacked__", lambda n: n.__unpacked__, ("beta", 1)),
("Unpack", lambda n: Unpack[n], ("beta", 1)),
# The right operand should now have `__unpacked__ = True`,
# and they are thus now longer equivalent
("__ne__", lambda n: n != next(iter(n)), ("beta", 1)),
# >= beta3
("__typing_unpacked_tuple_args__",
lambda n: n.__typing_unpacked_tuple_args__, ("beta", 3)),
# >= beta4
("__class__", lambda n: n.__class__ == type(n), ("beta", 4)),
])
def test_py311_features(
self,
name: str,
func: FuncType,
dev_version: tuple[str, int],
) -> None:
"""Test Python 3.11 features."""
value = func(NDArray)
if sys.version_info >= (3, 11, 0, *dev_version):
value_ref = func(NDArray_ref)
assert value == value_ref
def test_weakref(self) -> None:
"""Test ``__weakref__``."""
value = weakref.ref(NDArray)()
if sys.version_info >= (3, 9, 1): # xref bpo-42332
value_ref = weakref.ref(NDArray_ref)()
assert value == value_ref
@pytest.mark.parametrize("name", GETATTR_NAMES)
def test_getattr(self, name: str) -> None:
"""Test that `getattr` wraps around the underlying type,
aka ``__origin__``.
"""
value = getattr(NDArray, name)
value_ref1 = getattr(np.ndarray, name)
if sys.version_info >= (3, 9):
value_ref2 = getattr(NDArray_ref, name)
assert value == value_ref1 == value_ref2
else:
assert value == value_ref1
@pytest.mark.parametrize("name,exc_type,func", [
("__getitem__", TypeError, lambda n: n[()]),
("__getitem__", TypeError, lambda n: n[Any, Any]),
("__getitem__", TypeError, lambda n: n[Any][Any]),
("isinstance", TypeError, lambda n: isinstance(np.array(1), n)),
("issublass", TypeError, lambda n: issubclass(np.ndarray, n)),
("setattr", AttributeError, lambda n: setattr(n, "__origin__", int)),
("setattr", AttributeError, lambda n: setattr(n, "test", int)),
("getattr", AttributeError, lambda n: getattr(n, "test")),
])
def test_raise(
self,
name: str,
exc_type: type[BaseException],
func: FuncType,
) -> None:
"""Test operations that are supposed to raise."""
with pytest.raises(exc_type):
func(NDArray)
if sys.version_info >= (3, 9):
with pytest.raises(exc_type):
func(NDArray_ref)