1551 lines
50 KiB
Python
1551 lines
50 KiB
Python
"""
|
|
Contains utility functions for working with nested python data structures.
|
|
|
|
A *pytree* is Python nested data structure. It is a tree in the sense that
|
|
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
|
Python values. Furthermore, a pytree should not contain reference cycles.
|
|
|
|
pytrees are useful for working with nested collections of Tensors. For example,
|
|
one can use `tree_map` to map a function over all Tensors inside some nested
|
|
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
|
|
inside some nested collection. pytrees are helpful for implementing nested
|
|
collection support for PyTorch APIs.
|
|
|
|
This pytree implementation is not very performant due to Python overhead
|
|
To improve the performance we can move parts of the implementation to C++.
|
|
"""
|
|
|
|
import dataclasses
|
|
import importlib
|
|
import json
|
|
import sys
|
|
import threading
|
|
import types
|
|
import warnings
|
|
from collections import defaultdict, deque, namedtuple, OrderedDict
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
cast,
|
|
DefaultDict,
|
|
Deque,
|
|
Dict,
|
|
FrozenSet,
|
|
Generic,
|
|
Hashable,
|
|
Iterable,
|
|
List,
|
|
Mapping,
|
|
NamedTuple,
|
|
Optional,
|
|
OrderedDict as GenericOrderedDict,
|
|
overload,
|
|
Protocol,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"PyTree",
|
|
"Context",
|
|
"FlattenFunc",
|
|
"UnflattenFunc",
|
|
"DumpableContext",
|
|
"ToDumpableContextFn",
|
|
"FromDumpableContextFn",
|
|
"TreeSpec",
|
|
"LeafSpec",
|
|
"keystr",
|
|
"key_get",
|
|
"register_pytree_node",
|
|
"tree_flatten",
|
|
"tree_flatten_with_path",
|
|
"tree_unflatten",
|
|
"tree_leaves",
|
|
"tree_leaves_with_path",
|
|
"tree_structure",
|
|
"tree_map",
|
|
"tree_map_with_path",
|
|
"tree_map_",
|
|
"tree_map_only",
|
|
"tree_map_only_",
|
|
"tree_all",
|
|
"tree_any",
|
|
"tree_all_only",
|
|
"tree_any_only",
|
|
"treespec_dumps",
|
|
"treespec_loads",
|
|
"treespec_pprint",
|
|
]
|
|
|
|
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
U = TypeVar("U")
|
|
R = TypeVar("R")
|
|
|
|
|
|
DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
|
|
NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
|
|
|
|
|
|
class KeyEntry(Protocol):
|
|
def __hash__(self) -> int:
|
|
...
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
...
|
|
|
|
def __str__(self) -> str:
|
|
...
|
|
|
|
def get(self, parent: Any) -> Any:
|
|
...
|
|
|
|
|
|
Context = Any
|
|
PyTree = Any
|
|
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
|
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
|
DumpableContext = Any # Any json dumpable text
|
|
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
|
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
|
ToStrFunc = Callable[["TreeSpec", List[str]], str]
|
|
MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]]
|
|
KeyPath = Tuple[KeyEntry, ...]
|
|
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
|
|
|
|
|
|
# A NodeDef holds two callables:
|
|
# - flatten_fn should take the collection and return a flat list of values.
|
|
# It can also return some context that is used in reconstructing the
|
|
# collection.
|
|
# - unflatten_fn should take a flat list of values and some context
|
|
# (returned by flatten_fn). It returns the collection by reconstructing
|
|
# it from the list and the context.
|
|
# - flatten_with_keys_fn, which is a callable that takes a
|
|
# pytree and returns a list of (keypath, value) pairs and a context.
|
|
class NodeDef(NamedTuple):
|
|
type: Type[Any]
|
|
flatten_fn: FlattenFunc
|
|
unflatten_fn: UnflattenFunc
|
|
flatten_with_keys_fn: Optional[FlattenWithKeysFunc]
|
|
|
|
|
|
_NODE_REGISTRY_LOCK = threading.Lock()
|
|
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
|
|
|
|
|
|
# _SerializeNodeDef holds the following:
|
|
# - typ: the type of the node (e.g., "Dict", "List", etc)
|
|
# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict"
|
|
# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the
|
|
# context, and the version number
|
|
# - from_dumpable_context takes in a string representation of the context, and the
|
|
# version, and returns the deserialized context
|
|
class _SerializeNodeDef(NamedTuple):
|
|
typ: Type[Any]
|
|
serialized_type_name: str
|
|
to_dumpable_context: Optional[ToDumpableContextFn]
|
|
from_dumpable_context: Optional[FromDumpableContextFn]
|
|
|
|
|
|
SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {}
|
|
SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
|
|
|
|
|
|
def register_pytree_node(
|
|
cls: Type[Any],
|
|
flatten_fn: FlattenFunc,
|
|
unflatten_fn: UnflattenFunc,
|
|
*,
|
|
serialized_type_name: Optional[str] = None,
|
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
|
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
|
) -> None:
|
|
"""Register a container-like type as pytree node.
|
|
|
|
Args:
|
|
cls: the type to register
|
|
flatten_fn: A callable that takes a pytree and returns a flattened
|
|
representation of the pytree and additional context to represent the
|
|
flattened pytree.
|
|
unflatten_fn: A callable that takes a flattened version of the pytree,
|
|
additional context, and returns an unflattened pytree.
|
|
serialized_type_name: A keyword argument used to specify the fully qualified
|
|
name used when serializing the tree spec.
|
|
to_dumpable_context: An optional keyword argument to custom specify how
|
|
to convert the context of the pytree to a custom json dumpable
|
|
representation. This is used for json serialization, which is being
|
|
used in torch.export right now.
|
|
from_dumpable_context: An optional keyword argument to custom specify how
|
|
to convert the custom json dumpable representation of the context
|
|
back to the original context. This is used for json deserialization,
|
|
which is being used in torch.export right now.
|
|
flatten_with_keys_fn: An optional keyword argument to specify how to
|
|
access each pytree leaf's keypath when flattening and tree-mapping.
|
|
Like ``flatten_fn``, but in place of a List[leaf], it should return
|
|
a List[(keypath, leaf)].
|
|
"""
|
|
with _NODE_REGISTRY_LOCK:
|
|
if cls in SUPPORTED_NODES:
|
|
raise ValueError(f"{cls} is already registered as pytree node.")
|
|
|
|
_private_register_pytree_node(
|
|
cls,
|
|
flatten_fn,
|
|
unflatten_fn,
|
|
serialized_type_name=serialized_type_name,
|
|
to_dumpable_context=to_dumpable_context,
|
|
from_dumpable_context=from_dumpable_context,
|
|
flatten_with_keys_fn=flatten_with_keys_fn,
|
|
)
|
|
|
|
try:
|
|
from . import _cxx_pytree as cxx
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
cxx._private_register_pytree_node(
|
|
cls,
|
|
flatten_fn,
|
|
unflatten_fn,
|
|
serialized_type_name=serialized_type_name,
|
|
to_dumpable_context=to_dumpable_context,
|
|
from_dumpable_context=from_dumpable_context,
|
|
)
|
|
|
|
|
|
def _register_pytree_node(
|
|
cls: Type[Any],
|
|
flatten_fn: FlattenFunc,
|
|
unflatten_fn: UnflattenFunc,
|
|
to_str_fn: Optional[ToStrFunc] = None, # deprecated
|
|
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
|
|
*,
|
|
serialized_type_name: Optional[str] = None,
|
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
|
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
|
) -> None:
|
|
"""Register a container-like type as pytree node for the Python pytree only.
|
|
|
|
Args:
|
|
cls: the type to register
|
|
flatten_fn: A callable that takes a pytree and returns a flattened
|
|
representation of the pytree and additional context to represent the
|
|
flattened pytree.
|
|
unflatten_fn: A callable that takes a flattened version of the pytree,
|
|
additional context, and returns an unflattened pytree.
|
|
serialized_type_name: A keyword argument used to specify the fully qualified
|
|
name used when serializing the tree spec.
|
|
to_dumpable_context: An optional keyword argument to custom specify how
|
|
to convert the context of the pytree to a custom json dumpable
|
|
representation. This is used for json serialization, which is being
|
|
used in torch.export right now.
|
|
from_dumpable_context: An optional keyword argument to custom specify how
|
|
to convert the custom json dumpable representation of the context
|
|
back to the original context. This is used for json deserialization,
|
|
which is being used in torch.export right now.
|
|
flatten_with_keys_fn: An optional keyword argument to specify how to
|
|
access each pytree leaf's keypath when flattening and tree-mapping.
|
|
Like ``flatten_fn``, but in place of a List[leaf], it should return
|
|
a List[(keypath, leaf)].
|
|
"""
|
|
warnings.warn(
|
|
"torch.utils._pytree._register_pytree_node is deprecated. "
|
|
"Please use torch.utils._pytree.register_pytree_node instead.",
|
|
stacklevel=2,
|
|
)
|
|
|
|
if to_str_fn is not None or maybe_from_str_fn is not None:
|
|
warnings.warn(
|
|
"to_str_fn and maybe_from_str_fn is deprecated. "
|
|
"Please use to_dumpable_context and from_dumpable_context instead."
|
|
)
|
|
|
|
_private_register_pytree_node(
|
|
cls,
|
|
flatten_fn,
|
|
unflatten_fn,
|
|
serialized_type_name=serialized_type_name,
|
|
to_dumpable_context=to_dumpable_context,
|
|
from_dumpable_context=from_dumpable_context,
|
|
flatten_with_keys_fn=flatten_with_keys_fn,
|
|
)
|
|
|
|
|
|
def _private_register_pytree_node(
|
|
cls: Type[Any],
|
|
flatten_fn: FlattenFunc,
|
|
unflatten_fn: UnflattenFunc,
|
|
*,
|
|
serialized_type_name: Optional[str] = None,
|
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
|
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
|
) -> None:
|
|
"""This is an internal function that is used to register a pytree node type
|
|
for the Python pytree only. End-users should use :func:`register_pytree_node`
|
|
instead.
|
|
"""
|
|
with _NODE_REGISTRY_LOCK:
|
|
if cls in SUPPORTED_NODES:
|
|
# TODO: change this warning to an error after OSS/internal stabilize
|
|
warnings.warn(
|
|
f"{cls} is already registered as pytree node. "
|
|
"Overwriting the previous registration.",
|
|
)
|
|
|
|
node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn)
|
|
SUPPORTED_NODES[cls] = node_def
|
|
|
|
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
|
raise ValueError(
|
|
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
|
"be None or registered."
|
|
)
|
|
|
|
if serialized_type_name is None:
|
|
serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
|
|
|
|
serialize_node_def = _SerializeNodeDef(
|
|
cls,
|
|
serialized_type_name,
|
|
to_dumpable_context,
|
|
from_dumpable_context,
|
|
)
|
|
SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
|
|
SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SequenceKey(Generic[T]):
|
|
idx: int
|
|
|
|
def __str__(self) -> str:
|
|
return f"[{self.idx!r}]"
|
|
|
|
def get(self, sequence: Sequence[T]) -> T:
|
|
return sequence[self.idx]
|
|
|
|
|
|
K = TypeVar("K", bound=Hashable)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class MappingKey(Generic[K, T]):
|
|
key: K
|
|
|
|
def __str__(self) -> str:
|
|
return f"[{self.key!r}]"
|
|
|
|
def get(self, mapping: Mapping[K, T]) -> T:
|
|
return mapping[self.key]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GetAttrKey:
|
|
name: str
|
|
|
|
def __str__(self) -> str:
|
|
return f".{self.name}"
|
|
|
|
def get(self, obj: Any) -> Any:
|
|
return getattr(obj, self.name)
|
|
|
|
|
|
def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
|
|
return list(d), None
|
|
|
|
|
|
def _tuple_flatten_with_keys(
|
|
d: Tuple[Any, ...]
|
|
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _tuple_flatten(d)
|
|
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
|
|
|
|
|
def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]:
|
|
return tuple(values)
|
|
|
|
|
|
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
|
return d, None
|
|
|
|
|
|
def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _list_flatten(d)
|
|
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
|
|
|
|
|
def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]:
|
|
return list(values)
|
|
|
|
|
|
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
|
return list(d.values()), list(d.keys())
|
|
|
|
|
|
def _dict_flatten_with_keys(
|
|
d: Dict[Any, Any]
|
|
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _dict_flatten(d)
|
|
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
|
|
|
|
|
def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]:
|
|
return dict(zip(context, values))
|
|
|
|
|
|
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
|
|
return list(d), type(d)
|
|
|
|
|
|
def _namedtuple_flatten_with_keys(
|
|
d: NamedTuple,
|
|
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _namedtuple_flatten(d)
|
|
return (
|
|
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
|
|
context,
|
|
)
|
|
|
|
|
|
def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple:
|
|
return cast(NamedTuple, context(*values))
|
|
|
|
|
|
def _namedtuple_serialize(context: Context) -> DumpableContext:
|
|
json_namedtuple = {
|
|
"class_name": context.__name__,
|
|
"fields": context._fields,
|
|
}
|
|
return json_namedtuple
|
|
|
|
|
|
def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
|
|
class_name = dumpable_context["class_name"]
|
|
assert isinstance(class_name, str)
|
|
context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc]
|
|
return context
|
|
|
|
|
|
def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]:
|
|
return list(d.values()), list(d.keys())
|
|
|
|
|
|
def _ordereddict_flatten_with_keys(
|
|
d: GenericOrderedDict[Any, Any]
|
|
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _ordereddict_flatten(d)
|
|
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
|
|
|
|
|
def _ordereddict_unflatten(
|
|
values: Iterable[Any],
|
|
context: Context,
|
|
) -> GenericOrderedDict[Any, Any]:
|
|
return OrderedDict((key, value) for key, value in zip(context, values))
|
|
|
|
|
|
_odict_flatten = _ordereddict_flatten
|
|
_odict_unflatten = _ordereddict_unflatten
|
|
|
|
|
|
def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]:
|
|
values, dict_context = _dict_flatten(d)
|
|
return values, [d.default_factory, dict_context]
|
|
|
|
|
|
def _defaultdict_flatten_with_keys(
|
|
d: DefaultDict[Any, Any]
|
|
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _defaultdict_flatten(d)
|
|
_, dict_context = context
|
|
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
|
|
|
|
|
|
def _defaultdict_unflatten(
|
|
values: Iterable[Any],
|
|
context: Context,
|
|
) -> DefaultDict[Any, Any]:
|
|
default_factory, dict_context = context
|
|
return defaultdict(default_factory, _dict_unflatten(values, dict_context))
|
|
|
|
|
|
def _defaultdict_serialize(context: Context) -> DumpableContext:
|
|
default_factory, dict_context = context
|
|
json_defaultdict = {
|
|
"default_factory_module": default_factory.__module__,
|
|
"default_factory_name": default_factory.__qualname__,
|
|
"dict_context": dict_context,
|
|
}
|
|
return json_defaultdict
|
|
|
|
|
|
def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
|
|
assert isinstance(dumpable_context, dict)
|
|
assert set(dumpable_context) == {
|
|
"default_factory_module",
|
|
"default_factory_name",
|
|
"dict_context",
|
|
}
|
|
|
|
default_factory_module = dumpable_context["default_factory_module"]
|
|
default_factory_name = dumpable_context["default_factory_name"]
|
|
assert isinstance(default_factory_module, str)
|
|
assert isinstance(default_factory_name, str)
|
|
module = importlib.import_module(default_factory_module)
|
|
default_factory = getattr(module, default_factory_name)
|
|
|
|
dict_context = dumpable_context["dict_context"]
|
|
return [default_factory, dict_context]
|
|
|
|
|
|
def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]:
|
|
return list(d), d.maxlen
|
|
|
|
|
|
def _deque_flatten_with_keys(
|
|
d: Deque[Any],
|
|
) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
|
|
values, context = _deque_flatten(d)
|
|
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
|
|
|
|
|
def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
|
|
return deque(values, maxlen=context)
|
|
|
|
|
|
_private_register_pytree_node(
|
|
tuple,
|
|
_tuple_flatten,
|
|
_tuple_unflatten,
|
|
serialized_type_name="builtins.tuple",
|
|
flatten_with_keys_fn=_tuple_flatten_with_keys,
|
|
)
|
|
_private_register_pytree_node(
|
|
list,
|
|
_list_flatten,
|
|
_list_unflatten,
|
|
serialized_type_name="builtins.list",
|
|
flatten_with_keys_fn=_list_flatten_with_keys,
|
|
)
|
|
_private_register_pytree_node(
|
|
dict,
|
|
_dict_flatten,
|
|
_dict_unflatten,
|
|
serialized_type_name="builtins.dict",
|
|
flatten_with_keys_fn=_dict_flatten_with_keys,
|
|
)
|
|
_private_register_pytree_node(
|
|
namedtuple, # type: ignore[arg-type]
|
|
_namedtuple_flatten,
|
|
_namedtuple_unflatten,
|
|
serialized_type_name="collections.namedtuple",
|
|
to_dumpable_context=_namedtuple_serialize,
|
|
from_dumpable_context=_namedtuple_deserialize,
|
|
flatten_with_keys_fn=_namedtuple_flatten_with_keys,
|
|
)
|
|
_private_register_pytree_node(
|
|
OrderedDict,
|
|
_ordereddict_flatten,
|
|
_ordereddict_unflatten,
|
|
serialized_type_name="collections.OrderedDict",
|
|
flatten_with_keys_fn=_ordereddict_flatten_with_keys,
|
|
)
|
|
_private_register_pytree_node(
|
|
defaultdict,
|
|
_defaultdict_flatten,
|
|
_defaultdict_unflatten,
|
|
serialized_type_name="collections.defaultdict",
|
|
to_dumpable_context=_defaultdict_serialize,
|
|
from_dumpable_context=_defaultdict_deserialize,
|
|
flatten_with_keys_fn=_defaultdict_flatten_with_keys,
|
|
)
|
|
_private_register_pytree_node(
|
|
deque,
|
|
_deque_flatten,
|
|
_deque_unflatten,
|
|
serialized_type_name="collections.deque",
|
|
flatten_with_keys_fn=_deque_flatten_with_keys,
|
|
)
|
|
|
|
|
|
STANDARD_DICT_TYPES: FrozenSet[type] = frozenset(
|
|
{dict, OrderedDict, defaultdict},
|
|
)
|
|
BUILTIN_TYPES: FrozenSet[type] = frozenset(
|
|
{tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
|
|
def _is_namedtuple_instance(tree: Any) -> bool:
|
|
typ = type(tree)
|
|
bases = typ.__bases__
|
|
if len(bases) != 1 or bases[0] != tuple:
|
|
return False
|
|
fields = getattr(typ, "_fields", None)
|
|
if not isinstance(fields, tuple):
|
|
return False
|
|
return all(type(entry) == str for entry in fields)
|
|
|
|
|
|
def _get_node_type(tree: Any) -> Any:
|
|
if _is_namedtuple_instance(tree):
|
|
return namedtuple
|
|
return type(tree)
|
|
|
|
|
|
# A leaf is defined as anything that is not a Node.
|
|
def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
|
|
return (is_leaf is not None and is_leaf(tree)) or _get_node_type(
|
|
tree
|
|
) not in SUPPORTED_NODES
|
|
|
|
|
|
# A TreeSpec represents the structure of a pytree. It holds:
|
|
# "type": the type of root Node of the pytree
|
|
# context: some context that is useful in unflattening the pytree
|
|
# children_specs: specs for each child of the root Node
|
|
# num_leaves: the number of leaves
|
|
@dataclasses.dataclass
|
|
class TreeSpec:
|
|
type: Any
|
|
context: Context
|
|
children_specs: List["TreeSpec"]
|
|
|
|
num_nodes: int = dataclasses.field(init=False)
|
|
num_leaves: int = dataclasses.field(init=False)
|
|
num_children: int = dataclasses.field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs)
|
|
self.num_leaves = sum(spec.num_leaves for spec in self.children_specs)
|
|
self.num_children = len(self.children_specs)
|
|
|
|
def __repr__(self, indent: int = 0) -> str:
|
|
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
|
|
children_specs_str: str = ""
|
|
if self.num_children > 0:
|
|
indent += 2
|
|
children_specs_str += self.children_specs[0].__repr__(indent)
|
|
children_specs_str += "," if self.num_children > 1 else ""
|
|
children_specs_str += ",".join(
|
|
[
|
|
"\n" + " " * indent + child.__repr__(indent)
|
|
for child in self.children_specs[1:]
|
|
]
|
|
)
|
|
repr_suffix: str = f"{children_specs_str}])"
|
|
return repr_prefix + repr_suffix
|
|
|
|
def is_leaf(self) -> bool:
|
|
return self.num_nodes == 1 and self.num_leaves == 1
|
|
|
|
def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
|
|
if self.is_leaf():
|
|
subtrees.append(tree)
|
|
return
|
|
|
|
node_type = _get_node_type(tree)
|
|
if self.type not in BUILTIN_TYPES:
|
|
# Always require custom node types to match exactly
|
|
if node_type != self.type:
|
|
raise ValueError(
|
|
f"Type mismatch; "
|
|
f"expected {self.type!r}, but got {node_type!r}.",
|
|
)
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, context = flatten_fn(tree)
|
|
if len(child_pytrees) != self.num_children:
|
|
raise ValueError(
|
|
f"Node arity mismatch; "
|
|
f"expected {self.num_children}, but got {len(child_pytrees)}.",
|
|
)
|
|
if context != self.context:
|
|
raise ValueError(
|
|
f"Node context mismatch for custom node type {self.type!r}.",
|
|
)
|
|
else:
|
|
# For builtin dictionary types, we allow some flexibility
|
|
# Otherwise, we require exact matches
|
|
both_standard_dict = (
|
|
self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
|
|
)
|
|
if node_type != self.type and not both_standard_dict:
|
|
raise ValueError(
|
|
f"Node type mismatch; "
|
|
f"expected {self.type!r}, but got {node_type!r}.",
|
|
)
|
|
if len(tree) != self.num_children:
|
|
raise ValueError(
|
|
f"Node arity mismatch; "
|
|
f"expected {self.num_children}, but got {len(tree)}.",
|
|
)
|
|
|
|
if both_standard_dict: # dictionary types are compatible with each other
|
|
dict_context = (
|
|
self.context
|
|
if self.type is not defaultdict
|
|
# ignore mismatch of `default_factory` for defaultdict
|
|
else self.context[1]
|
|
)
|
|
expected_keys = dict_context
|
|
got_key_set = set(tree)
|
|
expected_key_set = set(expected_keys)
|
|
if got_key_set != expected_key_set:
|
|
missing_keys = expected_key_set.difference(got_key_set)
|
|
extra_keys = got_key_set.difference(expected_key_set)
|
|
message = ""
|
|
if missing_keys:
|
|
message += f"; missing key(s): {missing_keys}"
|
|
if extra_keys:
|
|
message += f"; extra key(s): {extra_keys}"
|
|
raise ValueError(f"Node keys mismatch{message}.")
|
|
child_pytrees = [tree[key] for key in expected_keys]
|
|
else:
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, context = flatten_fn(tree)
|
|
if (
|
|
context != self.context
|
|
and self.type is not deque # ignore mismatch of `maxlen` for deque
|
|
):
|
|
raise ValueError(
|
|
f"Node context mismatch for node type {self.type!r}; "
|
|
f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch
|
|
)
|
|
|
|
for child_pytree, child_spec in zip(child_pytrees, self.children_specs):
|
|
child_spec._flatten_up_to_helper(child_pytree, subtrees)
|
|
|
|
def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
|
|
subtrees: List[PyTree] = []
|
|
self._flatten_up_to_helper(tree, subtrees)
|
|
return subtrees
|
|
|
|
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
|
if not isinstance(leaves, (list, tuple)):
|
|
leaves = list(leaves)
|
|
if len(leaves) != self.num_leaves:
|
|
raise ValueError(
|
|
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
|
|
f"but the spec refers to a pytree that holds {self.num_leaves} "
|
|
f"items ({self}).",
|
|
)
|
|
if self.is_leaf():
|
|
return leaves[0]
|
|
|
|
unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
|
|
|
|
# Recursively unflatten the children
|
|
start = 0
|
|
end = 0
|
|
child_pytrees = []
|
|
for child_spec in self.children_specs:
|
|
end += child_spec.num_leaves
|
|
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
|
|
start = end
|
|
|
|
return unflatten_fn(child_pytrees, self.context)
|
|
|
|
|
|
class LeafSpec(TreeSpec):
|
|
def __init__(self) -> None:
|
|
super().__init__(None, None, [])
|
|
|
|
def __post_init__(self) -> None:
|
|
self.num_nodes = 1
|
|
self.num_leaves = 1
|
|
self.num_children = 0
|
|
|
|
def __repr__(self, indent: int = 0) -> str:
|
|
return "*"
|
|
|
|
|
|
# All leaves are equivalent, so represent with a single object to save on
|
|
# object construction time
|
|
_LEAF_SPEC = LeafSpec()
|
|
|
|
|
|
def _tree_flatten_helper(
|
|
tree: PyTree,
|
|
leaves: List[Any],
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> TreeSpec:
|
|
if _is_leaf(tree, is_leaf=is_leaf):
|
|
leaves.append(tree)
|
|
return _LEAF_SPEC
|
|
|
|
node_type = _get_node_type(tree)
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, context = flatten_fn(tree)
|
|
|
|
# Recursively flatten the children
|
|
children_specs = [
|
|
_tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees
|
|
]
|
|
|
|
return TreeSpec(node_type, context, children_specs)
|
|
|
|
|
|
def tree_flatten(
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> Tuple[List[Any], TreeSpec]:
|
|
"""Flattens a pytree into a list of values and a TreeSpec that can be used
|
|
to reconstruct the pytree.
|
|
"""
|
|
leaves: List[Any] = []
|
|
spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf)
|
|
return leaves, spec
|
|
|
|
|
|
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
|
"""Given a list of values and a TreeSpec, builds a pytree.
|
|
This is the inverse operation of `tree_flatten`.
|
|
"""
|
|
if not isinstance(treespec, TreeSpec):
|
|
raise TypeError(
|
|
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
|
|
f"instance of TreeSpec but got item of type {type(treespec)}.",
|
|
)
|
|
return treespec.unflatten(leaves)
|
|
|
|
|
|
def _tree_leaves_helper(
|
|
tree: PyTree,
|
|
leaves: List[Any],
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> None:
|
|
if _is_leaf(tree, is_leaf=is_leaf):
|
|
leaves.append(tree)
|
|
return
|
|
|
|
node_type = _get_node_type(tree)
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, _ = flatten_fn(tree)
|
|
|
|
# Recursively flatten the children
|
|
for child in child_pytrees:
|
|
_tree_leaves_helper(child, leaves, is_leaf=is_leaf)
|
|
|
|
|
|
def tree_leaves(
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> List[Any]:
|
|
"""Get a list of leaves of a pytree."""
|
|
leaves: List[Any] = []
|
|
_tree_leaves_helper(tree, leaves, is_leaf=is_leaf)
|
|
return leaves
|
|
|
|
|
|
def tree_structure(
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> TreeSpec:
|
|
"""Get the TreeSpec for a pytree."""
|
|
return tree_flatten(tree, is_leaf=is_leaf)[1]
|
|
|
|
|
|
def tree_map(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
"""Map a multi-input function over pytree args to produce a new pytree.
|
|
|
|
See also :func:`tree_map_`.
|
|
|
|
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
|
{'x': 8, 'y': (43, 65)}
|
|
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': False, 'y': (False, False), 'z': True}
|
|
|
|
If multiple inputs are given, the structure of the tree is taken from the first input;
|
|
subsequent inputs need only have ``tree`` as a prefix:
|
|
|
|
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
|
|
[[5, 7, 9], [6, 1, 2]]
|
|
|
|
Args:
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
|
flattening step. The function should have a single argument with signature
|
|
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
|
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
|
leaf or not. If the function is not specified, the default pytree registry will be used.
|
|
|
|
Returns:
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
|
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
|
|
is the tuple of values at corresponding nodes in ``rests``.
|
|
"""
|
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
return treespec.unflatten(map(func, *flat_args))
|
|
|
|
|
|
def tree_map_(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
|
|
|
See also :func:`tree_map`.
|
|
|
|
Args:
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
|
flattening step. The function should have a single argument with signature
|
|
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
|
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
|
leaf or not. If the function is not specified, the default pytree registry will be used.
|
|
|
|
Returns:
|
|
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
|
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
|
|
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
|
|
"""
|
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
tuple(map(func, *flat_args)) # consume and exhaust the iterable
|
|
return tree
|
|
|
|
|
|
Type2 = Tuple[Type[T], Type[S]]
|
|
Type3 = Tuple[Type[T], Type[S], Type[U]]
|
|
if sys.version_info >= (3, 10):
|
|
TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
|
|
else:
|
|
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
|
|
|
|
Fn2 = Callable[[Union[T, S]], R]
|
|
Fn3 = Callable[[Union[T, S, U]], R]
|
|
Fn = Callable[[T], R]
|
|
FnAny = Callable[[Any], R]
|
|
|
|
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|
|
|
|
|
# These specializations help with type inference on the lambda passed to this
|
|
# function
|
|
@overload
|
|
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
|
...
|
|
|
|
|
|
# This specialization is needed for the implementations below that call
|
|
@overload
|
|
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
|
|
...
|
|
|
|
|
|
def map_only(
|
|
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
|
|
) -> MapOnlyFn[FnAny[Any]]:
|
|
"""
|
|
Suppose you are writing a tree_map over tensors, leaving everything
|
|
else unchanged. Ordinarily you would have to write:
|
|
|
|
def go(t):
|
|
if isinstance(t, Tensor):
|
|
return ...
|
|
else:
|
|
return t
|
|
|
|
With this function, you only need to write:
|
|
|
|
@map_only(Tensor)
|
|
def go(t):
|
|
return ...
|
|
|
|
You can also directly use 'tree_map_only'
|
|
"""
|
|
if isinstance(__type_or_types_or_pred, (type, tuple)) or (
|
|
sys.version_info >= (3, 10)
|
|
and isinstance(__type_or_types_or_pred, types.UnionType)
|
|
):
|
|
|
|
def pred(x: Any) -> bool:
|
|
return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type]
|
|
|
|
elif callable(__type_or_types_or_pred):
|
|
pred = __type_or_types_or_pred # type: ignore[assignment]
|
|
else:
|
|
raise TypeError("Argument must be a type, a tuple of types, or a callable.")
|
|
|
|
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
|
|
# @functools.wraps(func) # torch dynamo doesn't support this yet
|
|
def wrapped(x: T) -> Any:
|
|
if pred(x):
|
|
return func(x)
|
|
return x
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|
|
|
|
|
|
@overload
|
|
def tree_map_only(
|
|
__type_or_types_or_pred: Type[T],
|
|
func: Fn[T, Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_map_only(
|
|
__type_or_types_or_pred: Type2[T, S],
|
|
func: Fn2[T, S, Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_map_only(
|
|
__type_or_types_or_pred: Type3[T, S, U],
|
|
func: Fn3[T, S, U, Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_map_only(
|
|
__type_or_types_or_pred: Callable[[Any], bool],
|
|
func: FnAny[Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
def tree_map_only(
|
|
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
|
func: FnAny[Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
|
|
|
|
|
@overload
|
|
def tree_map_only_(
|
|
__type_or_types_or_pred: Type[T],
|
|
func: Fn[T, Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_map_only_(
|
|
__type_or_types_or_pred: Type2[T, S],
|
|
func: Fn2[T, S, Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_map_only_(
|
|
__type_or_types_or_pred: Type3[T, S, U],
|
|
func: Fn3[T, S, U, Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_map_only_(
|
|
__type_or_types_or_pred: Callable[[Any], bool],
|
|
func: FnAny[Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
...
|
|
|
|
|
|
def tree_map_only_(
|
|
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
|
func: FnAny[Any],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
|
|
|
|
|
def tree_all(
|
|
pred: Callable[[Any], bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
|
return all(map(pred, flat_args))
|
|
|
|
|
|
def tree_any(
|
|
pred: Callable[[Any], bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
|
return any(map(pred, flat_args))
|
|
|
|
|
|
@overload
|
|
def tree_all_only(
|
|
__type_or_types: Type[T],
|
|
pred: Fn[T, bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_all_only(
|
|
__type_or_types: Type2[T, S],
|
|
pred: Fn2[T, S, bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_all_only(
|
|
__type_or_types: Type3[T, S, U],
|
|
pred: Fn3[T, S, U, bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
...
|
|
|
|
|
|
def tree_all_only(
|
|
__type_or_types: TypeAny,
|
|
pred: FnAny[bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
|
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
|
|
|
|
|
@overload
|
|
def tree_any_only(
|
|
__type_or_types: Type[T],
|
|
pred: Fn[T, bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_any_only(
|
|
__type_or_types: Type2[T, S],
|
|
pred: Fn2[T, S, bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_any_only(
|
|
__type_or_types: Type3[T, S, U],
|
|
pred: Fn3[T, S, U, bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
...
|
|
|
|
|
|
def tree_any_only(
|
|
__type_or_types: TypeAny,
|
|
pred: FnAny[bool],
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> bool:
|
|
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
|
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
|
|
|
|
|
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
|
|
# values. If this is not possible, then this function returns None.
|
|
#
|
|
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
|
|
# would return [0, 0]. This is useful for part of the vmap implementation:
|
|
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
|
|
# broadcastable to the tree structure of `inputs` and we use
|
|
# _broadcast_to_and_flatten to check this.
|
|
def _broadcast_to_and_flatten(
|
|
tree: PyTree,
|
|
treespec: TreeSpec,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> Optional[List[Any]]:
|
|
assert isinstance(treespec, TreeSpec)
|
|
|
|
if _is_leaf(tree, is_leaf=is_leaf):
|
|
return [tree] * treespec.num_leaves
|
|
if treespec.is_leaf():
|
|
return None
|
|
node_type = _get_node_type(tree)
|
|
if node_type != treespec.type:
|
|
return None
|
|
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, ctx = flatten_fn(tree)
|
|
|
|
# Check if the Node is different from the spec
|
|
if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
|
|
return None
|
|
|
|
# Recursively flatten the children
|
|
result: List[Any] = []
|
|
for child, child_spec in zip(child_pytrees, treespec.children_specs):
|
|
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
|
if flat is not None:
|
|
result += flat
|
|
else:
|
|
return None
|
|
|
|
return result
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _TreeSpecSchema:
|
|
"""
|
|
_TreeSpecSchema is the schema used to serialize the TreeSpec
|
|
It contains the following fields:
|
|
- type: A string name of the type. null for the case of a LeafSpec.
|
|
- context: Any format which is json dumpable
|
|
- children_spec: A list of children serialized specs.
|
|
"""
|
|
|
|
type: Optional[str]
|
|
context: DumpableContext
|
|
children_spec: List["_TreeSpecSchema"]
|
|
|
|
|
|
class _ProtocolFn(NamedTuple):
|
|
treespec_to_json: Callable[[TreeSpec], DumpableContext]
|
|
json_to_treespec: Callable[[DumpableContext], TreeSpec]
|
|
|
|
|
|
_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {}
|
|
|
|
|
|
def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
|
if treespec.is_leaf():
|
|
return _TreeSpecSchema(None, None, [])
|
|
|
|
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
|
|
raise NotImplementedError(
|
|
f"Serializing {treespec.type} in pytree is not registered.",
|
|
)
|
|
|
|
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
|
|
|
|
serialized_type_name = serialize_node_def.serialized_type_name
|
|
|
|
if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
|
|
raise NotImplementedError(
|
|
f"No registered serialization name for {treespec.type} found. "
|
|
"Please update your _register_pytree_node call with a `serialized_type_name` kwarg."
|
|
)
|
|
|
|
if serialize_node_def.to_dumpable_context is None:
|
|
try:
|
|
serialized_context = json.dumps(treespec.context)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
"Unable to serialize context. "
|
|
"Please make the context json dump-able, or register a "
|
|
"custom serializer using _register_pytree_node."
|
|
) from e
|
|
else:
|
|
serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
|
|
|
|
child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
|
|
|
|
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
|
|
|
|
|
|
def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
|
if (
|
|
json_schema["type"] is None
|
|
and json_schema["context"] is None
|
|
and len(json_schema["children_spec"]) == 0
|
|
):
|
|
return _LEAF_SPEC
|
|
|
|
if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
|
|
raise NotImplementedError(
|
|
f'Deserializing {json_schema["type"]} in pytree is not registered.',
|
|
)
|
|
|
|
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
|
|
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ]
|
|
|
|
if serialize_node_def.from_dumpable_context is None:
|
|
try:
|
|
context = json.loads(json_schema["context"])
|
|
except TypeError as ex:
|
|
raise TypeError(
|
|
"Unable to deserialize context. "
|
|
"Please make the context json load-able, or register a "
|
|
"custom serializer using _register_pytree_node.",
|
|
) from ex
|
|
else:
|
|
context = serialize_node_def.from_dumpable_context(json_schema["context"])
|
|
|
|
children_specs = []
|
|
for child_string in json_schema["children_spec"]:
|
|
children_specs.append(_json_to_treespec(child_string))
|
|
|
|
return TreeSpec(typ, context, children_specs)
|
|
|
|
|
|
_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
|
|
|
|
|
|
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
|
if not isinstance(treespec, TreeSpec):
|
|
raise TypeError(
|
|
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
|
|
f"TreeSpec but got item of type {type(treespec)}.",
|
|
)
|
|
|
|
if protocol is None:
|
|
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
|
|
|
|
if protocol in _SUPPORTED_PROTOCOLS:
|
|
json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown protocol {protocol}. "
|
|
f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
|
|
)
|
|
|
|
str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)))
|
|
return str_spec
|
|
|
|
|
|
def treespec_loads(serialized: str) -> TreeSpec:
|
|
protocol, json_schema = json.loads(serialized)
|
|
|
|
if protocol in _SUPPORTED_PROTOCOLS:
|
|
return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema)
|
|
raise ValueError(
|
|
f"Unknown protocol {protocol}. "
|
|
f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
|
|
)
|
|
|
|
|
|
class _DummyLeaf:
|
|
def __repr__(self) -> str:
|
|
return "*"
|
|
|
|
|
|
def treespec_pprint(treespec: TreeSpec) -> str:
|
|
dummy_tree = tree_unflatten(
|
|
[_DummyLeaf() for _ in range(treespec.num_leaves)],
|
|
treespec,
|
|
)
|
|
return repr(dummy_tree)
|
|
|
|
|
|
# TODO(angelayi): remove this function after OSS/internal stabilize
|
|
def pytree_to_str(treespec: TreeSpec) -> str:
|
|
warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")
|
|
return treespec_dumps(treespec)
|
|
|
|
|
|
# TODO(angelayi): remove this function after OSS/internal stabilize
|
|
def str_to_pytree(json: str) -> TreeSpec:
|
|
warnings.warn("str_to_pytree is deprecated. Please use treespec_loads")
|
|
return treespec_loads(json)
|
|
|
|
|
|
def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]:
|
|
"""Get a flat list of arguments to this function
|
|
|
|
A slightly faster version of tree_leaves((args, kwargs))
|
|
"""
|
|
leaves: List[Any] = []
|
|
for a in args:
|
|
_tree_leaves_helper(a, leaves)
|
|
for a in kwargs.values():
|
|
_tree_leaves_helper(a, leaves)
|
|
return leaves
|
|
|
|
|
|
def tree_flatten_with_path(
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
|
|
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
|
|
|
|
Args:
|
|
tree: a pytree to flatten. If it contains a custom type, that type must be
|
|
registered with an appropriate `tree_flatten_with_path_fn` when registered
|
|
with :func:`register_pytree_node`.
|
|
is_leaf: An extra leaf predicate function that will be called at each
|
|
flattening step. The function should have a single argument with signature
|
|
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
|
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
|
leaf or not. If the function is not specified, the default pytree registry will be used.
|
|
Returns:
|
|
A tuple where the first element is a list of (key path, leaf) pairs, and the
|
|
second element is a :class:`TreeSpec` representing the structure of the flattened
|
|
tree.
|
|
"""
|
|
_, treespec = tree_flatten(tree, is_leaf)
|
|
return list(_generate_key_paths((), tree, is_leaf)), treespec
|
|
|
|
|
|
def tree_leaves_with_path(
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> List[Tuple[KeyPath, Any]]:
|
|
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
|
|
|
Args:
|
|
tree: a pytree. If it contains a custom type, that type must be
|
|
registered with an appropriate `tree_flatten_with_path_fn` when registered
|
|
with :func:`register_pytree_node`.
|
|
is_leaf: An extra leaf predicate function that will be called at each
|
|
flattening step. The function should have a single argument with signature
|
|
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
|
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
|
leaf or not. If the function is not specified, the default pytree registry will be used.
|
|
Returns:
|
|
A list of (key path, leaf) pairs.
|
|
"""
|
|
return list(_generate_key_paths((), tree, is_leaf))
|
|
|
|
|
|
def _generate_key_paths(
|
|
key_path: KeyPath,
|
|
tree: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> Iterable[Tuple[KeyPath, Any]]:
|
|
if is_leaf and is_leaf(tree):
|
|
yield key_path, tree
|
|
return
|
|
|
|
node_type = _get_node_type(tree)
|
|
handler = SUPPORTED_NODES.get(node_type)
|
|
if not handler:
|
|
# This is a leaf
|
|
yield key_path, tree
|
|
return
|
|
|
|
flatten_with_keys = handler.flatten_with_keys_fn
|
|
if flatten_with_keys:
|
|
key_children, _ = flatten_with_keys(tree)
|
|
for k, c in key_children:
|
|
yield from _generate_key_paths((*key_path, k), c, is_leaf)
|
|
else:
|
|
# We registered this pytree but didn't add a flatten_with_keys_fn, complain.
|
|
raise ValueError(
|
|
f"Did not find a flatten_with_keys_fn for type: {node_type}. "
|
|
"Please pass a flatten_with_keys_fn argument to register_pytree_node."
|
|
)
|
|
|
|
|
|
def tree_map_with_path(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
|
) -> PyTree:
|
|
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.
|
|
|
|
Args:
|
|
func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees. The first positional argument
|
|
to ``func`` is the key path of the leaf in question. The second
|
|
positional argument is the value of the leaf.
|
|
tree: A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests: A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf: An extra leaf predicate function that will be called at each
|
|
flattening step. The function should have a single argument with signature
|
|
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
|
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
|
leaf or not. If the function is not specified, the default pytree registry will be used.
|
|
|
|
Returns
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
|
``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
|
|
corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
|
|
``xs`` is the tuple of values at corresponding nodes in ``rests``.
|
|
"""
|
|
keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
|
|
keypath_leaves = list(zip(*keypath_leaves))
|
|
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
|
|
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
|
|
|
|
|
|
def keystr(kp: KeyPath) -> str:
|
|
"""Given a key path, return a pretty-printed representation."""
|
|
return "".join([str(k) for k in kp])
|
|
|
|
|
|
def key_get(obj: Any, kp: KeyPath) -> Any:
|
|
"""Given an object and a key path, return the value at the key path."""
|
|
for k in kp:
|
|
obj = k.get(obj)
|
|
return obj
|