# Copyright 2021 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sparse test utilities."""
import functools
from typing import Any, Callable, Sequence, Union
import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.typing import DTypeLike
from jax import tree_util
from jax.util import safe_zip, split_list
from jax.experimental import sparse
import jax.numpy as jnp
def is_sparse(x):
return isinstance(x, sparse.JAXSparse)
class SparseTestCase(jtu.JaxTestCase):
def assertSparseArraysEquivalent(self, x, y, *, check_dtypes=True, atol=None,
rtol=None, canonicalize_dtypes=True, err_msg=''):
x_bufs, x_tree = tree_util.tree_flatten(x)
y_bufs, y_tree = tree_util.tree_flatten(y)
self.assertEqual(x_tree, y_tree)
self.assertAllClose(x_bufs, y_bufs, check_dtypes=check_dtypes, atol=atol, rtol=rtol,
canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg)
def _CheckAgainstDense(self, dense_op, sparse_op, args_maker, check_jit=True,
check_dtypes=True, tol=None, atol=None, rtol=None,
"""Check an operation against a dense equivalent"""
sparse_args = args_maker()
dense_args = tree_util.tree_map(sparse.todense, sparse_args, is_leaf=is_sparse)
expected = dense_op(*dense_args)
sparse_ans = sparse_op(*sparse_args)
actual = tree_util.tree_map(sparse.todense, sparse_ans, is_leaf=is_sparse)
self.assertAllClose(expected, actual, check_dtypes=check_dtypes,
atol=atol or tol, rtol=rtol or tol,
if check_jit:
sparse_ans_jit = jax.jit(sparse_op)(*sparse_args)
self.assertSparseArraysEquivalent(sparse_ans, sparse_ans_jit,
atol=atol or tol, rtol=rtol or tol)
def _CheckGradsSparse(self, dense_fun, sparse_fun, args_maker, *,
argnums=None, modes=('fwd', 'rev'), atol=None, rtol=None):
assert all(mode in ['fwd', 'rev'] for mode in modes)
args = args_maker()
args_flat, tree = tree_util.tree_flatten(args)
num_bufs = [len(tree_util.tree_flatten(arg)[0]) for arg in args]
argnums_flat = np.cumsum([0, *num_bufs[:-1]]).tolist()
if argnums is not None:
argnums_flat = [argnums_flat[n] for n in argnums]
def dense_fun_flat(*args_flat):
args = tree_util.tree_unflatten(tree, args_flat)
args_dense = tree_util.tree_map(sparse.todense, args, is_leaf=is_sparse)
return dense_fun(*args_dense)
def sparse_fun_flat(*args_flat):
out = sparse_fun(*tree_util.tree_unflatten(tree, args_flat))
return tree_util.tree_map(sparse.todense, out, is_leaf=is_sparse)
if 'rev' in modes:
result_de = jax.jacrev(dense_fun_flat, argnums=argnums_flat)(*args_flat)
result_sp = jax.jacrev(sparse_fun_flat, argnums=argnums_flat)(*args_flat)
self.assertAllClose(result_de, result_sp, atol=atol, rtol=rtol)
if 'fwd' in modes:
result_de = jax.jacfwd(dense_fun_flat, argnums=argnums_flat)(*args_flat)
result_sp = jax.jacfwd(sparse_fun_flat, argnums=argnums_flat)(*args_flat)
self.assertAllClose(result_de, result_sp, atol=atol, rtol=rtol)
def _random_bdims(self, *args):
rng = self.rng()
return [rng.randint(0, arg + 1) for arg in args]
def _CheckBatchingSparse(self, dense_fun, sparse_fun, args_maker, *, batch_size=3, bdims=None,
check_jit=False, check_dtypes=True, tol=None, atol=None, rtol=None,
if bdims is None:
bdims = self._random_bdims(*(arg.n_batch if is_sparse(arg) else arg.ndim
for arg in args_maker()))
def concat(args, bdim):
return sparse.sparsify(functools.partial(lax.concatenate, dimension=bdim))(args)
def expand(arg, bdim):
return sparse.sparsify(functools.partial(lax.expand_dims, dimensions=[bdim]))(arg)
def batched_args_maker():
args = list(zip(*(args_maker() for _ in range(batch_size))))
return [arg[0] if bdim is None else concat([expand(x, bdim) for x in arg], bdim)
for arg, bdim in safe_zip(args, bdims)]
self._CheckAgainstDense(jax.vmap(dense_fun, bdims), jax.vmap(sparse_fun, bdims), batched_args_maker,
check_dtypes=check_dtypes, tol=tol, atol=atol, rtol=rtol, check_jit=check_jit,
def _rand_sparse(shape: Sequence[int], dtype: DTypeLike, *,
rng: np.random.RandomState, rand_method: Callable[..., Any],
nse: Union[int, float], n_batch: int, n_dense: int,
sparse_format: str) -> Union[sparse.BCOO, sparse.BCSR]:
if sparse_format not in ['bcoo', 'bcsr']:
raise ValueError(f"Sparse format {sparse_format} not supported.")
n_sparse = len(shape) - n_batch - n_dense
if n_sparse < 0 or n_batch < 0 or n_dense < 0:
raise ValueError(f"Invalid parameters: {shape=} {n_batch=} {n_sparse=}")
if sparse_format == 'bcsr' and n_sparse != 2:
raise ValueError("bcsr array must have 2 sparse dimensions; "
f"{n_sparse} is given.")
batch_shape, sparse_shape, dense_shape = split_list(shape,
[n_batch, n_sparse])
if 0 <= nse < 1:
nse = int(np.ceil(nse *
data_rng = rand_method(rng)
data_shape = (*batch_shape, nse, *dense_shape)
data = jnp.array(data_rng(data_shape, dtype))
if sparse_format == 'bcoo':
index_shape = (*batch_shape, nse, n_sparse)
indices = jnp.array(
rng.randint(0, sparse_shape, size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
return sparse.BCOO((data, indices), shape=shape)
index_shape = (*batch_shape, nse)
indptr_shape = (*batch_shape, sparse_shape[0] + 1)
indices = jnp.array(
rng.randint(0, sparse_shape[1], size=index_shape, dtype=np.int32)) # type: ignore[arg-type]
indptr = jnp.sort(
rng.randint(0, nse + 1, size=indptr_shape, dtype=np.int32), axis=-1) # type: ignore[call-overload]
indptr =[..., 0].set(0)
return sparse.BCSR((data, indices, indptr), shape=shape)
def rand_bcoo(rng: np.random.RandomState,
rand_method: Callable[..., Any]=jtu.rand_default,
nse: Union[int, float]=0.5, n_batch: int=0, n_dense: int=0):
"""Generates a random BCOO array."""
return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method,
nse=nse, n_batch=n_batch, n_dense=n_dense,
def rand_bcsr(rng: np.random.RandomState,
rand_method: Callable[..., Any]=jtu.rand_default,
nse: Union[int, float]=0.5, n_batch: int=0, n_dense: int=0):
"""Generates a random BCSR array."""
return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method,
nse=nse, n_batch=n_batch, n_dense=n_dense,