2868 lines
119 KiB
Python
2868 lines
119 KiB
Python
# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""OpTree: Optimized PyTree Utilities."""
|
|
|
|
# pylint: disable=too-many-lines
|
|
|
|
from __future__ import annotations
|
|
|
|
import difflib
|
|
import functools
|
|
import itertools
|
|
import textwrap
|
|
from collections import OrderedDict, defaultdict, deque
|
|
from typing import Any, Callable, Iterable, Mapping, overload
|
|
|
|
from optree import _C
|
|
from optree.registry import (
|
|
AttributeKeyPathEntry,
|
|
FlattenedKeyPathEntry,
|
|
KeyPath,
|
|
KeyPathEntry,
|
|
PyTreeNodeRegistryEntry,
|
|
register_keypaths,
|
|
register_pytree_node,
|
|
)
|
|
from optree.typing import (
|
|
CustomTreeNode,
|
|
MetaData,
|
|
NamedTuple,
|
|
PyTree,
|
|
PyTreeSpec,
|
|
S,
|
|
T,
|
|
U,
|
|
is_namedtuple_instance,
|
|
is_structseq_instance,
|
|
namedtuple_fields,
|
|
)
|
|
from optree.typing import structseq as PyStructSequence # noqa: N812
|
|
from optree.typing import structseq_fields
|
|
|
|
|
|
__all__ = [
|
|
'MAX_RECURSION_DEPTH',
|
|
'NONE_IS_NODE',
|
|
'NONE_IS_LEAF',
|
|
'tree_flatten',
|
|
'tree_flatten_with_path',
|
|
'tree_unflatten',
|
|
'tree_iter',
|
|
'tree_leaves',
|
|
'tree_structure',
|
|
'tree_paths',
|
|
'tree_is_leaf',
|
|
'all_leaves',
|
|
'tree_map',
|
|
'tree_map_',
|
|
'tree_map_with_path',
|
|
'tree_map_with_path_',
|
|
'tree_replace_nones',
|
|
'tree_transpose',
|
|
'tree_transpose_map',
|
|
'tree_transpose_map_with_path',
|
|
'tree_broadcast_prefix',
|
|
'broadcast_prefix',
|
|
'tree_broadcast_common',
|
|
'broadcast_common',
|
|
'tree_broadcast_map',
|
|
'tree_broadcast_map_with_path',
|
|
'tree_reduce',
|
|
'tree_sum',
|
|
'tree_max',
|
|
'tree_min',
|
|
'tree_all',
|
|
'tree_any',
|
|
'tree_flatten_one_level',
|
|
'treespec_paths',
|
|
'treespec_entries',
|
|
'treespec_entry',
|
|
'treespec_children',
|
|
'treespec_child',
|
|
'treespec_is_leaf',
|
|
'treespec_is_strict_leaf',
|
|
'treespec_is_prefix',
|
|
'treespec_is_suffix',
|
|
'treespec_leaf',
|
|
'treespec_none',
|
|
'treespec_tuple',
|
|
'treespec_list',
|
|
'treespec_dict',
|
|
'treespec_namedtuple',
|
|
'treespec_ordereddict',
|
|
'treespec_defaultdict',
|
|
'treespec_deque',
|
|
'treespec_structseq',
|
|
'treespec_from_collection',
|
|
'prefix_errors',
|
|
]
|
|
|
|
MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH # 1000
|
|
"""Maximum recursion depth for pytree traversal. It is 1000.
|
|
|
|
This limit prevents infinite recursion from causing an overflow of the C stack
|
|
and crashing Python.
|
|
"""
|
|
NONE_IS_NODE: bool = False # literal constant
|
|
"""Literal constant that treats :data:`None` as a pytree non-leaf node."""
|
|
NONE_IS_LEAF: bool = True # literal constant
|
|
"""Literal constant that treats :data:`None` as a pytree leaf node."""
|
|
|
|
|
|
def tree_flatten(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> tuple[list[T], PyTreeSpec]:
|
|
"""Flatten a pytree.
|
|
|
|
See also :func:`tree_flatten_with_path` and :func:`tree_unflatten`.
|
|
|
|
The flattening order (i.e., the order of elements in the output list) is deterministic,
|
|
corresponding to a left-to-right depth-first tree traversal.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[1, 2, 3, 4, 5],
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
|
|
)
|
|
>>> tree_flatten(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[1, 2, 3, 4, None, 5],
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
|
|
)
|
|
>>> tree_flatten(1)
|
|
([1], PyTreeSpec(*))
|
|
>>> tree_flatten(None)
|
|
([], PyTreeSpec(None))
|
|
>>> tree_flatten(None, none_is_leaf=True)
|
|
([None], PyTreeSpec(*, NoneIsLeaf))
|
|
|
|
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
|
|
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
|
|
if you want to keep the keys in the insertion order.
|
|
|
|
>>> from collections import OrderedDict
|
|
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
|
|
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[2, 3, 4, 1, 5],
|
|
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))
|
|
)
|
|
>>> tree_flatten(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[2, 3, 4, 1, None, 5],
|
|
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
|
|
)
|
|
|
|
Args:
|
|
tree (pytree): A pytree to flatten.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
|
|
second element is a treespec representing the structure of the pytree.
|
|
"""
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace)
|
|
|
|
|
|
def tree_flatten_with_path(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> tuple[list[tuple[Any, ...]], list[T], PyTreeSpec]:
|
|
"""Flatten a pytree and additionally record the paths.
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_paths`, and :func:`treespec_paths`.
|
|
|
|
The flattening order (i.e., the order of elements in the output list) is deterministic,
|
|
corresponding to a left-to-right depth-first tree traversal.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> tree_flatten_with_path(tree) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)],
|
|
[1, 2, 3, 4, 5],
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
|
|
)
|
|
>>> tree_flatten_with_path(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)],
|
|
[1, 2, 3, 4, None, 5],
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
|
|
)
|
|
>>> tree_flatten_with_path(1)
|
|
([()], [1], PyTreeSpec(*))
|
|
>>> tree_flatten_with_path(None)
|
|
([], [], PyTreeSpec(None))
|
|
>>> tree_flatten_with_path(None, none_is_leaf=True)
|
|
([()], [None], PyTreeSpec(*, NoneIsLeaf))
|
|
|
|
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
|
|
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
|
|
if you want to keep the keys in the insertion order.
|
|
|
|
>>> from collections import OrderedDict
|
|
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
|
|
>>> tree_flatten_with_path(tree) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('d',)],
|
|
[2, 3, 4, 1, 5],
|
|
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))
|
|
)
|
|
>>> tree_flatten_with_path(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)],
|
|
[2, 3, 4, 1, None, 5],
|
|
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
|
|
)
|
|
|
|
Args:
|
|
tree (pytree): A pytree to flatten.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A triple ``(paths, leaves, treespec)``. The first element is a list of the paths to the leaf
|
|
values, while each path is a tuple of the index or keys. The second element is a list of
|
|
leaf values and the last element is a treespec representing the structure of the pytree.
|
|
"""
|
|
return _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace)
|
|
|
|
|
|
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[T]) -> PyTree[T]:
|
|
"""Reconstruct a pytree from the treespec and the leaves.
|
|
|
|
The inverse of :func:`tree_flatten`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> leaves, treespec = tree_flatten(tree)
|
|
>>> tree == tree_unflatten(treespec, leaves)
|
|
True
|
|
|
|
Args:
|
|
treespec (PyTreeSpec): The treespec to reconstruct.
|
|
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
|
|
number of leaves of the treespec.
|
|
|
|
Returns:
|
|
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
|
``treespec``.
|
|
"""
|
|
return treespec.unflatten(leaves)
|
|
|
|
|
|
def tree_iter(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> Iterable[T]:
|
|
"""Get an iterator over the leaves of a pytree.
|
|
|
|
See also :func:`tree_flatten` and :func:`tree_leaves`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> list(tree_iter(tree))
|
|
[1, 2, 3, 4, 5]
|
|
>>> list(tree_iter(tree, none_is_leaf=True))
|
|
[1, 2, 3, 4, None, 5]
|
|
>>> list(tree_iter(1))
|
|
[1]
|
|
>>> list(tree_iter(None))
|
|
[]
|
|
>>> list(tree_iter(None, none_is_leaf=True))
|
|
[None]
|
|
|
|
Args:
|
|
tree (pytree): A pytree to iterate over.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
An iterator over the leaf values.
|
|
"""
|
|
return _C.PyTreeIter(tree, is_leaf, none_is_leaf, namespace)
|
|
|
|
|
|
def tree_leaves(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> list[T]:
|
|
"""Get the leaves of a pytree.
|
|
|
|
See also :func:`tree_flatten` and :func:`tree_iter`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> tree_leaves(tree)
|
|
[1, 2, 3, 4, 5]
|
|
>>> tree_leaves(tree, none_is_leaf=True)
|
|
[1, 2, 3, 4, None, 5]
|
|
>>> tree_leaves(1)
|
|
[1]
|
|
>>> tree_leaves(None)
|
|
[]
|
|
>>> tree_leaves(None, none_is_leaf=True)
|
|
[None]
|
|
|
|
Args:
|
|
tree (pytree): A pytree to flatten.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A list of leaf values.
|
|
"""
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace)[0]
|
|
|
|
|
|
def tree_structure(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Get the treespec for a pytree.
|
|
|
|
See also :func:`tree_flatten`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> tree_structure(tree)
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
|
|
>>> tree_structure(tree, none_is_leaf=True)
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
|
|
>>> tree_structure(1)
|
|
PyTreeSpec(*)
|
|
>>> tree_structure(None)
|
|
PyTreeSpec(None)
|
|
>>> tree_structure(None, none_is_leaf=True)
|
|
PyTreeSpec(*, NoneIsLeaf)
|
|
|
|
Args:
|
|
tree (pytree): A pytree to flatten.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec object representing the structure of the pytree.
|
|
"""
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace)[1]
|
|
|
|
|
|
def tree_paths(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> list[tuple[Any, ...]]:
|
|
"""Get the path entries to the leaves of a pytree.
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`, and :func:`treespec_paths`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> tree_paths(tree)
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)]
|
|
>>> tree_paths(tree, none_is_leaf=True)
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)]
|
|
>>> tree_paths(1)
|
|
[()]
|
|
>>> tree_paths(None)
|
|
[]
|
|
>>> tree_paths(None, none_is_leaf=True)
|
|
[()]
|
|
|
|
Args:
|
|
tree (pytree): A pytree to flatten.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A list of the paths to the leaf values, while each path is a tuple of the index or keys.
|
|
"""
|
|
return _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace)[0]
|
|
|
|
|
|
def tree_is_leaf(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> bool:
|
|
"""Test whether the given object is a leaf node.
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_leaves`, and :func:`all_leaves`.
|
|
|
|
>>> tree_is_leaf(1)
|
|
True
|
|
>>> tree_is_leaf(None)
|
|
False
|
|
>>> tree_is_leaf(None, none_is_leaf=True)
|
|
True
|
|
>>> tree_is_leaf({'a': 1, 'b': (2, 3)})
|
|
False
|
|
|
|
Args:
|
|
tree (pytree): A pytree to check if it is a leaf node.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than a leaf. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A boolean indicating if all elements in the input iterable are leaves.
|
|
"""
|
|
return _C.is_leaf(tree, is_leaf, none_is_leaf, namespace) # type: ignore[arg-type]
|
|
|
|
|
|
def all_leaves(
|
|
iterable: Iterable[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> bool:
|
|
"""Test whether all elements in the given iterable are all leaves.
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_leaves`, and :func:`tree_is_leaf`.
|
|
|
|
>>> tree = {'a': [1, 2, 3]}
|
|
>>> all_leaves(tree_leaves(tree))
|
|
True
|
|
>>> all_leaves([tree])
|
|
False
|
|
>>> all_leaves([1, 2, None, 3])
|
|
False
|
|
>>> all_leaves([1, 2, None, 3], none_is_leaf=True)
|
|
True
|
|
|
|
Note that this function iterates and checks the elements in the input iterable object, which
|
|
uses the :func:`iter` function. For dictionaries, ``iter(d)`` for a dictionary ``d`` iterates
|
|
the keys of the dictionary, not the values.
|
|
|
|
>>> list({'a': 1, 'b': (2, 3)})
|
|
['a', 'b']
|
|
>>> all_leaves({'a': 1, 'b': (2, 3)})
|
|
True
|
|
|
|
This function is useful in advanced cases. For example, if a library allows arbitrary map
|
|
operations on a flat list of leaves it may want to check if the result is still a flat list
|
|
of leaves.
|
|
|
|
Args:
|
|
iterable (iterable): A iterable of leaves.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than a leaf. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A boolean indicating if all elements in the input iterable are leaves.
|
|
"""
|
|
return _C.all_leaves(iterable, is_leaf, none_is_leaf, namespace)
|
|
|
|
|
|
def tree_map(
|
|
func: Callable[..., U],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[U]:
|
|
"""Map a multi-input function over pytree args to produce a new pytree.
|
|
|
|
See also :func:`tree_map_`, :func:`tree_map_with_path`, :func:`tree_map_with_path_`,
|
|
and :func:`tree_broadcast_map`.
|
|
|
|
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
|
{'x': 8, 'y': (43, 65)}
|
|
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': 8, 'y': (43, 65), 'z': None}
|
|
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': False, 'y': (False, False), 'z': None}
|
|
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True)
|
|
{'x': False, 'y': (False, False), 'z': True}
|
|
|
|
If multiple inputs are given, the structure of the tree is taken from the first input;
|
|
subsequent inputs need only have ``tree`` as a prefix:
|
|
|
|
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
|
|
[[5, 7, 9], [6, 1, 2]]
|
|
|
|
Args:
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
|
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
|
|
is the tuple of values at corresponding nodes in ``rests``.
|
|
"""
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
return treespec.unflatten(map(func, *flat_args))
|
|
|
|
|
|
def tree_map_(
|
|
func: Callable[..., Any],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[T]:
|
|
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`.
|
|
|
|
Args:
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
|
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
|
|
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
|
|
"""
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
|
return tree
|
|
|
|
|
|
def tree_map_with_path(
|
|
func: Callable[..., U],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[U]:
|
|
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path_`.
|
|
|
|
>>> tree_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)})
|
|
{'x': (1, 7), 'y': ((2, 42), (2, 64))}
|
|
>>> tree_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': 8, 'y': (44, 66), 'z': None}
|
|
>>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}})
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}}
|
|
>>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True)
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}}
|
|
|
|
Args:
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees with extra paths.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional
|
|
argument and the corresponding path providing the first positional argument to function
|
|
``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
|
``func(p, x, *xs)`` where ``(p, x)`` are the path and value at the corresponding leaf in
|
|
``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``.
|
|
"""
|
|
paths, leaves, treespec = _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
return treespec.unflatten(map(func, paths, *flat_args))
|
|
|
|
|
|
def tree_map_with_path_(
|
|
func: Callable[..., Any],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[T]:
|
|
"""Like :func:`tree_map_with_path`, but do an inplace call on each leaf and return the original tree.
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path`.
|
|
|
|
Args:
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees with extra paths.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional
|
|
argument and the corresponding path providing the first positional argument to function
|
|
``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
|
``func(p, x, *xs)`` (not the return value) where ``(p, x)`` are the path and value at the
|
|
corresponding leaf in ``tree`` and ``xs`` is the tuple of values at values at corresponding
|
|
nodes in ``rests``.
|
|
"""
|
|
paths, leaves, treespec = _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
deque(map(func, paths, *flat_args), maxlen=0) # consume and exhaust the iterable
|
|
return tree
|
|
|
|
|
|
def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, namespace: str = '') -> PyTree[T]:
|
|
"""Replace :data:`None` in ``tree`` with ``sentinel``.
|
|
|
|
See also :func:`tree_flatten` and :func:`tree_map`.
|
|
|
|
>>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)})
|
|
{'a': 1, 'b': 0, 'c': (2, 0)}
|
|
>>> tree_replace_nones(0, None)
|
|
0
|
|
|
|
Args:
|
|
sentinel (object): The value to replace :data:`None` with.
|
|
tree (pytree): A pytree to be transformed.
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new pytree with the same structure as ``tree`` but with :data:`None` replaced.
|
|
"""
|
|
if tree is None:
|
|
return sentinel
|
|
return tree_map(
|
|
lambda x: x if x is not None else sentinel,
|
|
tree,
|
|
none_is_leaf=True,
|
|
namespace=namespace,
|
|
)
|
|
|
|
|
|
def tree_transpose(
|
|
outer_treespec: PyTreeSpec,
|
|
inner_treespec: PyTreeSpec,
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
) -> PyTree[T]: # PyTree[PyTree[T]]
|
|
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_structure`, and :func:`tree_transpose_map`.
|
|
|
|
>>> outer_treespec = tree_structure({'a': 1, 'b': 2, 'c': (3, 4)})
|
|
>>> outer_treespec
|
|
PyTreeSpec({'a': *, 'b': *, 'c': (*, *)})
|
|
>>> inner_treespec = tree_structure((1, 2))
|
|
>>> inner_treespec
|
|
PyTreeSpec((*, *))
|
|
>>> tree = {'a': (1, 2), 'b': (3, 4), 'c': ((5, 6), (7, 8))}
|
|
>>> tree_transpose(outer_treespec, inner_treespec, tree)
|
|
({'a': 1, 'b': 3, 'c': (5, 7)}, {'a': 2, 'b': 4, 'c': (6, 8)})
|
|
|
|
For performance reasons, this function is only checks for the number of leaves in the input
|
|
pytree, not the structure. The result is only enumerated up to the original order of leaves in
|
|
``tree``, then transpose depends on the number of leaves in structure (inner, outer). The caller
|
|
is responsible for ensuring that the input pytree has a prefix structure of ``outer_treespec``
|
|
followed by a prefix structure of ``inner_treespec``. Otherwise, the result may be incorrect.
|
|
|
|
>>> tree_transpose(outer_treespec, inner_treespec, list(range(1, 9)))
|
|
({'a': 1, 'b': 3, 'c': (5, 7)}, {'a': 2, 'b': 4, 'c': (6, 8)})
|
|
|
|
Args:
|
|
outer_treespec (PyTreeSpec): A treespec object representing the outer structure of the pytree.
|
|
inner_treespec (PyTreeSpec): A treespec object representing the inner structure of the pytree.
|
|
tree (pytree): A pytree to be transposed.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
|
|
Returns:
|
|
A new pytree with the same structure as ``inner_treespec`` but with the value at each leaf
|
|
has the same structure as ``outer_treespec``.
|
|
"""
|
|
if outer_treespec.none_is_leaf != inner_treespec.none_is_leaf:
|
|
raise ValueError('Tree structures must have the same none_is_leaf value.')
|
|
outer_size = outer_treespec.num_leaves
|
|
inner_size = inner_treespec.num_leaves
|
|
if outer_size == 0 or inner_size == 0:
|
|
raise ValueError('Tree structures must have at least one leaf.')
|
|
if (
|
|
outer_treespec.namespace
|
|
and inner_treespec.namespace
|
|
and outer_treespec.namespace != inner_treespec.namespace
|
|
):
|
|
raise ValueError(
|
|
f'Tree structures must have the same namespace, '
|
|
f'got {outer_treespec.namespace!r} vs. {inner_treespec.namespace!r}.',
|
|
)
|
|
|
|
leaves, treespec = tree_flatten(
|
|
tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=outer_treespec.none_is_leaf,
|
|
namespace=outer_treespec.namespace or inner_treespec.namespace,
|
|
)
|
|
if treespec.num_leaves != outer_size * inner_size:
|
|
expected_treespec = outer_treespec.compose(inner_treespec)
|
|
raise TypeError(f'Tree structure mismatch; expected: {expected_treespec}, got: {treespec}.')
|
|
|
|
grouped = [
|
|
leaves[offset : offset + inner_size]
|
|
for offset in range(0, outer_size * inner_size, inner_size)
|
|
]
|
|
transposed = zip(*grouped)
|
|
subtrees = map(outer_treespec.unflatten, transposed)
|
|
return inner_treespec.unflatten(subtrees) # type: ignore[arg-type]
|
|
|
|
|
|
def tree_transpose_map(
|
|
func: Callable[..., PyTree[U]],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[S],
|
|
inner_treespec: PyTreeSpec | None = None,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[U]: # PyTree[PyTree[U]]
|
|
"""Map a multi-input function over pytree args to produce a new pytree with transposed structure.
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_with_path`, and :func:`tree_transpose`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
|
|
>>> tree_transpose_map( # doctest: +IGNORE_WHITESPACE
|
|
... lambda x: {'identity': x, 'double': 2 * x},
|
|
... tree,
|
|
... )
|
|
{
|
|
'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)},
|
|
'double': {'b': (4, [6, 8]), 'a': 2, 'c': (10, 12)}
|
|
}
|
|
>>> tree_transpose_map( # doctest: +IGNORE_WHITESPACE
|
|
... lambda x: {'identity': x, 'double': (x, x)},
|
|
... tree,
|
|
... )
|
|
{
|
|
'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)},
|
|
'double': (
|
|
{'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)},
|
|
{'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
|
|
)
|
|
}
|
|
>>> tree_transpose_map( # doctest: +IGNORE_WHITESPACE
|
|
... lambda x: {'identity': x, 'double': (x, x)},
|
|
... tree,
|
|
... inner_treespec=tree_structure({'identity': 0, 'double': 0}),
|
|
... )
|
|
{
|
|
'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)},
|
|
'double': {'b': ((2, 2), [(3, 3), (4, 4)]), 'a': (1, 1), 'c': ((5, 5), (6, 6))}
|
|
}
|
|
|
|
Args:
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
inner_treespec (PyTreeSpec, optional): The treespec object representing the inner structure
|
|
of the result pytree. If not specified, the inner structure is inferred from the result
|
|
of the function ``func`` on the first leaf. (default: :data:`None`)
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new nested pytree with the same structure as ``inner_treespec`` but with the value at each
|
|
leaf has the same structure as ``tree``. The subtree at each leaf is given by the result of
|
|
function ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and
|
|
``xs`` is the tuple of values at corresponding nodes in ``rests``.
|
|
"""
|
|
leaves, outer_treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
|
|
if outer_treespec.num_leaves == 0:
|
|
raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.')
|
|
flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests]
|
|
outputs = list(map(func, *flat_args))
|
|
|
|
if inner_treespec is None:
|
|
inner_treespec = tree_structure(
|
|
outputs[0],
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
if inner_treespec.num_leaves == 0:
|
|
raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.')
|
|
|
|
grouped = [inner_treespec.flatten_up_to(o) for o in outputs]
|
|
transposed = zip(*grouped)
|
|
subtrees = map(outer_treespec.unflatten, transposed)
|
|
return inner_treespec.unflatten(subtrees) # type: ignore[arg-type]
|
|
|
|
|
|
def tree_transpose_map_with_path(
|
|
func: Callable[..., PyTree[U]],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[S],
|
|
inner_treespec: PyTreeSpec | None = None,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[U]: # PyTree[PyTree[U]]
|
|
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree with transposed structure.
|
|
|
|
See also :func:`tree_map_with_path`, :func:`tree_transpose_map`, and :func:`tree_transpose`.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
|
|
>>> tree_transpose_map_with_path( # doctest: +IGNORE_WHITESPACE
|
|
... lambda p, x: {'depth': len(p), 'value': x},
|
|
... tree,
|
|
... )
|
|
{
|
|
'depth': {'b': (2, [3, 3]), 'a': 1, 'c': (2, 2)},
|
|
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
|
|
}
|
|
>>> tree_transpose_map_with_path( # doctest: +IGNORE_WHITESPACE
|
|
... lambda p, x: {'path': p, 'value': x},
|
|
... tree,
|
|
... inner_treespec=tree_structure({'path': 0, 'value': 0})),
|
|
... )
|
|
{
|
|
'path': {'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), 'a': ('a',), 'c': (('c', 0), ('c', 1))},
|
|
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
|
|
}
|
|
|
|
Args:
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees with extra paths.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional
|
|
argument and the corresponding path providing the first positional argument to function
|
|
``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
|
``tree`` or has ``tree`` as a prefix.
|
|
inner_treespec (PyTreeSpec, optional): The treespec object representing the inner structure
|
|
of the result pytree. If not specified, the inner structure is inferred from the result
|
|
of the function ``func`` on the first leaf. (default: :data:`None`)
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new nested pytree with the same structure as ``inner_treespec`` but with the value at each
|
|
leaf has the same structure as ``tree``. The subtree at each leaf is given by the result of
|
|
function ``func(p, x, *xs)`` where ``(p, x)`` are the path and value at the corresponding
|
|
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``.
|
|
""" # pylint: disable=line-too-long
|
|
paths, leaves, outer_treespec = _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace)
|
|
if outer_treespec.num_leaves == 0:
|
|
raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.')
|
|
flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests]
|
|
outputs = list(map(func, paths, *flat_args))
|
|
|
|
if inner_treespec is None:
|
|
inner_treespec = tree_structure(
|
|
outputs[0],
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
if inner_treespec.num_leaves == 0:
|
|
raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.')
|
|
|
|
grouped = [inner_treespec.flatten_up_to(o) for o in outputs]
|
|
transposed = zip(*grouped)
|
|
subtrees = map(outer_treespec.unflatten, transposed)
|
|
return inner_treespec.unflatten(subtrees) # type: ignore[arg-type]
|
|
|
|
|
|
def tree_broadcast_prefix(
|
|
prefix_tree: PyTree[T],
|
|
full_tree: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[T]: # PyTree[PyTree[T]]
|
|
"""Return a pytree of same structure of ``full_tree`` with broadcasted subtrees in ``prefix_tree``.
|
|
|
|
See also :func:`broadcast_prefix`, :func:`tree_broadcast_common`, and :func:`treespec_is_prefix`.
|
|
|
|
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
|
|
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
|
|
|
|
This function returns a pytree with the same size as ``full_tree``. The leaves are replicated
|
|
from ``prefix_tree``. The number of replicas is determined by the corresponding subtree in
|
|
``full_tree``.
|
|
|
|
>>> tree_broadcast_prefix(1, [2, 3, 4])
|
|
[1, 1, 1]
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6])
|
|
[1, 2, 3]
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)])
|
|
[1, 2, (3, 3)]
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}])
|
|
[1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}]
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True)
|
|
[1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]
|
|
|
|
Args:
|
|
prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``.
|
|
full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A pytree of same structure of ``full_tree`` with broadcasted subtrees in ``prefix_tree``.
|
|
"""
|
|
|
|
def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]:
|
|
subtreespec = tree_structure(
|
|
subtree,
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
return subtreespec.unflatten(itertools.repeat(x, subtreespec.num_leaves))
|
|
|
|
# If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError;
|
|
# use prefix_errors to find disagreements and raise more precise error messages.
|
|
# errors = prefix_errors(
|
|
# prefix_tree,
|
|
# full_tree,
|
|
# is_leaf=is_leaf,
|
|
# none_is_leaf=none_is_leaf,
|
|
# namespace=namespace,
|
|
# )
|
|
return tree_map(
|
|
broadcast_leaves, # type: ignore[arg-type]
|
|
prefix_tree,
|
|
full_tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
|
|
def broadcast_prefix(
|
|
prefix_tree: PyTree[T],
|
|
full_tree: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> list[T]:
|
|
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
|
|
|
|
See also :func:`tree_broadcast_prefix`, :func:`broadcast_common`, and :func:`treespec_is_prefix`.
|
|
|
|
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
|
|
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
|
|
|
|
This function returns a list of leaves with the same size as ``full_tree``. The leaves are
|
|
replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
|
|
subtree in ``full_tree``.
|
|
|
|
>>> broadcast_prefix(1, [2, 3, 4])
|
|
[1, 1, 1]
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, 6])
|
|
[1, 2, 3]
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, (6, 7)])
|
|
[1, 2, 3, 3]
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}])
|
|
[1, 2, 3, 3, 3]
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True)
|
|
[1, 2, 3, 3, 3, 3]
|
|
|
|
Args:
|
|
prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``.
|
|
full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
|
|
"""
|
|
result: list[T] = []
|
|
|
|
def add_leaves(x: T, subtree: PyTree[S]) -> None:
|
|
subtreespec = tree_structure(
|
|
subtree,
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
result.extend(itertools.repeat(x, subtreespec.num_leaves))
|
|
|
|
# If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError;
|
|
# use prefix_errors to find disagreements and raise more precise error messages.
|
|
# errors = prefix_errors(
|
|
# prefix_tree,
|
|
# full_tree,
|
|
# is_leaf=is_leaf,
|
|
# none_is_leaf=none_is_leaf,
|
|
# namespace=namespace,
|
|
# )
|
|
tree_map_(
|
|
add_leaves,
|
|
prefix_tree,
|
|
full_tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
return result
|
|
|
|
|
|
def tree_broadcast_common(
|
|
tree: PyTree[T],
|
|
other_tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> tuple[PyTree[T], PyTree[T]]:
|
|
"""Return two pytrees of common suffix structure of ``tree`` and ``other_tree`` with broadcasted subtrees.
|
|
|
|
See also :func:`broadcast_common`, :func:`tree_broadcast_prefix`, and :func:`treespec_is_prefix`.
|
|
|
|
If a ``suffix_tree`` is a suffix of a ``tree``, this means the ``suffix_tree`` can be
|
|
constructed by replacing the leaves of ``tree`` with appropriate **subtrees**.
|
|
|
|
This function returns two pytrees with the same structure. The tree structure is the common
|
|
suffix structure of ``tree`` and ``other_tree``. The leaves are replicated from ``tree`` and
|
|
``other_tree``. The number of replicas is determined by the corresponding subtree in the suffix
|
|
structure.
|
|
|
|
>>> tree_broadcast_common(1, [2, 3, 4])
|
|
([1, 1, 1], [2, 3, 4])
|
|
>>> tree_broadcast_common([1, 2, 3], [4, 5, 6])
|
|
([1, 2, 3], [4, 5, 6])
|
|
>>> tree_broadcast_common([1, 2, 3], [4, 5, 6, 7])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: list arity mismatch; expected: 3, got: 4.
|
|
>>> tree_broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)])
|
|
([1, (2, 3), (4, 4)], [5, (6, 6), (7, 8)])
|
|
>>> tree_broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}])
|
|
([1, {'a': (2, 3)}, {'a': 4, 'b': 4, 'c': (None, 4)}],
|
|
[5, {'a': (6, 6)}, {'a': 7, 'b': 8, 'c': (None, 9)}])
|
|
>>> tree_broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], none_is_leaf=True)
|
|
([1, {'a': (2, 3)}, {'a': 4, 'b': 4, 'c': (4, 4)}],
|
|
[5, {'a': (6, 6)}, {'a': 7, 'b': 8, 'c': (None, 9)}])
|
|
>>> tree_broadcast_common([1, None], [None, 2])
|
|
([None, None], [None, None])
|
|
>>> tree_broadcast_common([1, None], [None, 2], none_is_leaf=True)
|
|
([1, None], [None, 2])
|
|
|
|
Args:
|
|
tree (pytree): A pytree has a common suffix structure of ``other_tree``.
|
|
other_tree (pytree): A pytree has a common suffix structure of ``tree``.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
Two pytrees of common suffix structure of ``tree`` and ``other_tree`` with broadcasted subtrees.
|
|
"""
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
|
|
other_leaves, other_treespec = _C.flatten(other_tree, is_leaf, none_is_leaf, namespace)
|
|
common_suffix_treespec = treespec.broadcast_to_common_suffix(other_treespec)
|
|
|
|
sentinel: T = object() # type: ignore[assignment]
|
|
common_suffix_tree: PyTree[T] = common_suffix_treespec.unflatten(
|
|
itertools.repeat(sentinel, common_suffix_treespec.num_leaves),
|
|
)
|
|
|
|
def broadcast_leaves(x: T, subtree: PyTree[T]) -> PyTree[T]:
|
|
subtreespec = tree_structure(
|
|
subtree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
return subtreespec.unflatten(itertools.repeat(x, subtreespec.num_leaves))
|
|
|
|
broadcasted_tree: PyTree[T] = treespec.unflatten(
|
|
map(
|
|
broadcast_leaves, # type: ignore[arg-type]
|
|
leaves,
|
|
treespec.flatten_up_to(common_suffix_tree),
|
|
),
|
|
)
|
|
other_broadcasted_tree: PyTree[T] = other_treespec.unflatten(
|
|
map(
|
|
broadcast_leaves, # type: ignore[arg-type]
|
|
other_leaves,
|
|
other_treespec.flatten_up_to(common_suffix_tree),
|
|
),
|
|
)
|
|
return broadcasted_tree, other_broadcasted_tree
|
|
|
|
|
|
def broadcast_common(
|
|
tree: PyTree[T],
|
|
other_tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> tuple[list[T], list[T]]:
|
|
"""Return two lists of leaves in ``tree`` and ``other_tree`` broadcasted to match the number of leaves in the common suffix structure.
|
|
|
|
See also :func:`tree_broadcast_common`, :func:`broadcast_prefix`, and :func:`treespec_is_prefix`.
|
|
|
|
If a ``suffix_tree`` is a suffix of a ``tree``, this means the ``suffix_tree`` can be
|
|
constructed by replacing the leaves of ``tree`` with appropriate **subtrees**.
|
|
|
|
This function returns two pytrees with the same structure. The tree structure is the common
|
|
suffix structure of ``tree`` and ``other_tree``. The leaves are replicated from ``tree`` and
|
|
``other_tree``. The number of replicas is determined by the corresponding subtree in the suffix
|
|
structure.
|
|
|
|
>>> broadcast_common(1, [2, 3, 4])
|
|
([1, 1, 1], [2, 3, 4])
|
|
>>> broadcast_common([1, 2, 3], [4, 5, 6])
|
|
([1, 2, 3], [4, 5, 6])
|
|
>>> broadcast_common([1, 2, 3], [4, 5, 6, 7])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: list arity mismatch; expected: 3, got: 4.
|
|
>>> broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)])
|
|
([1, 2, 3, 4, 4], [5, 6, 6, 7, 8])
|
|
>>> broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}])
|
|
([1, 2, 3, 4, 4, 4], [5, 6, 6, 7, 8, 9])
|
|
>>> broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], none_is_leaf=True)
|
|
([1, 2, 3, 4, 4, 4, 4], [5, 6, 6, 7, 8, None, 9])
|
|
>>> broadcast_common([1, None], [None, 2])
|
|
([], [])
|
|
>>> broadcast_common([1, None], [None, 2], none_is_leaf=True)
|
|
([1, None], [None, 2])
|
|
|
|
Args:
|
|
tree (pytree): A pytree has a common suffix structure of ``other_tree``.
|
|
other_tree (pytree): A pytree has a common suffix structure of ``tree``.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
Two lists of leaves in ``tree`` and ``other_tree`` broadcasted to match the number of leaves
|
|
in the common suffix structure.
|
|
""" # pylint: disable=line-too-long
|
|
broadcasted_tree, other_broadcasted_tree = tree_broadcast_common(
|
|
tree,
|
|
other_tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
broadcasted_leaves: list[T] = []
|
|
other_broadcasted_leaves: list[T] = []
|
|
|
|
def add_leaves(x: T, y: T) -> None:
|
|
broadcasted_leaves.append(x)
|
|
other_broadcasted_leaves.append(y)
|
|
|
|
tree_map_(
|
|
add_leaves,
|
|
broadcasted_tree,
|
|
other_broadcasted_tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
return broadcasted_leaves, other_broadcasted_leaves
|
|
|
|
|
|
# pylint: disable-next=too-many-locals
|
|
def tree_broadcast_map(
|
|
func: Callable[..., U],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[U]:
|
|
"""Map a multi-input function over pytree args to produce a new pytree.
|
|
|
|
See also :func:`tree_broadcast_map_with_path`, :func:`tree_map`, :func:`tree_map_`,
|
|
and :func:`tree_map_with_path`.
|
|
|
|
If only one input is provided, this function is the same as :func:`tree_map`:
|
|
|
|
>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
|
{'x': 8, 'y': (43, 65)}
|
|
>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': 8, 'y': (43, 65), 'z': None}
|
|
>>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': False, 'y': (False, False), 'z': None}
|
|
>>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True)
|
|
{'x': False, 'y': (False, False), 'z': True}
|
|
|
|
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure
|
|
of all inputs:
|
|
|
|
>>> tree_broadcast_map(lambda x, y: x * y, [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8])
|
|
[{'a': 35, 'b': 45}, [6, 12], (24, 32)]
|
|
|
|
Args:
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with
|
|
each other and with ``tree``.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but
|
|
with the value at each leaf given by ``func(x, *xs)`` where ``x`` is the value at the
|
|
corresponding leaf (may be broadcasted) in ``tree`` and ``xs`` is the tuple of values at
|
|
corresponding leaves (may be broadcasted) in ``rests``.
|
|
"""
|
|
if not rests:
|
|
return tree_map(
|
|
func,
|
|
tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
broadcasted_tree = tree
|
|
broadcasted_rests = list(rests)
|
|
for _ in range(2):
|
|
for i, rest in enumerate(rests):
|
|
broadcasted_tree, broadcasted_rests[i] = tree_broadcast_common(
|
|
broadcasted_tree,
|
|
rest,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
return tree_map(
|
|
func,
|
|
broadcasted_tree,
|
|
*broadcasted_rests,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
|
|
# pylint: disable-next=too-many-locals
|
|
def tree_broadcast_map_with_path(
|
|
func: Callable[..., U],
|
|
tree: PyTree[T],
|
|
*rests: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTree[U]:
|
|
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
|
|
|
|
See also :func:`tree_broadcast_map`, :func:`tree_map`, :func:`tree_map_`,
|
|
and :func:`tree_map_with_path`.
|
|
|
|
If only one input is provided, this function is the same as :func:`tree_map`:
|
|
|
|
>>> tree_broadcast_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)})
|
|
{'x': (1, 7), 'y': ((2, 42), (2, 64))}
|
|
>>> tree_broadcast_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None})
|
|
{'x': 8, 'y': (44, 66), 'z': None}
|
|
>>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}})
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}}
|
|
>>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True)
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}}
|
|
|
|
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure
|
|
of all inputs:
|
|
|
|
>>> tree_broadcast_map_with_path(lambda p, x, y: (p, x * y), [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8])
|
|
[{'a': ((0, 'a'), 35), 'b': ((0, 'b'), 45)},
|
|
[((1, 0), 6), ((1, 1), 12)],
|
|
(((2, 0), 24), ((2, 1), 32))]
|
|
|
|
Args:
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the
|
|
corresponding leaves of the pytrees with extra paths.
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
|
argument to function ``func``.
|
|
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with
|
|
each other and with ``tree``.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but
|
|
with the value at each leaf given by ``func(p, x, *xs)`` where ``(p, x)`` are the path and
|
|
value at the corresponding leaf (may be broadcasted) in and ``xs`` is the tuple of values at
|
|
corresponding leaves (may be broadcasted) in ``rests``.
|
|
"""
|
|
if not rests:
|
|
return tree_map_with_path(
|
|
func,
|
|
tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
broadcasted_tree = tree
|
|
broadcasted_rests = list(rests)
|
|
for _ in range(2):
|
|
for i, rest in enumerate(rests):
|
|
broadcasted_tree, broadcasted_rests[i] = tree_broadcast_common(
|
|
broadcasted_tree,
|
|
rest,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
return tree_map_with_path(
|
|
func,
|
|
broadcasted_tree,
|
|
*broadcasted_rests,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
|
|
# pylint: disable-next=missing-class-docstring,too-few-public-methods
|
|
class MissingSentinel: # pragma: no cover
|
|
def __repr__(self) -> str:
|
|
return '<MISSING>'
|
|
|
|
|
|
__MISSING: T = MissingSentinel() # type: ignore[valid-type]
|
|
del MissingSentinel
|
|
|
|
|
|
@overload
|
|
def tree_reduce(
|
|
func: Callable[[T, T], T],
|
|
tree: PyTree[T],
|
|
*,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T: # pragma: no cover
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_reduce(
|
|
func: Callable[[T, S], T],
|
|
tree: PyTree[S],
|
|
initial: T = __MISSING,
|
|
*,
|
|
is_leaf: Callable[[S], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T: # pragma: no cover
|
|
...
|
|
|
|
|
|
def tree_reduce(
|
|
func,
|
|
tree,
|
|
initial=__MISSING,
|
|
*,
|
|
is_leaf=None,
|
|
none_is_leaf=False,
|
|
namespace='',
|
|
):
|
|
"""Traversal through a pytree and reduce the leaves in left-to-right depth-first order.
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_sum`.
|
|
|
|
>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)})
|
|
6
|
|
>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default
|
|
6
|
|
>>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3})
|
|
3
|
|
>>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True)
|
|
None
|
|
|
|
Args:
|
|
func (callable): A function that takes two arguments and returns a value of the same type.
|
|
tree (pytree): A pytree to be traversed.
|
|
initial (object, optional): An initial value to be used for the reduction. If not provided,
|
|
the first leaf value is used as the initial value.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
The result of reducing the leaves of the pytree using ``func``.
|
|
""" # pylint: disable=line-too-long
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
|
|
if initial is __MISSING:
|
|
return functools.reduce(func, leaves)
|
|
return functools.reduce(func, leaves, initial)
|
|
|
|
|
|
def tree_sum(
|
|
tree: PyTree[T],
|
|
start: T = 0, # type: ignore[assignment]
|
|
*,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T:
|
|
"""Sum ``start`` and leaf values in ``tree`` in left-to-right depth-first order and return the total.
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_reduce`.
|
|
|
|
>>> tree_sum({'x': 1, 'y': (2, 3)})
|
|
6
|
|
>>> tree_sum({'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default
|
|
6
|
|
>>> tree_sum({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'
|
|
>>> tree_sum({'x': 'a', 'y': ('b', None), 'z': 'c'}, start='')
|
|
'abc'
|
|
>>> tree_sum({'x': [1], 'y': ([2], [None]), 'z': [3]}, start=[], is_leaf=lambda x: isinstance(x, list))
|
|
[1, 2, None, 3]
|
|
|
|
Args:
|
|
tree (pytree): A pytree to be traversed.
|
|
start (object, optional): An initial value to be used for the sum. (default: :data:`0`)
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
The total sum of ``start`` and leaf values in ``tree``.
|
|
"""
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
|
|
# sum() rejects string values for `start` parameter
|
|
if isinstance(start, str):
|
|
return ''.join([start, *leaves]) # type: ignore[list-item,return-value]
|
|
if isinstance(start, (bytes, bytearray)):
|
|
return b''.join([start, *leaves]) # type: ignore[list-item,return-value]
|
|
return sum(leaves, start) # type: ignore[call-overload]
|
|
|
|
|
|
@overload
|
|
def tree_max(
|
|
tree: PyTree[T],
|
|
*,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
key: Callable[[T], Any] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T: # pragma: no cover
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_max(
|
|
tree: PyTree[T],
|
|
*,
|
|
default: T = __MISSING,
|
|
key: Callable[[T], Any] | None = None,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T: # pragma: no cover
|
|
...
|
|
|
|
|
|
def tree_max(tree, *, default=__MISSING, key=None, is_leaf=None, none_is_leaf=False, namespace=''):
|
|
"""Return the maximum leaf value in ``tree``.
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_min`.
|
|
|
|
>>> tree_max({})
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: max() arg is an empty sequence
|
|
>>> tree_max({}, default=0)
|
|
0
|
|
>>> tree_max({'x': 0, 'y': (2, 1)})
|
|
2
|
|
>>> tree_max({'x': 0, 'y': (2, 1)}, key=lambda x: -x)
|
|
0
|
|
>>> tree_max({'a': None}) # `None` is a non-leaf node with arity 0 by default
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: max() arg is an empty sequence
|
|
>>> tree_max({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default
|
|
0
|
|
>>> tree_max({'a': None}, none_is_leaf=True)
|
|
None
|
|
>>> tree_max(None) # `None` is a non-leaf node with arity 0 by default
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: max() arg is an empty sequence
|
|
>>> tree_max(None, default=0)
|
|
0
|
|
>>> tree_max(None, none_is_leaf=True)
|
|
None
|
|
|
|
Args:
|
|
tree (pytree): A pytree to be traversed.
|
|
default (object, optional): The default value to return if ``tree`` is empty. If the ``tree``
|
|
is empty and ``default`` is not specified, raise a :exc:`ValueError`.
|
|
key (callable or None, optional): An one argument ordering function like that used for
|
|
:meth:`list.sort`.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
The maximum leaf value in ``tree``.
|
|
"""
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
|
|
if default is __MISSING:
|
|
if key is None: # special handling for Python 3.7
|
|
return max(leaves)
|
|
return max(leaves, key=key)
|
|
if key is None: # special handling for Python 3.7
|
|
return max(leaves, default=default)
|
|
return max(leaves, default=default, key=key)
|
|
|
|
|
|
@overload
|
|
def tree_min(
|
|
tree: PyTree[T],
|
|
*,
|
|
key: Callable[[T], Any] | None = None,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T: # pragma: no cover
|
|
...
|
|
|
|
|
|
@overload
|
|
def tree_min(
|
|
tree: PyTree[T],
|
|
*,
|
|
default: T = __MISSING,
|
|
key: Callable[[T], Any] | None = None,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> T: # pragma: no cover
|
|
...
|
|
|
|
|
|
def tree_min(tree, *, default=__MISSING, key=None, is_leaf=None, none_is_leaf=False, namespace=''):
|
|
"""Return the minimum leaf value in ``tree``.
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_max`.
|
|
|
|
>>> tree_min({})
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: min() arg is an empty sequence
|
|
>>> tree_min({}, default=0)
|
|
0
|
|
>>> tree_min({'x': 0, 'y': (2, 1)})
|
|
0
|
|
>>> tree_min({'x': 0, 'y': (2, 1)}, key=lambda x: -x)
|
|
2
|
|
>>> tree_min({'a': None}) # `None` is a non-leaf node with arity 0 by default
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: min() arg is an empty sequence
|
|
>>> tree_min({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default
|
|
0
|
|
>>> tree_min({'a': None}, none_is_leaf=True)
|
|
None
|
|
>>> tree_min(None) # `None` is a non-leaf node with arity 0 by default
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: min() arg is an empty sequence
|
|
>>> tree_min(None, default=0)
|
|
0
|
|
>>> tree_min(None, none_is_leaf=True)
|
|
None
|
|
|
|
Args:
|
|
tree (pytree): A pytree to be traversed.
|
|
default (object, optional): The default value to return if ``tree`` is empty. If the ``tree``
|
|
is empty and ``default`` is not specified, raise a :exc:`ValueError`.
|
|
key (callable or None, optional): An one argument ordering function like that used for
|
|
:meth:`list.sort`.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
The minimum leaf value in ``tree``.
|
|
"""
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
|
|
if default is __MISSING:
|
|
if key is None: # special handling for Python 3.7
|
|
return min(leaves)
|
|
return min(leaves, key=key)
|
|
if key is None: # special handling for Python 3.7
|
|
return min(leaves, default=default)
|
|
return min(leaves, default=default, key=key)
|
|
|
|
|
|
def tree_all(
|
|
tree: PyTree[T],
|
|
*,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> bool:
|
|
"""Test whether all leaves in ``tree`` are true (or if ``tree`` is empty).
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_any`.
|
|
|
|
>>> tree_all({})
|
|
True
|
|
>>> tree_all({'x': 1, 'y': (2, 3)})
|
|
True
|
|
>>> tree_all({'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node by default
|
|
True
|
|
>>> tree_all({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True)
|
|
False
|
|
>>> tree_all(None) # `None` is a non-leaf node by default
|
|
True
|
|
>>> tree_all(None, none_is_leaf=True)
|
|
False
|
|
|
|
Args:
|
|
tree (pytree): A pytree to be traversed.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
:data:`True` if all leaves in ``tree`` are true, or if ``tree`` is empty.
|
|
Otherwise, :data:`False`.
|
|
"""
|
|
return all(
|
|
tree_iter(
|
|
tree, # type: ignore[arg-type]
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
),
|
|
)
|
|
|
|
|
|
def tree_any(
|
|
tree: PyTree[T],
|
|
*,
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> bool:
|
|
"""Test whether all leaves in ``tree`` are true (or :data:`False` if ``tree`` is empty).
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_all`.
|
|
|
|
>>> tree_any({})
|
|
False
|
|
>>> tree_any({'x': 0, 'y': (2, 0)})
|
|
True
|
|
>>> tree_any({'a': None}) # `None` is a non-leaf node with arity 0 by default
|
|
False
|
|
>>> tree_any({'a': None}, none_is_leaf=True) # `None` is evaluated as false
|
|
False
|
|
>>> tree_any(None) # `None` is a non-leaf node with arity 0 by default
|
|
False
|
|
>>> tree_any(None, none_is_leaf=True) # `None` is evaluated as false
|
|
False
|
|
|
|
Args:
|
|
tree (pytree): A pytree to be traversed.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
:data:`True` if any leaves in ``tree`` are true, otherwise, :data:`False`. If ``tree`` is
|
|
empty, return :data:`False`.
|
|
"""
|
|
return any(
|
|
tree_iter(
|
|
tree, # type: ignore[arg-type]
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
),
|
|
)
|
|
|
|
|
|
def tree_flatten_one_level(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> tuple[
|
|
list[PyTree[T]],
|
|
MetaData,
|
|
tuple[Any, ...],
|
|
Callable[[MetaData, list[PyTree[T]]], PyTree[T]],
|
|
]:
|
|
"""Flatten the pytree one level, returning a 4-tuple of children, auxiliary data, path entries, and an unflatten function.
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`.
|
|
|
|
>>> children, metadata, entries, unflatten_func = tree_flatten_one_level({'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5})
|
|
>>> children, metadata, entries
|
|
([1, (2, [3, 4]), None, 5], ['a', 'b', 'c', 'd'], ('a', 'b', 'c', 'd'))
|
|
>>> unflatten_func(metadata, children)
|
|
{'a': 1, 'b': (2, [3, 4]), 'c': None, 'd': 5}
|
|
>>> children, metadata, entries, unflatten_func = tree_flatten_one_level([{'a': 1, 'b': (2, 3)}, (4, 5)])
|
|
>>> children, metadata, entries
|
|
([{'a': 1, 'b': (2, 3)}, (4, 5)], None, (0, 1))
|
|
>>> unflatten_func(metadata, children)
|
|
[{'a': 1, 'b': (2, 3)}, (4, 5)]
|
|
|
|
Args:
|
|
tree (pytree): A pytree to be traversed.
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
|
flattening should traverse the current object.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A 4-tuple ``(children, metadata, entries, unflatten_func)``. The first element is a list of
|
|
one-level children of the pytree node. The second element is the auxiliary data used to
|
|
reconstruct the pytree node. The third element is a tuple of path entries to the children.
|
|
The fourth element is a function that can be used to unflatten the auxiliary data and
|
|
children back to the pytree node.
|
|
""" # pylint: disable=line-too-long
|
|
node_type = type(tree)
|
|
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): # type: ignore[unreachable,arg-type]
|
|
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')
|
|
|
|
handler: PyTreeNodeRegistryEntry | None = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined]
|
|
if handler:
|
|
flattened = tuple(handler.flatten_func(tree)) # type: ignore[arg-type]
|
|
if len(flattened) == 2:
|
|
flattened = (*flattened, None)
|
|
elif len(flattened) != 3:
|
|
raise RuntimeError(
|
|
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
|
|
f'got {len(flattened)}.',
|
|
)
|
|
children, metadata, entries = flattened
|
|
children = list(children) # type: ignore[arg-type]
|
|
entries = tuple(range(len(children)) if entries is None else entries)
|
|
if len(children) != len(entries):
|
|
raise RuntimeError(
|
|
f'PyTree custom flatten function for type {node_type} returned inconsistent '
|
|
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
|
|
)
|
|
return children, metadata, entries, handler.unflatten_func # type: ignore[return-value]
|
|
|
|
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')
|
|
|
|
|
|
def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]:
|
|
"""Return a list of paths to the leaves of a treespec.
|
|
|
|
See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`.
|
|
"""
|
|
return treespec.paths()
|
|
|
|
|
|
def treespec_entries(treespec: PyTreeSpec) -> list[Any]:
|
|
"""Return a list of one-level entries of a treespec to its children.
|
|
|
|
See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`,
|
|
and :meth:`PyTreeSpec.entries`.
|
|
"""
|
|
return treespec.entries()
|
|
|
|
|
|
def treespec_entry(treespec: PyTreeSpec, index: int) -> Any:
|
|
"""Return the entry of a treespec at the given index.
|
|
|
|
See also :func:`treespec_entries`, :func:`treespec_children`, and :meth:`PyTreeSpec.entry`.
|
|
"""
|
|
return treespec.entry(index)
|
|
|
|
|
|
def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]:
|
|
"""Return a list of treespecs for the children of a treespec.
|
|
|
|
See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`,
|
|
and :meth:`PyTreeSpec.children`.
|
|
"""
|
|
return treespec.children()
|
|
|
|
|
|
def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec:
|
|
"""Return the treespec of the child of a treespec at the given index.
|
|
|
|
See also :func:`treespec_children`, :func:`treespec_entries`, and :meth:`PyTreeSpec.child`.
|
|
"""
|
|
return treespec.child(index)
|
|
|
|
|
|
def treespec_is_leaf(treespec: PyTreeSpec, strict: bool = True) -> bool:
|
|
"""Return whether the treespec is a leaf that has no children.
|
|
|
|
See also :func:`treespec_is_strict_leaf` and :meth:`PyTreeSpec.is_leaf`.
|
|
|
|
This function is equivalent to ``treespec.is_leaf(strict=strict)``. If ``strict=False``, it will
|
|
return :data:`True` if and only if the treespec represents a strict leaf. If ``strict=False``,
|
|
it will return :data:`True` if the treespec represents a strict leaf or :data:`None` or an empty
|
|
container (e.g., an empty tuple).
|
|
|
|
>>> treespec_is_leaf(tree_structure(1))
|
|
True
|
|
>>> treespec_is_leaf(tree_structure((1, 2)))
|
|
False
|
|
>>> treespec_is_leaf(tree_structure(None))
|
|
False
|
|
>>> treespec_is_leaf(tree_structure(None), strict=False)
|
|
True
|
|
>>> treespec_is_leaf(tree_structure(None, none_is_leaf=False))
|
|
False
|
|
>>> treespec_is_leaf(tree_structure(None, none_is_leaf=True))
|
|
True
|
|
>>> treespec_is_leaf(tree_structure(()))
|
|
False
|
|
>>> treespec_is_leaf(tree_structure(()), strict=False)
|
|
True
|
|
>>> treespec_is_leaf(tree_structure([]))
|
|
False
|
|
>>> treespec_is_leaf(tree_structure([]), strict=False)
|
|
True
|
|
|
|
Args:
|
|
treespec (PyTreeSpec): A treespec.
|
|
strict (bool, optional): Whether not to treat :data:`None` or an empty
|
|
container (e.g., an empty tuple) as a leaf. (default: :data:`True`)
|
|
|
|
Returns:
|
|
:data:`True` if the treespec represents a leaf that has no children, otherwise, :data:`False`.
|
|
"""
|
|
if strict:
|
|
return treespec.num_nodes == 1 and treespec.num_leaves == 1
|
|
return treespec.num_nodes == 1
|
|
|
|
|
|
def treespec_is_strict_leaf(treespec: PyTreeSpec) -> bool:
|
|
"""Return whether the treespec is a strict leaf.
|
|
|
|
See also :func:`treespec_is_leaf` and :meth:`PyTreeSpec.is_leaf`.
|
|
|
|
This function respects the ``none_is_leaf`` setting in the treespec. It is equivalent to
|
|
``treespec.is_leaf(strict=True)``. It will return :data:`True` if and only if the treespec
|
|
represents a strict leaf.
|
|
|
|
>>> treespec_is_strict_leaf(tree_structure(1))
|
|
True
|
|
>>> treespec_is_strict_leaf(tree_structure((1, 2)))
|
|
False
|
|
>>> treespec_is_strict_leaf(tree_structure(None))
|
|
False
|
|
>>> treespec_is_strict_leaf(tree_structure(None, none_is_leaf=False))
|
|
False
|
|
>>> treespec_is_strict_leaf(tree_structure(None, none_is_leaf=True))
|
|
True
|
|
>>> treespec_is_strict_leaf(tree_structure(()))
|
|
False
|
|
>>> treespec_is_strict_leaf(tree_structure([]))
|
|
False
|
|
|
|
Args:
|
|
treespec (PyTreeSpec): A treespec.
|
|
|
|
Returns:
|
|
:data:`True` if the treespec represents a strict leaf, otherwise, :data:`False`.
|
|
"""
|
|
return treespec.num_nodes == 1 and treespec.num_leaves == 1
|
|
|
|
|
|
def treespec_is_prefix(
|
|
treespec: PyTreeSpec,
|
|
other_treespec: PyTreeSpec,
|
|
strict: bool = False,
|
|
) -> bool:
|
|
"""Return whether ``treespec`` is a prefix of ``other_treespec``.
|
|
|
|
See also :func:`treespec_is_prefix` and :meth:`PyTreeSpec.is_prefix`.
|
|
"""
|
|
return treespec.is_prefix(other_treespec, strict=strict)
|
|
|
|
|
|
def treespec_is_suffix(
|
|
treespec: PyTreeSpec,
|
|
other_treespec: PyTreeSpec,
|
|
strict: bool = False,
|
|
) -> bool:
|
|
"""Return whether ``treespec`` is a suffix of ``other_treespec``.
|
|
|
|
See also :func:`treespec_is_suffix` :meth:`PyTreeSpec.is_suffix`.
|
|
"""
|
|
return treespec.is_suffix(other_treespec, strict=strict)
|
|
|
|
|
|
def treespec_leaf(
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '', # unused
|
|
) -> PyTreeSpec:
|
|
"""Make a treespec representing a leaf node.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_none`, and :func:`treespec_tuple`.
|
|
|
|
>>> treespec_leaf()
|
|
PyTreeSpec(*)
|
|
>>> treespec_leaf(none_is_leaf=True)
|
|
PyTreeSpec(*, NoneIsLeaf)
|
|
>>> treespec_leaf(none_is_leaf=False) == treespec_leaf(none_is_leaf=True)
|
|
False
|
|
>>> treespec_leaf() == tree_structure(1)
|
|
True
|
|
>>> treespec_leaf(none_is_leaf=True) == tree_structure(1, none_is_leaf=True)
|
|
True
|
|
>>> treespec_leaf(none_is_leaf=True) == tree_structure(None, none_is_leaf=True)
|
|
True
|
|
>>> treespec_leaf(none_is_leaf=True) == tree_structure(None, none_is_leaf=False)
|
|
False
|
|
>>> treespec_leaf(none_is_leaf=True) == treespec_none(none_is_leaf=True)
|
|
True
|
|
>>> treespec_leaf(none_is_leaf=True) == treespec_none(none_is_leaf=False)
|
|
False
|
|
>>> treespec_leaf(none_is_leaf=False) == treespec_none(none_is_leaf=True)
|
|
False
|
|
>>> treespec_leaf(none_is_leaf=False) == treespec_none(none_is_leaf=False)
|
|
False
|
|
|
|
Args:
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a leaf node.
|
|
"""
|
|
return _C.make_leaf(
|
|
none_is_leaf,
|
|
namespace, # unused
|
|
)
|
|
|
|
|
|
def treespec_none(
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '', # unused
|
|
) -> PyTreeSpec:
|
|
"""Make a treespec representing a :data:`None` node.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_tuple`.
|
|
|
|
>>> treespec_none()
|
|
PyTreeSpec(None)
|
|
>>> treespec_none(none_is_leaf=True)
|
|
PyTreeSpec(*, NoneIsLeaf)
|
|
>>> treespec_none(none_is_leaf=False) == treespec_none(none_is_leaf=True)
|
|
False
|
|
>>> treespec_none() == tree_structure(None)
|
|
True
|
|
>>> treespec_none() == tree_structure(1)
|
|
False
|
|
>>> treespec_none(none_is_leaf=True) == tree_structure(1, none_is_leaf=True)
|
|
True
|
|
>>> treespec_none(none_is_leaf=True) == tree_structure(None, none_is_leaf=True)
|
|
True
|
|
>>> treespec_none(none_is_leaf=True) == tree_structure(None, none_is_leaf=False)
|
|
False
|
|
>>> treespec_none(none_is_leaf=True) == treespec_leaf(none_is_leaf=True)
|
|
True
|
|
>>> treespec_none(none_is_leaf=False) == treespec_leaf(none_is_leaf=True)
|
|
False
|
|
>>> treespec_none(none_is_leaf=True) == treespec_leaf(none_is_leaf=False)
|
|
False
|
|
>>> treespec_none(none_is_leaf=False) == treespec_leaf(none_is_leaf=False)
|
|
False
|
|
|
|
Args:
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a :data:`None` node.
|
|
"""
|
|
return _C.make_none(
|
|
none_is_leaf,
|
|
namespace, # unused
|
|
)
|
|
|
|
|
|
def treespec_tuple(
|
|
iterable: Iterable[PyTreeSpec] = (),
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Make a tuple treespec from a list of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_tuple([treespec_leaf(), treespec_leaf()])
|
|
PyTreeSpec((*, *))
|
|
>>> treespec_tuple([treespec_leaf(), treespec_leaf(), treespec_none()])
|
|
PyTreeSpec((*, *, None))
|
|
>>> treespec_tuple()
|
|
PyTreeSpec(())
|
|
>>> treespec_tuple([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])])
|
|
PyTreeSpec((*, (*, *)))
|
|
>>> treespec_tuple([treespec_leaf(), tree_structure({'a': 1, 'b': 2})])
|
|
PyTreeSpec((*, {'a': *, 'b': *}))
|
|
>>> treespec_tuple([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
iterable (iterable of PyTreeSpec, optional): A iterable of child treespecs. They must have
|
|
the same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a tuple node with the given children.
|
|
"""
|
|
return _C.make_from_collection(
|
|
tuple(iterable), # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_list(
|
|
iterable: Iterable[PyTreeSpec] = (),
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Make a list treespec from a list of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_list([treespec_leaf(), treespec_leaf()])
|
|
PyTreeSpec([*, *])
|
|
>>> treespec_list([treespec_leaf(), treespec_leaf(), treespec_none()])
|
|
PyTreeSpec([*, *, None])
|
|
>>> treespec_list()
|
|
PyTreeSpec([])
|
|
>>> treespec_list([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])])
|
|
PyTreeSpec([*, (*, *)])
|
|
>>> treespec_list([treespec_leaf(), tree_structure({'a': 1, 'b': 2})])
|
|
PyTreeSpec([*, {'a': *, 'b': *}])
|
|
>>> treespec_list([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
iterable (iterable of PyTreeSpec, optional): A iterable of child treespecs. They must have
|
|
the same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a list node with the given children.
|
|
"""
|
|
return _C.make_from_collection(
|
|
list(iterable), # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_dict(
|
|
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
**kwargs: PyTreeSpec,
|
|
) -> PyTreeSpec:
|
|
"""Make a dict treespec from a dict of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_dict({'a': treespec_leaf(), 'b': treespec_leaf()})
|
|
PyTreeSpec({'a': *, 'b': *})
|
|
>>> treespec_dict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())])
|
|
PyTreeSpec({'a': None, 'b': *, 'c': *})
|
|
>>> treespec_dict()
|
|
PyTreeSpec({})
|
|
>>> treespec_dict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()]))
|
|
PyTreeSpec({'a': *, 'b': (*, *)})
|
|
>>> treespec_dict({'a': treespec_leaf(), 'b': tree_structure([1, 2])})
|
|
PyTreeSpec({'a': *, 'b': [*, *]})
|
|
>>> treespec_dict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)})
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have the
|
|
same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a dict node with the given children.
|
|
"""
|
|
return _C.make_from_collection(
|
|
dict(mapping, **kwargs), # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_namedtuple(
|
|
namedtuple: NamedTuple[PyTreeSpec], # type: ignore[type-arg]
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Make a namedtuple treespec from a namedtuple of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> from collections import namedtuple
|
|
>>> Point = namedtuple('Point', ['x', 'y'])
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=treespec_leaf()))
|
|
PyTreeSpec(Point(x=*, y=*))
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=treespec_tuple([treespec_leaf(), treespec_leaf()])))
|
|
PyTreeSpec(Point(x=*, y=(*, *)))
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=tree_structure([1, 2])))
|
|
PyTreeSpec(Point(x=*, y=[*, *]))
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=tree_structure([1, 2], none_is_leaf=True)))
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
namedtuple (namedtuple of PyTreeSpec): A namedtuple of child treespecs. They must have the
|
|
same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a dict node with the given children.
|
|
"""
|
|
if not is_namedtuple_instance(namedtuple):
|
|
raise ValueError(f'Expected a namedtuple of PyTreeSpec(s), got {namedtuple!r}.')
|
|
return _C.make_from_collection(
|
|
namedtuple, # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_ordereddict(
|
|
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
**kwargs: PyTreeSpec,
|
|
) -> PyTreeSpec:
|
|
"""Make an OrderedDict treespec from an OrderedDict of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': treespec_leaf()})
|
|
PyTreeSpec(OrderedDict([('a', *), ('b', *)]))
|
|
>>> treespec_ordereddict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())])
|
|
PyTreeSpec(OrderedDict([('b', *), ('c', *), ('a', None)]))
|
|
>>> treespec_ordereddict()
|
|
PyTreeSpec(OrderedDict([]))
|
|
>>> treespec_ordereddict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()]))
|
|
PyTreeSpec(OrderedDict([('a', *), ('b', (*, *))]))
|
|
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2])})
|
|
PyTreeSpec(OrderedDict([('a', *), ('b', [*, *])]))
|
|
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)})
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have the
|
|
same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing an OrderedDict node with the given children.
|
|
"""
|
|
return _C.make_from_collection(
|
|
OrderedDict(mapping, **kwargs), # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_defaultdict(
|
|
default_factory: Callable[[], Any] | None = None,
|
|
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
**kwargs: PyTreeSpec,
|
|
) -> PyTreeSpec:
|
|
"""Make a defaultdict treespec from a defaultdict of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': treespec_leaf()})
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': *}))
|
|
>>> treespec_defaultdict(int, [('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())])
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': None, 'b': *, 'c': *}))
|
|
>>> treespec_defaultdict()
|
|
PyTreeSpec(defaultdict(None, {}))
|
|
>>> treespec_defaultdict(int)
|
|
PyTreeSpec(defaultdict(<class 'int'>, {}))
|
|
>>> treespec_defaultdict(int, a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()]))
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': (*, *)}))
|
|
>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': tree_structure([1, 2])})
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': [*, *]}))
|
|
>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)})
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
default_factory (callable or None, optional): A factory function that will be used to create
|
|
a missing value. (default: :data:`None`)
|
|
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have the
|
|
same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a defaultdict node with the given children.
|
|
"""
|
|
return _C.make_from_collection(
|
|
defaultdict(default_factory, mapping, **kwargs), # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_deque(
|
|
iterable: Iterable[PyTreeSpec] = (),
|
|
maxlen: int | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Make a deque treespec from a deque of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_deque([treespec_leaf(), treespec_leaf()])
|
|
PyTreeSpec(deque([*, *]))
|
|
>>> treespec_deque([treespec_leaf(), treespec_leaf(), treespec_none()], maxlen=5)
|
|
PyTreeSpec(deque([*, *, None], maxlen=5))
|
|
>>> treespec_deque()
|
|
PyTreeSpec(deque([]))
|
|
>>> treespec_deque([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])])
|
|
PyTreeSpec(deque([*, (*, *)]))
|
|
>>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5)
|
|
PyTreeSpec(deque([*, {'a': *, 'b': *}], maxlen=5))
|
|
>>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)], maxlen=5)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
Args:
|
|
iterable (iterable of PyTreeSpec, optional): A iterable of child treespecs. They must have
|
|
the same ``node_is_leaf`` and ``namespace`` values.
|
|
maxlen (int or None, optional): The maximum size of a deque or :data:`None` if unbounded.
|
|
(default: :data:`None`)
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a deque node with the given children.
|
|
"""
|
|
return _C.make_from_collection(
|
|
deque(iterable, maxlen=maxlen), # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_structseq(
|
|
structseq: PyStructSequence[PyTreeSpec],
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Make a PyStructSequence treespec from a PyStructSequence of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
Args:
|
|
structseq (PyStructSequence of PyTreeSpec): A PyStructSequence of child treespecs. They must
|
|
have the same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing a PyStructSequence node with the given children.
|
|
"""
|
|
if not is_structseq_instance(structseq):
|
|
raise ValueError(f'Expected a PyStructSequence of PyTreeSpec(s), got {structseq!r}.')
|
|
return _C.make_from_collection(
|
|
structseq, # type: ignore[arg-type]
|
|
none_is_leaf,
|
|
namespace,
|
|
)
|
|
|
|
|
|
def treespec_from_collection(
|
|
collection: CustomTreeNode[PyTreeSpec],
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> PyTreeSpec:
|
|
"""Make a treespec from a collection of child treespecs.
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
|
|
|
|
>>> treespec_from_collection(None)
|
|
PyTreeSpec(None)
|
|
>>> treespec_from_collection(None, none_is_leaf=True)
|
|
PyTreeSpec(*, NoneIsLeaf)
|
|
>>> treespec_from_collection(object())
|
|
PyTreeSpec(*)
|
|
>>> treespec_from_collection([treespec_leaf(), treespec_none()])
|
|
PyTreeSpec([*, None])
|
|
>>> treespec_from_collection({'a': treespec_leaf(), 'b': treespec_none()})
|
|
PyTreeSpec({'a': *, 'b': None})
|
|
>>> treespec_from_collection(deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5))
|
|
PyTreeSpec(deque([*, {'a': *, 'b': *}], maxlen=5))
|
|
>>> treespec_from_collection({'a': treespec_leaf(), 'b': (treespec_leaf(), treespec_none())})
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected a(n) dict of PyTreeSpec(s), got {'a': PyTreeSpec(*), 'b': (PyTreeSpec(*), PyTreeSpec(None))}.
|
|
>>> treespec_from_collection([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Expected treespec(s) with `node_is_leaf=False`.
|
|
|
|
|
|
Args:
|
|
collection (collection of PyTreeSpec): A collection of child treespecs. They must have the
|
|
same ``node_is_leaf`` and ``namespace`` values.
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
|
pytree. (default: :data:`False`)
|
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
|
(default: :const:`''`, i.e., the global namespace)
|
|
|
|
Returns:
|
|
A treespec representing the same structure of the collection with the given children.
|
|
"""
|
|
return _C.make_from_collection(collection, none_is_leaf, namespace)
|
|
|
|
|
|
def prefix_errors(
|
|
prefix_tree: PyTree[T],
|
|
full_tree: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> list[Callable[[str], ValueError]]:
|
|
"""Return a list of errors that would be raised by :func:`broadcast_prefix`."""
|
|
return list(
|
|
_prefix_error(
|
|
KeyPath(),
|
|
prefix_tree,
|
|
full_tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
),
|
|
)
|
|
|
|
|
|
STANDARD_DICT_TYPES = frozenset({dict, OrderedDict, defaultdict})
|
|
|
|
|
|
# pylint: disable-next=too-many-locals
|
|
def _prefix_error(
|
|
key_path: KeyPath,
|
|
prefix_tree: PyTree[T],
|
|
full_tree: PyTree[S],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> Iterable[Callable[[str], ValueError]]:
|
|
# A leaf is a valid prefix of any tree
|
|
if tree_is_leaf(prefix_tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace):
|
|
return
|
|
|
|
# The subtrees may disagree because their roots are of different types:
|
|
prefix_tree_type = type(prefix_tree)
|
|
full_tree_type = type(full_tree)
|
|
both_standard_dict = (
|
|
prefix_tree_type in STANDARD_DICT_TYPES # type: ignore[comparison-overlap]
|
|
and full_tree_type in STANDARD_DICT_TYPES # type: ignore[comparison-overlap]
|
|
)
|
|
both_deque = prefix_tree_type is deque and full_tree_type is deque # type: ignore[comparison-overlap]
|
|
if prefix_tree_type is not full_tree_type and (
|
|
# Special handling for dictionary types
|
|
not both_standard_dict
|
|
):
|
|
yield lambda name: ValueError(
|
|
f'pytree structure error: different types at key path\n'
|
|
f' {{name}}{key_path.pprint()}\n'
|
|
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
|
|
f' {type(prefix_tree)}\n'
|
|
f'but at the same key path the full pytree has a subtree of different type\n'
|
|
f' {type(full_tree)}.'.format(name=name),
|
|
)
|
|
return # don't look for more errors in this subtree
|
|
|
|
# Or they may disagree if their roots have different numbers of children (note that because both
|
|
# prefix_tree and full_tree have the same type at this point, and because prefix_tree is not a
|
|
# leaf, each can be flattened once):
|
|
prefix_tree_children, prefix_tree_metadata, _, __ = tree_flatten_one_level(
|
|
prefix_tree,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
full_tree_children, full_tree_metadata, _, __ = tree_flatten_one_level(
|
|
full_tree,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
# Special handling for dictionary types
|
|
if both_standard_dict:
|
|
prefix_tree_keys = (
|
|
prefix_tree_metadata
|
|
if prefix_tree_type is not defaultdict # type: ignore[comparison-overlap]
|
|
else prefix_tree_metadata[1] # type: ignore[index]
|
|
)
|
|
full_tree_keys = (
|
|
full_tree_metadata
|
|
if full_tree_type is not defaultdict # type: ignore[comparison-overlap]
|
|
else full_tree_metadata[1] # type: ignore[index]
|
|
)
|
|
prefix_tree_keys_set = set(prefix_tree_keys)
|
|
full_tree_keys_set = set(full_tree_keys)
|
|
if prefix_tree_keys_set != full_tree_keys_set:
|
|
missing_keys = sorted(prefix_tree_keys_set.difference(full_tree_keys_set))
|
|
extra_keys = sorted(full_tree_keys_set.difference(prefix_tree_keys_set))
|
|
key_difference = ''
|
|
if missing_keys:
|
|
key_difference += f'\nmissing key(s):\n {missing_keys}'
|
|
if extra_keys:
|
|
key_difference += f'\nextra key(s):\n {extra_keys}'
|
|
yield lambda name: ValueError(
|
|
f'pytree structure error: different pytree keys at key path\n'
|
|
f' {{name}}{key_path.pprint()}\n'
|
|
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
|
|
f' {prefix_tree_type}\n'
|
|
f'with {len(prefix_tree_keys)} key(s)\n'
|
|
f' {prefix_tree_keys}\n'
|
|
f'but at the same key path the full pytree has a subtree of type\n'
|
|
f' {full_tree_type}\n'
|
|
f'but with {len(full_tree_keys)} key(s)\n'
|
|
f' {full_tree_keys}{key_difference}'.format(name=name),
|
|
)
|
|
return # don't look for more errors in this subtree
|
|
|
|
# If the keys agree, we should ensure that the children are in the same order:
|
|
full_tree_children = [full_tree[k] for k in prefix_tree_keys] # type: ignore[index]
|
|
|
|
if len(prefix_tree_children) != len(full_tree_children):
|
|
yield lambda name: ValueError(
|
|
f'pytree structure error: different numbers of pytree children at key path\n'
|
|
f' {{name}}{key_path.pprint()}\n'
|
|
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
|
|
f' {prefix_tree_type}\n'
|
|
f'with {len(prefix_tree_children)} children, '
|
|
f'but at the same key path the full pytree has a subtree of the same '
|
|
f'type but with {len(full_tree_children)} children.'.format(name=name),
|
|
)
|
|
return # don't look for more errors in this subtree
|
|
|
|
# Or they may disagree if their roots have different pytree metadata:
|
|
if (
|
|
prefix_tree_metadata != full_tree_metadata
|
|
and (not both_deque) # ignore maxlen mismatch for deque
|
|
and (
|
|
# Special handling for dictionary types already done in the keys check above
|
|
not both_standard_dict
|
|
)
|
|
):
|
|
prefix_tree_metadata_repr = repr(prefix_tree_metadata)
|
|
full_tree_metadata_repr = repr(full_tree_metadata)
|
|
metadata_diff = textwrap.indent(
|
|
'\n'.join(
|
|
difflib.ndiff(
|
|
prefix_tree_metadata_repr.splitlines(),
|
|
full_tree_metadata_repr.splitlines(),
|
|
),
|
|
),
|
|
prefix=' ',
|
|
)
|
|
yield lambda name: ValueError(
|
|
f'pytree structure error: different pytree metadata at key path\n'
|
|
f' {{name}}{key_path.pprint()}\n'
|
|
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
|
|
f' {prefix_tree_type}\n'
|
|
f'with metadata\n'
|
|
f' {prefix_tree_metadata_repr}\n'
|
|
f'but at the same key path the full pytree has a subtree of the same '
|
|
f'type but with metadata\n'
|
|
f' {full_tree_metadata_repr}\n'
|
|
f'so the diff in the metadata at these pytree nodes is\n'
|
|
f'{metadata_diff}'.format(name=name),
|
|
)
|
|
return # don't look for more errors in this subtree
|
|
|
|
# If the root types and numbers of children agree, there must be an error in a subtree,
|
|
# so recurse:
|
|
keys = _child_keys(
|
|
prefix_tree,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
keys_ = _child_keys(
|
|
full_tree,
|
|
is_leaf=is_leaf, # type: ignore[arg-type]
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
assert keys == keys_ or (
|
|
# Special handling for dictionary types already done in the keys check above
|
|
both_standard_dict
|
|
), f'equal pytree nodes gave different keys: {keys} and {keys_}'
|
|
# pylint: disable-next=invalid-name
|
|
for k, t1, t2 in zip(keys, prefix_tree_children, full_tree_children):
|
|
yield from _prefix_error(
|
|
key_path + k,
|
|
t1,
|
|
t2,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=none_is_leaf,
|
|
namespace=namespace,
|
|
)
|
|
|
|
|
|
def _child_keys(
|
|
tree: PyTree[T],
|
|
is_leaf: Callable[[T], bool] | None = None,
|
|
*,
|
|
none_is_leaf: bool = False,
|
|
namespace: str = '',
|
|
) -> list[KeyPathEntry]:
|
|
treespec = tree_structure(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
|
|
assert not treespec_is_strict_leaf(treespec), 'treespec must be a non-leaf node'
|
|
|
|
handler = register_keypaths.get(type(tree)) # type: ignore[attr-defined]
|
|
if handler:
|
|
return list(handler(tree))
|
|
|
|
if is_structseq_instance(tree):
|
|
# Handle PyStructSequence as a special case, based on heuristic
|
|
return list(map(AttributeKeyPathEntry, structseq_fields(tree))) # type: ignore[arg-type]
|
|
|
|
if is_namedtuple_instance(tree):
|
|
# Handle namedtuple as a special case, based on heuristic
|
|
return list(map(AttributeKeyPathEntry, namedtuple_fields(tree))) # type: ignore[arg-type]
|
|
|
|
num_children = treespec.num_children
|
|
return list(map(FlattenedKeyPathEntry, range(num_children)))
|