# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from typing import Any, Callable, Sequence, Tuple, Union import jax from jax._src import core from jax import tree_util from jax._src.api_util import _ensure_index, _ensure_index_tuple from jax.util import safe_zip from jax._src.util import split_list, wraps from jax._src.traceback_util import api_boundary from jax.experimental.sparse._base import JAXSparse is_sparse = lambda x: isinstance(x, JAXSparse) def flatten_fun_for_sparse_ad(fun, argnums: Union[int, Tuple[int]], args: Tuple[Any]): argnums_tup = _ensure_index_tuple(argnums) assert all(0 <= argnum < len(args) for argnum in argnums_tup) # We do a two-step flattening to figure out how argnums maps to args_flat. # First, flatten arguments to a list containing sparse and dense objects. args_flat1, tree1 = tree_util.tree_flatten(args, is_leaf=is_sparse) *leaf_argnums1, end = split_list(range(tree1.num_leaves), [child.num_leaves for child in tree1.children()]) assert not end argnums_flat1 = list(itertools.chain.from_iterable( nums for i, nums in enumerate(leaf_argnums1) if i in argnums_tup)) # Next, fully flatten to a list of dense buffers. args_flat, tree2 = tree_util.tree_flatten(args_flat1) *leaf_argnums2, end = split_list(range(tree2.num_leaves), [child.num_leaves for child in tree2.children()]) assert not end # For sparse args, we only mark the first buffer (the data) for differentiation. leaf_argnums2 = [nums[:1] if is_sparse(arg) else nums for arg, nums in safe_zip(args_flat1, leaf_argnums2)] argnums_flat = tuple(itertools.chain.from_iterable( nums for i, nums in enumerate(leaf_argnums2) if i in argnums_flat1)) def fun_flat(*args_flat, **kwargs): args = tree_util.tree_unflatten(tree1, tree_util.tree_unflatten(tree2, args_flat)) return fun(*args, **kwargs) def reconstruct(i, grad_out): bufs, tree = tree_util.tree_flatten(args_flat1[i]) f_recons = lambda g: tree_util.tree_unflatten(tree, [g, *bufs[1:]]) for _ in range(grad_out.ndim - bufs[0].ndim): f_recons = jax.vmap(f_recons) return f_recons(grad_out) def postprocess_gradients(grads_out): out = [reconstruct(*args) for args in safe_zip(argnums_flat1, grads_out)] return out[0] if isinstance(argnums, int) else out return fun_flat, argnums_flat, args_flat, postprocess_gradients def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux=False, **kwargs) -> Callable[..., Tuple[Any, Any]]: """Sparse-aware version of :func:`jax.value_and_grad` Arguments and return values are the same as :func:`jax.value_and_grad`, but when taking the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. Example: >>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) >>> y = jnp.ones(6) >>> sparse.value_and_grad(lambda X, y: X @ y)(X, y) (Array(15., dtype=float32), BCOO(float32[6], nse=5)) """ raw_value_and_grad_fun = jax.value_and_grad(fun, argnums=argnums, has_aux=has_aux, **kwargs) argnums = core.concrete_or_error(_ensure_index, argnums) @wraps(fun, docstr=raw_value_and_grad_fun.__doc__, argnums=argnums) @api_boundary def value_and_grad_fun(*args, **kwargs): fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args) val_out, grad_out = jax.value_and_grad(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat) return val_out, postprocess_gradients(grad_out) return value_and_grad_fun def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux=False, **kwargs) -> Callable: """Sparse-aware version of :func:`jax.grad` Arguments and return values are the same as :func:`jax.grad`, but when taking the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. Example: >>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) >>> y = jnp.ones(6) >>> sparse.grad(lambda X, y: X @ y)(X, y) BCOO(float32[6], nse=5) """ raw_grad_fun = jax.grad(fun, argnums=argnums, **kwargs) argnums = core.concrete_or_error(_ensure_index, argnums) @wraps(fun, docstr=raw_grad_fun.__doc__, argnums=argnums) @api_boundary def grad_fun(*args, **kwargs): fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args) out = jax.grad(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat) if has_aux: return postprocess_gradients(out[0]), out[1] return postprocess_gradients(out) return grad_fun def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, **kwargs) -> Callable: """Sparse-aware version of :func:`jax.jacfwd` Arguments and return values are the same as :func:`jax.jacfwd`, but when taking the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. Currently this is only implemented for dense outputs. """ raw_jacfwd_fun = jax.jacfwd(fun, argnums=argnums, **kwargs) argnums = core.concrete_or_error(_ensure_index, argnums) @wraps(fun, docstr=raw_jacfwd_fun.__doc__, argnums=argnums) @api_boundary def jacfwd_fun(*args, **kwargs): fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args) out = jax.jacfwd(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat) if has_aux: return postprocess_gradients(out[0]), out[1] return postprocess_gradients(out) return jacfwd_fun def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, **kwargs) -> Callable: """Sparse-aware version of :func:`jax.jacrev` Arguments and return values are the same as :func:`jax.jacrev`, but when taking the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. Currently this is only implemented for dense outputs. """ raw_jacrev_fun = jax.jacrev(fun, argnums=argnums, **kwargs) argnums = core.concrete_or_error(_ensure_index, argnums) @wraps(fun, docstr=raw_jacrev_fun.__doc__, argnums=argnums) @api_boundary def jacrev_fun(*args, **kwargs): fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args) out = jax.jacrev(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat) if has_aux: return postprocess_gradients(out[0]), out[1] return postprocess_gradients(out) return jacrev_fun jacobian = jacrev