424 lines
14 KiB
Python
424 lines
14 KiB
Python
# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Typing utilities for OpTree."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import types
|
|
from collections.abc import Hashable
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
DefaultDict,
|
|
Deque,
|
|
Dict,
|
|
ForwardRef,
|
|
Generic,
|
|
Iterable,
|
|
List,
|
|
NoReturn,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
from typing_extensions import NamedTuple # Generic NamedTuple: Python 3.11+
|
|
from typing_extensions import OrderedDict # Generic OrderedDict: Python 3.7.2+
|
|
from typing_extensions import Self # Python 3.11+
|
|
from typing_extensions import TypeAlias # Python 3.10+
|
|
from typing_extensions import Final, Protocol, runtime_checkable # Python 3.8+
|
|
|
|
from optree import _C
|
|
from optree._C import PyTreeKind, PyTreeSpec
|
|
|
|
|
|
__all__ = [
|
|
'PyTreeSpec',
|
|
'PyTreeDef',
|
|
'PyTreeKind',
|
|
'PyTree',
|
|
'PyTreeTypeVar',
|
|
'CustomTreeNode',
|
|
'Children',
|
|
'MetaData',
|
|
'FlattenFunc',
|
|
'UnflattenFunc',
|
|
'is_namedtuple',
|
|
'is_namedtuple_instance',
|
|
'is_namedtuple_class',
|
|
'namedtuple_fields',
|
|
'is_structseq',
|
|
'is_structseq_instance',
|
|
'is_structseq_class',
|
|
'structseq_fields',
|
|
'T',
|
|
'S',
|
|
'U',
|
|
'KT',
|
|
'VT',
|
|
'Iterable',
|
|
'Sequence',
|
|
'List',
|
|
'Tuple',
|
|
'NamedTuple',
|
|
'Dict',
|
|
'OrderedDict',
|
|
'DefaultDict',
|
|
'Deque',
|
|
]
|
|
|
|
|
|
PyTreeDef = PyTreeSpec # alias
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
U = TypeVar('U')
|
|
KT = TypeVar('KT')
|
|
VT = TypeVar('VT')
|
|
|
|
|
|
Children = Iterable[T]
|
|
_MetaData = TypeVar('_MetaData', bound=Hashable)
|
|
MetaData = Optional[_MetaData]
|
|
|
|
|
|
@runtime_checkable
|
|
class CustomTreeNode(Protocol[T]):
|
|
"""The abstract base class for custom pytree nodes."""
|
|
|
|
def tree_flatten(
|
|
self,
|
|
) -> (
|
|
# Use `range(num_children)` as path entries
|
|
tuple[Children[T], MetaData]
|
|
|
|
|
# With optionally implemented path entries
|
|
tuple[Children[T], MetaData, Iterable[Any] | None]
|
|
):
|
|
"""Flatten the custom pytree node into children and auxiliary data."""
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTreeNode[T]:
|
|
"""Unflatten the children and auxiliary data into the custom pytree node."""
|
|
|
|
|
|
_GenericAlias = type(Union[int, str])
|
|
|
|
|
|
def _tp_cache(func: Callable) -> Callable:
|
|
import functools # pylint: disable=import-outside-toplevel
|
|
|
|
cached = functools.lru_cache()(func)
|
|
|
|
@functools.wraps(func)
|
|
def inner(*args: Any, **kwds: Any) -> Any:
|
|
try:
|
|
return cached(*args, **kwds)
|
|
except TypeError:
|
|
# All real errors (not unhashable args) are raised below.
|
|
return func(*args, **kwds)
|
|
|
|
return inner
|
|
|
|
|
|
class PyTree(Generic[T]): # pylint: disable=too-few-public-methods
|
|
"""Generic PyTree type.
|
|
|
|
>>> import torch
|
|
>>> from optree.typing import PyTree
|
|
>>> TensorTree = PyTree[torch.Tensor]
|
|
>>> TensorTree # doctest: +IGNORE_WHITESPACE
|
|
typing.Union[torch.Tensor,
|
|
typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
|
|
typing.List[ForwardRef('PyTree[torch.Tensor]')],
|
|
typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
|
|
typing.Deque[ForwardRef('PyTree[torch.Tensor]')],
|
|
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
|
|
"""
|
|
|
|
@_tp_cache
|
|
def __class_getitem__(
|
|
cls,
|
|
item: T | tuple[T] | tuple[T, str | None],
|
|
) -> TypeAlias:
|
|
"""Instantiate a PyTree type with the given type."""
|
|
if not isinstance(item, tuple):
|
|
item = (item, None)
|
|
if len(item) != 2:
|
|
raise TypeError(
|
|
f'{cls.__name__}[...] only supports a tuple of 2 items, '
|
|
f'a parameter and a string of type name, got {item!r}.',
|
|
)
|
|
param, name = item
|
|
if name is not None and not isinstance(name, str):
|
|
raise TypeError(
|
|
f'{cls.__name__}[...] only supports a tuple of 2 items, '
|
|
f'a parameter and a string of type name, got {item!r}.',
|
|
)
|
|
|
|
if (
|
|
isinstance(param, _GenericAlias)
|
|
and param.__origin__ is Union # type: ignore[attr-defined]
|
|
and hasattr(param, '__pytree_args__')
|
|
):
|
|
return param # PyTree[PyTree[T]] -> PyTree[T]
|
|
|
|
if name is not None:
|
|
recurse_ref = ForwardRef(name)
|
|
elif isinstance(param, TypeVar):
|
|
recurse_ref = ForwardRef(f'{cls.__name__}[{param.__name__}]')
|
|
elif isinstance(param, type):
|
|
if param.__module__ == 'builtins':
|
|
typename = param.__qualname__
|
|
else:
|
|
try:
|
|
typename = f'{param.__module__}.{param.__qualname__}'
|
|
except AttributeError:
|
|
typename = f'{param.__module__}.{param.__name__}'
|
|
recurse_ref = ForwardRef(f'{cls.__name__}[{typename}]')
|
|
else:
|
|
recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]')
|
|
|
|
pytree_alias = Union[
|
|
param, # type: ignore[valid-type]
|
|
Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence
|
|
List[recurse_ref], # type: ignore[valid-type]
|
|
Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
|
|
Deque[recurse_ref], # type: ignore[valid-type]
|
|
CustomTreeNode[recurse_ref], # type: ignore[valid-type]
|
|
]
|
|
pytree_alias.__pytree_args__ = item # type: ignore[attr-defined]
|
|
return pytree_alias
|
|
|
|
def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ
|
|
"""Prohibit instantiation."""
|
|
raise TypeError('Cannot instantiate special typing classes.')
|
|
|
|
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
|
|
"""Prohibit subclassing."""
|
|
raise TypeError('Cannot subclass special typing classes.')
|
|
|
|
def __copy__(self) -> PyTree:
|
|
"""Immutable copy."""
|
|
return self
|
|
|
|
def __deepcopy__(self, memo: dict[int, Any]) -> PyTree:
|
|
"""Immutable copy."""
|
|
return self
|
|
|
|
|
|
class PyTreeTypeVar:
|
|
"""Type variable for PyTree.
|
|
|
|
>>> import torch
|
|
>>> from optree.typing import PyTreeTypeVar
|
|
>>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
|
|
>>> TensorTree # doctest: +IGNORE_WHITESPACE
|
|
typing.Union[torch.Tensor,
|
|
typing.Tuple[ForwardRef('TensorTree'), ...],
|
|
typing.List[ForwardRef('TensorTree')],
|
|
typing.Dict[typing.Any, ForwardRef('TensorTree')],
|
|
typing.Deque[ForwardRef('TensorTree')],
|
|
optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
|
|
"""
|
|
|
|
@_tp_cache
|
|
def __new__(cls, name: str, param: type) -> TypeAlias:
|
|
"""Instantiate a PyTree type variable with the given name and parameter."""
|
|
if not isinstance(name, str):
|
|
raise TypeError(f'{cls.__name__} only supports a string of type name, got {name!r}.')
|
|
return PyTree[param, name] # type: ignore[misc,valid-type]
|
|
|
|
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
|
|
"""Prohibit subclassing."""
|
|
raise TypeError('Cannot subclass special typing classes.')
|
|
|
|
def __copy__(self) -> TypeAlias:
|
|
"""Immutable copy."""
|
|
return self
|
|
|
|
def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias:
|
|
"""Immutable copy."""
|
|
return self
|
|
|
|
|
|
FlattenFunc = Callable[
|
|
[CustomTreeNode[T]],
|
|
Union[
|
|
Tuple[Children[T], MetaData],
|
|
Tuple[Children[T], MetaData, Optional[Iterable[Any]]],
|
|
],
|
|
]
|
|
UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]]
|
|
|
|
|
|
def is_namedtuple(obj: object | type) -> bool:
|
|
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
|
|
cls = obj if isinstance(obj, type) else type(obj)
|
|
return is_namedtuple_class(cls)
|
|
|
|
|
|
def is_namedtuple_instance(obj: object) -> bool:
|
|
"""Return whether the object is an instance of namedtuple."""
|
|
return is_namedtuple_class(type(obj))
|
|
|
|
|
|
def is_namedtuple_class(cls: type) -> bool:
|
|
"""Return whether the class is a subclass of namedtuple."""
|
|
return (
|
|
isinstance(cls, type)
|
|
and issubclass(cls, tuple)
|
|
and isinstance(getattr(cls, '_fields', None), tuple)
|
|
and all(
|
|
type(field) is str # noqa: E721 # pylint: disable=unidiomatic-typecheck
|
|
for field in cls._fields # type: ignore[attr-defined]
|
|
)
|
|
and callable(getattr(cls, '_make', None))
|
|
and callable(getattr(cls, '_asdict', None))
|
|
)
|
|
|
|
|
|
def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
|
|
"""Return the field names of a namedtuple."""
|
|
if isinstance(obj, type):
|
|
cls = obj
|
|
if not is_namedtuple_class(cls):
|
|
raise TypeError(f'Expected a collections.namedtuple type, got {cls!r}.')
|
|
else:
|
|
cls = type(obj)
|
|
if not is_namedtuple_class(cls):
|
|
raise TypeError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
|
|
return cls._fields # type: ignore[attr-defined]
|
|
|
|
|
|
_T_co = TypeVar('_T_co', covariant=True)
|
|
|
|
|
|
class _StructSequenceMeta(type):
|
|
def __subclasscheck__(cls, subclass: type) -> bool:
|
|
"""Return whether the class is a PyStructSequence type.
|
|
|
|
>>> import time
|
|
>>> issubclass(time.struct_time, structseq)
|
|
True
|
|
>>> class MyTuple(tuple):
|
|
... n_fields = 2
|
|
... n_sequence_fields = 2
|
|
... n_unnamed_fields = 0
|
|
>>> issubclass(MyTuple, structseq)
|
|
False
|
|
"""
|
|
return is_structseq_class(subclass)
|
|
|
|
def __instancecheck__(cls, instance: Any) -> bool:
|
|
"""Return whether the object is a PyStructSequence instance.
|
|
|
|
>>> import sys
|
|
>>> isinstance(sys.float_info, structseq)
|
|
True
|
|
>>> isinstance((1, 2), structseq)
|
|
False
|
|
"""
|
|
return is_structseq_instance(instance)
|
|
|
|
|
|
# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
|
|
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
|
|
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
|
|
# pylint: disable-next=invalid-name,missing-class-docstring
|
|
class structseq(tuple, Generic[_T_co], metaclass=_StructSequenceMeta): # type: ignore[misc] # noqa: N801
|
|
"""A generic type stub for CPython's ``PyStructSequence`` type."""
|
|
|
|
n_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
|
|
n_sequence_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
|
|
n_unnamed_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
|
|
|
|
def __init_subclass__(cls) -> NoReturn:
|
|
"""Prohibit subclassing."""
|
|
raise TypeError("type 'structseq' is not an acceptable base type")
|
|
|
|
# pylint: disable-next=unused-argument,redefined-builtin
|
|
def __new__(cls, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
|
|
raise NotImplementedError
|
|
|
|
|
|
del _StructSequenceMeta
|
|
|
|
|
|
def is_structseq(obj: object | type) -> bool:
|
|
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
|
|
cls = obj if isinstance(obj, type) else type(obj)
|
|
return is_structseq_class(cls)
|
|
|
|
|
|
def is_structseq_instance(obj: object) -> bool:
|
|
"""Return whether the object is an instance of PyStructSequence."""
|
|
return is_structseq_class(type(obj))
|
|
|
|
|
|
# Set if the type allows subclassing (see CPython's Include/object.h)
|
|
Py_TPFLAGS_BASETYPE = _C.Py_TPFLAGS_BASETYPE # (1UL << 10)
|
|
|
|
|
|
def is_structseq_class(cls: type) -> bool:
|
|
"""Return whether the class is a class of PyStructSequence."""
|
|
return (
|
|
isinstance(cls, type)
|
|
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
|
|
and cls.__bases__ == (tuple,)
|
|
# Check PyStructSequence members
|
|
and isinstance(getattr(cls, 'n_fields', None), int)
|
|
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
|
|
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
|
|
# Check the type does not allow subclassing
|
|
and not (cls.__flags__ & Py_TPFLAGS_BASETYPE)
|
|
)
|
|
|
|
|
|
def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
|
|
"""Return the field names of a PyStructSequence."""
|
|
if isinstance(obj, type):
|
|
cls = obj
|
|
if not is_structseq_class(cls):
|
|
raise TypeError(f'Expected a PyStructSequence type, got {cls!r}.')
|
|
else:
|
|
cls = type(obj)
|
|
if not is_structseq_class(cls):
|
|
raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')
|
|
|
|
n_sequence_fields: int = cls.n_sequence_fields # type: ignore[attr-defined]
|
|
fields: list[str] = []
|
|
for name, member in vars(cls).items():
|
|
if len(fields) >= n_sequence_fields:
|
|
break
|
|
if isinstance(member, types.MemberDescriptorType):
|
|
fields.append(name)
|
|
return tuple(fields)
|
|
|
|
|
|
# Ensure that the behavior is consistent with C++ implementation
|
|
# pylint: disable-next=wrong-import-position,ungrouped-imports
|
|
from optree._C import (
|
|
is_namedtuple,
|
|
is_namedtuple_class,
|
|
is_namedtuple_instance,
|
|
is_structseq,
|
|
is_structseq_class,
|
|
is_structseq_instance,
|
|
namedtuple_fields,
|
|
structseq_fields,
|
|
)
|