3RNN/Lib/site-packages/optree/typing.py
2024-05-26 19:49:15 +02:00

424 lines
14 KiB
Python

# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Typing utilities for OpTree."""
from __future__ import annotations
import types
from collections.abc import Hashable
from typing import (
Any,
Callable,
DefaultDict,
Deque,
Dict,
ForwardRef,
Generic,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from typing_extensions import NamedTuple # Generic NamedTuple: Python 3.11+
from typing_extensions import OrderedDict # Generic OrderedDict: Python 3.7.2+
from typing_extensions import Self # Python 3.11+
from typing_extensions import TypeAlias # Python 3.10+
from typing_extensions import Final, Protocol, runtime_checkable # Python 3.8+
from optree import _C
from optree._C import PyTreeKind, PyTreeSpec
__all__ = [
'PyTreeSpec',
'PyTreeDef',
'PyTreeKind',
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
'Children',
'MetaData',
'FlattenFunc',
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_instance',
'is_namedtuple_class',
'namedtuple_fields',
'is_structseq',
'is_structseq_instance',
'is_structseq_class',
'structseq_fields',
'T',
'S',
'U',
'KT',
'VT',
'Iterable',
'Sequence',
'List',
'Tuple',
'NamedTuple',
'Dict',
'OrderedDict',
'DefaultDict',
'Deque',
]
PyTreeDef = PyTreeSpec # alias
T = TypeVar('T')
S = TypeVar('S')
U = TypeVar('U')
KT = TypeVar('KT')
VT = TypeVar('VT')
Children = Iterable[T]
_MetaData = TypeVar('_MetaData', bound=Hashable)
MetaData = Optional[_MetaData]
@runtime_checkable
class CustomTreeNode(Protocol[T]):
"""The abstract base class for custom pytree nodes."""
def tree_flatten(
self,
) -> (
# Use `range(num_children)` as path entries
tuple[Children[T], MetaData]
|
# With optionally implemented path entries
tuple[Children[T], MetaData, Iterable[Any] | None]
):
"""Flatten the custom pytree node into children and auxiliary data."""
@classmethod
def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTreeNode[T]:
"""Unflatten the children and auxiliary data into the custom pytree node."""
_GenericAlias = type(Union[int, str])
def _tp_cache(func: Callable) -> Callable:
import functools # pylint: disable=import-outside-toplevel
cached = functools.lru_cache()(func)
@functools.wraps(func)
def inner(*args: Any, **kwds: Any) -> Any:
try:
return cached(*args, **kwds)
except TypeError:
# All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner
class PyTree(Generic[T]): # pylint: disable=too-few-public-methods
"""Generic PyTree type.
>>> import torch
>>> from optree.typing import PyTree
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree # doctest: +IGNORE_WHITESPACE
typing.Union[torch.Tensor,
typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
typing.List[ForwardRef('PyTree[torch.Tensor]')],
typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
typing.Deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
"""
@_tp_cache
def __class_getitem__(
cls,
item: T | tuple[T] | tuple[T, str | None],
) -> TypeAlias:
"""Instantiate a PyTree type with the given type."""
if not isinstance(item, tuple):
item = (item, None)
if len(item) != 2:
raise TypeError(
f'{cls.__name__}[...] only supports a tuple of 2 items, '
f'a parameter and a string of type name, got {item!r}.',
)
param, name = item
if name is not None and not isinstance(name, str):
raise TypeError(
f'{cls.__name__}[...] only supports a tuple of 2 items, '
f'a parameter and a string of type name, got {item!r}.',
)
if (
isinstance(param, _GenericAlias)
and param.__origin__ is Union # type: ignore[attr-defined]
and hasattr(param, '__pytree_args__')
):
return param # PyTree[PyTree[T]] -> PyTree[T]
if name is not None:
recurse_ref = ForwardRef(name)
elif isinstance(param, TypeVar):
recurse_ref = ForwardRef(f'{cls.__name__}[{param.__name__}]')
elif isinstance(param, type):
if param.__module__ == 'builtins':
typename = param.__qualname__
else:
try:
typename = f'{param.__module__}.{param.__qualname__}'
except AttributeError:
typename = f'{param.__module__}.{param.__name__}'
recurse_ref = ForwardRef(f'{cls.__name__}[{typename}]')
else:
recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]')
pytree_alias = Union[
param, # type: ignore[valid-type]
Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence
List[recurse_ref], # type: ignore[valid-type]
Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
Deque[recurse_ref], # type: ignore[valid-type]
CustomTreeNode[recurse_ref], # type: ignore[valid-type]
]
pytree_alias.__pytree_args__ = item # type: ignore[attr-defined]
return pytree_alias
def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ
"""Prohibit instantiation."""
raise TypeError('Cannot instantiate special typing classes.')
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')
def __copy__(self) -> PyTree:
"""Immutable copy."""
return self
def __deepcopy__(self, memo: dict[int, Any]) -> PyTree:
"""Immutable copy."""
return self
class PyTreeTypeVar:
"""Type variable for PyTree.
>>> import torch
>>> from optree.typing import PyTreeTypeVar
>>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
>>> TensorTree # doctest: +IGNORE_WHITESPACE
typing.Union[torch.Tensor,
typing.Tuple[ForwardRef('TensorTree'), ...],
typing.List[ForwardRef('TensorTree')],
typing.Dict[typing.Any, ForwardRef('TensorTree')],
typing.Deque[ForwardRef('TensorTree')],
optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
"""
@_tp_cache
def __new__(cls, name: str, param: type) -> TypeAlias:
"""Instantiate a PyTree type variable with the given name and parameter."""
if not isinstance(name, str):
raise TypeError(f'{cls.__name__} only supports a string of type name, got {name!r}.')
return PyTree[param, name] # type: ignore[misc,valid-type]
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')
def __copy__(self) -> TypeAlias:
"""Immutable copy."""
return self
def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias:
"""Immutable copy."""
return self
FlattenFunc = Callable[
[CustomTreeNode[T]],
Union[
Tuple[Children[T], MetaData],
Tuple[Children[T], MetaData, Optional[Iterable[Any]]],
],
]
UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]]
def is_namedtuple(obj: object | type) -> bool:
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
cls = obj if isinstance(obj, type) else type(obj)
return is_namedtuple_class(cls)
def is_namedtuple_instance(obj: object) -> bool:
"""Return whether the object is an instance of namedtuple."""
return is_namedtuple_class(type(obj))
def is_namedtuple_class(cls: type) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return (
isinstance(cls, type)
and issubclass(cls, tuple)
and isinstance(getattr(cls, '_fields', None), tuple)
and all(
type(field) is str # noqa: E721 # pylint: disable=unidiomatic-typecheck
for field in cls._fields # type: ignore[attr-defined]
)
and callable(getattr(cls, '_make', None))
and callable(getattr(cls, '_asdict', None))
)
def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
"""Return the field names of a namedtuple."""
if isinstance(obj, type):
cls = obj
if not is_namedtuple_class(cls):
raise TypeError(f'Expected a collections.namedtuple type, got {cls!r}.')
else:
cls = type(obj)
if not is_namedtuple_class(cls):
raise TypeError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
return cls._fields # type: ignore[attr-defined]
_T_co = TypeVar('_T_co', covariant=True)
class _StructSequenceMeta(type):
def __subclasscheck__(cls, subclass: type) -> bool:
"""Return whether the class is a PyStructSequence type.
>>> import time
>>> issubclass(time.struct_time, structseq)
True
>>> class MyTuple(tuple):
... n_fields = 2
... n_sequence_fields = 2
... n_unnamed_fields = 0
>>> issubclass(MyTuple, structseq)
False
"""
return is_structseq_class(subclass)
def __instancecheck__(cls, instance: Any) -> bool:
"""Return whether the object is a PyStructSequence instance.
>>> import sys
>>> isinstance(sys.float_info, structseq)
True
>>> isinstance((1, 2), structseq)
False
"""
return is_structseq_instance(instance)
# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
# pylint: disable-next=invalid-name,missing-class-docstring
class structseq(tuple, Generic[_T_co], metaclass=_StructSequenceMeta): # type: ignore[misc] # noqa: N801
"""A generic type stub for CPython's ``PyStructSequence`` type."""
n_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
n_sequence_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
n_unnamed_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
def __init_subclass__(cls) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError("type 'structseq' is not an acceptable base type")
# pylint: disable-next=unused-argument,redefined-builtin
def __new__(cls, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
raise NotImplementedError
del _StructSequenceMeta
def is_structseq(obj: object | type) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
return is_structseq_class(cls)
def is_structseq_instance(obj: object) -> bool:
"""Return whether the object is an instance of PyStructSequence."""
return is_structseq_class(type(obj))
# Set if the type allows subclassing (see CPython's Include/object.h)
Py_TPFLAGS_BASETYPE = _C.Py_TPFLAGS_BASETYPE # (1UL << 10)
def is_structseq_class(cls: type) -> bool:
"""Return whether the class is a class of PyStructSequence."""
return (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__bases__ == (tuple,)
# Check PyStructSequence members
and isinstance(getattr(cls, 'n_fields', None), int)
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
# Check the type does not allow subclassing
and not (cls.__flags__ & Py_TPFLAGS_BASETYPE)
)
def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
"""Return the field names of a PyStructSequence."""
if isinstance(obj, type):
cls = obj
if not is_structseq_class(cls):
raise TypeError(f'Expected a PyStructSequence type, got {cls!r}.')
else:
cls = type(obj)
if not is_structseq_class(cls):
raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')
n_sequence_fields: int = cls.n_sequence_fields # type: ignore[attr-defined]
fields: list[str] = []
for name, member in vars(cls).items():
if len(fields) >= n_sequence_fields:
break
if isinstance(member, types.MemberDescriptorType):
fields.append(name)
return tuple(fields)
# Ensure that the behavior is consistent with C++ implementation
# pylint: disable-next=wrong-import-position,ungrouped-imports
from optree._C import (
is_namedtuple,
is_namedtuple_class,
is_namedtuple_instance,
is_structseq,
is_structseq_class,
is_structseq_instance,
namedtuple_fields,
structseq_fields,
)