Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/sparse/ad.py
2023-06-19 00:49:18 +02:00

178 lines
7.4 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, Callable, Sequence, Tuple, Union
import jax
from jax._src import core
from jax import tree_util
from jax._src.api_util import _ensure_index, _ensure_index_tuple
from jax.util import safe_zip
from jax._src.util import split_list, wraps
from jax._src.traceback_util import api_boundary
from jax.experimental.sparse._base import JAXSparse
is_sparse = lambda x: isinstance(x, JAXSparse)
def flatten_fun_for_sparse_ad(fun, argnums: Union[int, Tuple[int]], args: Tuple[Any]):
argnums_tup = _ensure_index_tuple(argnums)
assert all(0 <= argnum < len(args) for argnum in argnums_tup)
# We do a two-step flattening to figure out how argnums maps to args_flat.
# First, flatten arguments to a list containing sparse and dense objects.
args_flat1, tree1 = tree_util.tree_flatten(args, is_leaf=is_sparse)
*leaf_argnums1, end = split_list(range(tree1.num_leaves),
[child.num_leaves for child in tree1.children()])
assert not end
argnums_flat1 = list(itertools.chain.from_iterable(
nums for i, nums in enumerate(leaf_argnums1) if i in argnums_tup))
# Next, fully flatten to a list of dense buffers.
args_flat, tree2 = tree_util.tree_flatten(args_flat1)
*leaf_argnums2, end = split_list(range(tree2.num_leaves),
[child.num_leaves for child in tree2.children()])
assert not end
# For sparse args, we only mark the first buffer (the data) for differentiation.
leaf_argnums2 = [nums[:1] if is_sparse(arg) else nums
for arg, nums in safe_zip(args_flat1, leaf_argnums2)]
argnums_flat = tuple(itertools.chain.from_iterable(
nums for i, nums in enumerate(leaf_argnums2) if i in argnums_flat1))
def fun_flat(*args_flat, **kwargs):
args = tree_util.tree_unflatten(tree1, tree_util.tree_unflatten(tree2, args_flat))
return fun(*args, **kwargs)
def reconstruct(i, grad_out):
bufs, tree = tree_util.tree_flatten(args_flat1[i])
f_recons = lambda g: tree_util.tree_unflatten(tree, [g, *bufs[1:]])
for _ in range(grad_out.ndim - bufs[0].ndim):
f_recons = jax.vmap(f_recons)
return f_recons(grad_out)
def postprocess_gradients(grads_out):
out = [reconstruct(*args) for args in safe_zip(argnums_flat1, grads_out)]
return out[0] if isinstance(argnums, int) else out
return fun_flat, argnums_flat, args_flat, postprocess_gradients
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux=False, **kwargs) -> Callable[..., Tuple[Any, Any]]:
"""Sparse-aware version of :func:`jax.value_and_grad`
Arguments and return values are the same as :func:`jax.value_and_grad`, but when
taking the gradient with respect to a :class:`jax.experimental.sparse` array, the
gradient is computed in the subspace defined by the array's sparsity pattern.
Example:
>>> from jax.experimental import sparse
>>> X = sparse.BCOO.fromdense(jnp.arange(6.))
>>> y = jnp.ones(6)
>>> sparse.value_and_grad(lambda X, y: X @ y)(X, y)
(Array(15., dtype=float32), BCOO(float32[6], nse=5))
"""
raw_value_and_grad_fun = jax.value_and_grad(fun, argnums=argnums, has_aux=has_aux, **kwargs)
argnums = core.concrete_or_error(_ensure_index, argnums)
@wraps(fun, docstr=raw_value_and_grad_fun.__doc__, argnums=argnums)
@api_boundary
def value_and_grad_fun(*args, **kwargs):
fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args)
val_out, grad_out = jax.value_and_grad(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat)
return val_out, postprocess_gradients(grad_out)
return value_and_grad_fun
def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux=False, **kwargs) -> Callable:
"""Sparse-aware version of :func:`jax.grad`
Arguments and return values are the same as :func:`jax.grad`, but when taking
the gradient with respect to a :class:`jax.experimental.sparse` array, the
gradient is computed in the subspace defined by the array's sparsity pattern.
Example:
>>> from jax.experimental import sparse
>>> X = sparse.BCOO.fromdense(jnp.arange(6.))
>>> y = jnp.ones(6)
>>> sparse.grad(lambda X, y: X @ y)(X, y)
BCOO(float32[6], nse=5)
"""
raw_grad_fun = jax.grad(fun, argnums=argnums, **kwargs)
argnums = core.concrete_or_error(_ensure_index, argnums)
@wraps(fun, docstr=raw_grad_fun.__doc__, argnums=argnums)
@api_boundary
def grad_fun(*args, **kwargs):
fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args)
out = jax.grad(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat)
if has_aux:
return postprocess_gradients(out[0]), out[1]
return postprocess_gradients(out)
return grad_fun
def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, **kwargs) -> Callable:
"""Sparse-aware version of :func:`jax.jacfwd`
Arguments and return values are the same as :func:`jax.jacfwd`, but when taking
the gradient with respect to a :class:`jax.experimental.sparse` array, the
gradient is computed in the subspace defined by the array's sparsity pattern.
Currently this is only implemented for dense outputs.
"""
raw_jacfwd_fun = jax.jacfwd(fun, argnums=argnums, **kwargs)
argnums = core.concrete_or_error(_ensure_index, argnums)
@wraps(fun, docstr=raw_jacfwd_fun.__doc__, argnums=argnums)
@api_boundary
def jacfwd_fun(*args, **kwargs):
fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args)
out = jax.jacfwd(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat)
if has_aux:
return postprocess_gradients(out[0]), out[1]
return postprocess_gradients(out)
return jacfwd_fun
def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, **kwargs) -> Callable:
"""Sparse-aware version of :func:`jax.jacrev`
Arguments and return values are the same as :func:`jax.jacrev`, but when taking
the gradient with respect to a :class:`jax.experimental.sparse` array, the
gradient is computed in the subspace defined by the array's sparsity pattern.
Currently this is only implemented for dense outputs.
"""
raw_jacrev_fun = jax.jacrev(fun, argnums=argnums, **kwargs)
argnums = core.concrete_or_error(_ensure_index, argnums)
@wraps(fun, docstr=raw_jacrev_fun.__doc__, argnums=argnums)
@api_boundary
def jacrev_fun(*args, **kwargs):
fun_flat, argnums_flat, args_flat, postprocess_gradients = flatten_fun_for_sparse_ad(fun, argnums, args)
out = jax.jacrev(fun_flat, argnums=argnums_flat, has_aux=has_aux, **kwargs)(*args_flat)
if has_aux:
return postprocess_gradients(out[0]), out[1]
return postprocess_gradients(out)
return jacrev_fun
jacobian = jacrev