271 lines
10 KiB
Python
271 lines
10 KiB
Python
|
# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
# This file is modified from:
|
||
|
# https://github.com/google/jax/blob/jax-v0.4.20/jax/_src/flatten_util.py
|
||
|
# ==============================================================================
|
||
|
# Copyright 2018 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Integration with JAX."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import itertools
|
||
|
import warnings
|
||
|
from types import FunctionType
|
||
|
from typing import Any, Callable
|
||
|
from typing_extensions import TypeAlias # Python 3.10+
|
||
|
|
||
|
import jax.numpy as jnp
|
||
|
from jax import Array, lax
|
||
|
from jax._src import dtypes
|
||
|
from jax.typing import ArrayLike
|
||
|
|
||
|
from optree.ops import tree_flatten, tree_unflatten
|
||
|
from optree.typing import PyTreeSpec, PyTreeTypeVar
|
||
|
from optree.utils import safe_zip, total_order_sorted
|
||
|
|
||
|
|
||
|
__all__ = ['ArrayLikeTree', 'ArrayTree', 'tree_ravel']
|
||
|
|
||
|
|
||
|
ArrayLikeTree: TypeAlias = PyTreeTypeVar('ArrayLikeTree', ArrayLike) # type: ignore[valid-type]
|
||
|
ArrayTree: TypeAlias = PyTreeTypeVar('ArrayTree', Array) # type: ignore[valid-type]
|
||
|
|
||
|
|
||
|
# Vendor from https://github.com/google/jax/blob/jax-v0.4.20/jax/_src/util.py
|
||
|
class HashablePartial: # pragma: no cover
|
||
|
"""A hashable version of :class:`functools.partial`."""
|
||
|
|
||
|
func: FunctionType
|
||
|
args: tuple[Any, ...]
|
||
|
kwargs: dict[str, Any]
|
||
|
|
||
|
def __init__(self, func: FunctionType | HashablePartial, *args: Any, **kwargs: Any) -> None:
|
||
|
"""Construct a :class:`HashablePartial` instance."""
|
||
|
if not callable(func):
|
||
|
raise TypeError(f'Expected a callable, got {func!r}.')
|
||
|
|
||
|
if isinstance(func, HashablePartial):
|
||
|
self.func = func.func
|
||
|
self.args = func.args + args
|
||
|
self.kwargs = {**func.kwargs, **kwargs}
|
||
|
elif isinstance(func, FunctionType):
|
||
|
self.func = func # type: ignore[assignment]
|
||
|
self.args = args
|
||
|
self.kwargs = kwargs
|
||
|
else:
|
||
|
raise TypeError(f'Expected a function, got {func!r}.')
|
||
|
|
||
|
def __eq__(self, other: object) -> bool:
|
||
|
return (
|
||
|
type(other) is HashablePartial # pylint: disable=unidiomatic-typecheck
|
||
|
and self.func.__code__ == other.func.__code__ # type: ignore[attr-defined]
|
||
|
and self.args == other.args
|
||
|
and self.kwargs == other.kwargs
|
||
|
)
|
||
|
|
||
|
def __hash__(self) -> int:
|
||
|
return hash(
|
||
|
(
|
||
|
self.func.__code__, # type: ignore[attr-defined]
|
||
|
self.args,
|
||
|
tuple(total_order_sorted(self.kwargs.items(), key=lambda kv: kv[0])),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||
|
return self.func(*self.args, *args, **self.kwargs, **kwargs)
|
||
|
|
||
|
|
||
|
try: # noqa: SIM105 # pragma: no cover
|
||
|
# pylint: disable=ungrouped-imports
|
||
|
from jax._src.util import HashablePartial # type: ignore[assignment] # noqa: F811,RUF100
|
||
|
except ImportError: # pragma: no cover
|
||
|
pass
|
||
|
|
||
|
|
||
|
def tree_ravel(
|
||
|
tree: ArrayLikeTree,
|
||
|
is_leaf: Callable[[Any], bool] | None = None,
|
||
|
*,
|
||
|
none_is_leaf: bool = False,
|
||
|
namespace: str = '',
|
||
|
) -> tuple[Array, Callable[[Array], ArrayTree]]:
|
||
|
r"""Ravel (flatten) a pytree of arrays down to a 1D array.
|
||
|
|
||
|
>>> tree = {
|
||
|
... 'layer1': {
|
||
|
... 'weight': jnp.arange(0, 6, dtype=jnp.float32).reshape((2, 3)),
|
||
|
... 'bias': jnp.arange(6, 8, dtype=jnp.float32).reshape((2,)),
|
||
|
... },
|
||
|
... 'layer2': {
|
||
|
... 'weight': jnp.arange(8, 10, dtype=jnp.float32).reshape((1, 2)),
|
||
|
... 'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,))
|
||
|
... },
|
||
|
... }
|
||
|
>>> tree
|
||
|
{'layer1': {'weight': Array([[0., 1., 2.],
|
||
|
[3., 4., 5.]], dtype=float32),
|
||
|
'bias': Array([6., 7.], dtype=float32)},
|
||
|
'layer2': {'weight': Array([[8., 9.]], dtype=float32),
|
||
|
'bias': Array([10.], dtype=float32)}}
|
||
|
>>> flat, unravel_func = tree_ravel(tree)
|
||
|
>>> flat
|
||
|
Array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
|
||
|
>>> unravel_func(flat)
|
||
|
{'layer1': {'weight': Array([[0., 1., 2.],
|
||
|
[3., 4., 5.]], dtype=float32),
|
||
|
'bias': Array([6., 7.], dtype=float32)},
|
||
|
'layer2': {'weight': Array([[8., 9.]], dtype=float32),
|
||
|
'bias': Array([10.], dtype=float32)}}
|
||
|
|
||
|
Args:
|
||
|
tree (pytree): a pytree of arrays and scalars to ravel.
|
||
|
is_leaf (callable, optional): An optionally specified function that will be called at each
|
||
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
||
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
||
|
flattening should traverse the current object.
|
||
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
|
||
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
|
||
|
treespec rather than in the leaves list and :data:`None` will be remain in the result
|
||
|
pytree. (default: :data:`False`)
|
||
|
namespace (str, optional): The registry namespace used for custom pytree node types.
|
||
|
(default: :const:`''`, i.e., the global namespace)
|
||
|
|
||
|
Returns:
|
||
|
A pair ``(array, unravel_func)`` where the first element is a 1D array representing the
|
||
|
flattened and concatenated leaf values, with ``dtype`` determined by promoting the
|
||
|
``dtype``\s of leaf values, and the second element is a callable for unflattening a 1D array
|
||
|
of the same length back to a pytree of the same structure as the input ``tree``. If the
|
||
|
input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the
|
||
|
default dtype is returned in the first component of the output.
|
||
|
"""
|
||
|
leaves, treespec = tree_flatten(
|
||
|
tree,
|
||
|
is_leaf=is_leaf,
|
||
|
none_is_leaf=none_is_leaf,
|
||
|
namespace=namespace,
|
||
|
)
|
||
|
flat, unravel_flat = _ravel_leaves(leaves)
|
||
|
return flat, HashablePartial(_tree_unravel, treespec, unravel_flat) # type: ignore[arg-type]
|
||
|
|
||
|
|
||
|
ravel_pytree = tree_ravel
|
||
|
|
||
|
|
||
|
def _tree_unravel(
|
||
|
treespec: PyTreeSpec,
|
||
|
unravel_flat: Callable[[Array], list[ArrayLike]],
|
||
|
flat: Array,
|
||
|
) -> ArrayTree:
|
||
|
return tree_unflatten(treespec, unravel_flat(flat))
|
||
|
|
||
|
|
||
|
def _ravel_leaves(
|
||
|
leaves: list[ArrayLike],
|
||
|
) -> tuple[Array, Callable[[Array], list[ArrayLike]]]:
|
||
|
if not leaves:
|
||
|
return (jnp.array([]), _unravel_empty)
|
||
|
|
||
|
from_dtypes = tuple(dtypes.dtype(leaf) for leaf in leaves)
|
||
|
to_dtype = dtypes.result_type(*from_dtypes)
|
||
|
sizes = tuple(jnp.size(leaf) for leaf in leaves)
|
||
|
shapes = tuple(jnp.shape(leaf) for leaf in leaves)
|
||
|
indices = tuple(itertools.accumulate(sizes))
|
||
|
|
||
|
if all(dt == to_dtype for dt in from_dtypes):
|
||
|
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
|
||
|
# See https://github.com/google/jax/issues/7809.
|
||
|
raveled = jnp.concatenate([jnp.ravel(leaf) for leaf in leaves])
|
||
|
return (
|
||
|
raveled,
|
||
|
HashablePartial(_unravel_leaves_single_dtype, indices, shapes), # type: ignore[arg-type]
|
||
|
)
|
||
|
|
||
|
# When there is more than one distinct input dtype, we perform type conversions and produce a
|
||
|
# dtype-specific unravel function.
|
||
|
raveled = jnp.concatenate(
|
||
|
[jnp.ravel(lax.convert_element_type(leaf, to_dtype)) for leaf in leaves],
|
||
|
)
|
||
|
return (
|
||
|
raveled,
|
||
|
HashablePartial(_unravel_leaves, indices, shapes, from_dtypes, to_dtype), # type: ignore[arg-type]
|
||
|
)
|
||
|
|
||
|
|
||
|
def _unravel_empty(flat: Array) -> list[ArrayLike]:
|
||
|
if jnp.shape(flat) != (0,): # type: ignore[comparison-overlap]
|
||
|
raise ValueError(
|
||
|
f'The unravel function expected an array of shape {(0,)}, '
|
||
|
f'got shape {jnp.shape(flat)}.',
|
||
|
)
|
||
|
|
||
|
return []
|
||
|
|
||
|
|
||
|
def _unravel_leaves_single_dtype(
|
||
|
indices: tuple[int, ...],
|
||
|
shapes: tuple[tuple[int, ...]],
|
||
|
flat: Array,
|
||
|
) -> list[Array]:
|
||
|
if jnp.shape(flat) != (indices[-1],): # type: ignore[comparison-overlap]
|
||
|
raise ValueError(
|
||
|
f'The unravel function expected an array of shape {(indices[-1],)}, '
|
||
|
f'got shape {jnp.shape(flat)}.',
|
||
|
)
|
||
|
|
||
|
chunks = jnp.split(flat, indices[:-1])
|
||
|
return [chunk.reshape(shape) for chunk, shape in safe_zip(chunks, shapes)]
|
||
|
|
||
|
|
||
|
def _unravel_leaves(
|
||
|
indices: tuple[int, ...],
|
||
|
shapes: tuple[tuple[int, ...]],
|
||
|
from_dtypes: tuple[jnp.dtype, ...],
|
||
|
to_dtype: jnp.dtype,
|
||
|
flat: Array,
|
||
|
) -> list[Array]:
|
||
|
if jnp.shape(flat) != (indices[-1],): # type: ignore[comparison-overlap]
|
||
|
raise ValueError(
|
||
|
f'The unravel function expected an array of shape {(indices[-1],)}, '
|
||
|
f'got shape {jnp.shape(flat)}.',
|
||
|
)
|
||
|
array_dtype = dtypes.dtype(flat)
|
||
|
if array_dtype != to_dtype:
|
||
|
raise ValueError(
|
||
|
f'The unravel function expected an array of dtype {to_dtype}, '
|
||
|
f'got dtype {array_dtype}.',
|
||
|
)
|
||
|
|
||
|
chunks = jnp.split(flat, indices[:-1])
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter('ignore') # ignore complex-to-real cast warning
|
||
|
return [
|
||
|
lax.convert_element_type(chunk.reshape(shape), dtype)
|
||
|
for chunk, shape, dtype in safe_zip(chunks, shapes, from_dtypes)
|
||
|
]
|