2023-06-19 00:49:18 +02:00

299 lines
11 KiB

# Copyright 2021 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
.. currentmodule:: jax.experimental.sparse
The :mod:`jax.experimental.sparse` module includes experimental support for sparse matrix
operations in JAX. It is under active development, and the API is subject to change. The
primary interfaces made available are the :class:`BCOO` sparse array type, and the
:func:`sparsify` transform.
Batched-coordinate (BCOO) sparse matrices
The main high-level sparse object currently available in JAX is the :class:`BCOO`,
or *batched coordinate* sparse array, which offers a compressed storage format compatible
with JAX transformations, in particular JIT (e.g. :func:`jax.jit`), batching
(e.g. :func:`jax.vmap`) and autodiff (e.g. :func:`jax.grad`).
Here is an example of creating a sparse array from a dense array:
>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
... [3., 0., 0., 0.],
... [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)
Convert back to a dense array with the ``todense()`` method:
>>> M_sp.todense()
Array([[0., 1., 0., 2.],
[3., 0., 0., 0.],
[0., 0., 4., 0.]], dtype=float32)
The BCOO format is a somewhat modified version of the standard COO format, and the dense
representation can be seen in the ``data`` and ``indices`` attributes:
>>> # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
[0, 3],
[1, 0],
[2, 2]], dtype=int32)
BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:
>>> M_sp.ndim
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
>>> M_sp.nse # "number of specified elements"
BCOO objects also implement a number of array-like methods, to allow you to use them
directly within jax programs. For example, here we compute the transposed matrix-vector
>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18., 3., 20., 6.], dtype=float32)
>>> M.T @ y # Compare to dense version
Array([18., 3., 20., 6.], dtype=float32)
BCOO objects are designed to be compatible with JAX transforms, including :func:`jax.jit`,
:func:`jax.vmap`, :func:`jax.grad`, and others. For example:
>>> from jax import grad, jit
>>> def f(y):
... return (M_sp.T @ y).sum()
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)
Note, however, that under normal circumstances :mod:`jax.numpy` and :mod:`jax.lax` functions
do not know how to handle sparse matrices, so attempting to compute things like
``, y)`` will result in an error (however, see the next section).
Sparsify transform
An overarching goal of the JAX sparse implementation is to provide a means to switch from
dense to sparse computation seamlessly, without having to modify the dense implementation.
This sparse experiment accomplishes this through the :func:`sparsify` transform.
Consider this function, which computes a more complicated result from a matrix and a vector input:
>>> def f(M, v):
... return 2 *, v) + 1
>>> f(M, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Were we to pass a sparse matrix to this directly, it would result in an error, because ``jnp``
functions do not recognize sparse inputs. However, with :func:`sparsify`, we get a version of
this function that does accept sparse matrices:
>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Support for :func:`sparsify` includes a large number of the most common primitives, including:
- generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`)
- zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.)
- zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.)
- summation reductions (:obj:`~jax.lax.reduce_sum_p`)
- general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
- concatenation and stacking (:obj:`~jax.lax.concatenate_p`)
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`,
:obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`)
- some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`)
- some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`)
Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used
within a sparsify transform to operate on sparse arrays. This set of primitives is enough
to enable relatively sophisticated sparse workflows, as the next section will show.
Example: sparse logistic regression
As an example of a more complicated sparse workflow, let's consider a simple logistic regression
implemented in JAX. Notice that the following implementation has no reference to sparsity:
>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
... return 0.5 * (jnp.tanh(x / 2) + 1)
>>> def y_model(params, X):
... return sigmoid(, params[1:]) + params[0])
>>> def loss(params, X, y):
... y_hat = y_model(params, X)
... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
>>> def fit_logreg(X, y):
... params = jnp.zeros(X.shape[1] + 1)
... result = optimize.minimize(functools.partial(loss, X=X, y=y),
... x0=params, method='BFGS')
... return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense) # doctest: +SKIP
[-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008
-0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844
0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087
-0.67060554 0.03139788 -0.05359547]
This returns the best-fit parameters of a dense logistic regression problem.
To fit the same model on sparse data, we can apply the :func:`sparsify` transform:
>>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse) # doctest: +SKIP
[-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566
-0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753
0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593
-0.670236 0.03132951 -0.05356663]
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 &
from import (
jacfwd as jacfwd,
jacobian as jacobian,
jacrev as jacrev,
grad as grad,
value_and_grad as value_and_grad,
from jax.experimental.sparse.bcoo import (
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
bcoo_concatenate as bcoo_concatenate,
bcoo_conv_general_dilated as bcoo_conv_general_dilated,
bcoo_dot_general as bcoo_dot_general,
bcoo_dot_general_p as bcoo_dot_general_p,
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
bcoo_dot_general_sampled_p as bcoo_dot_general_sampled_p,
bcoo_dynamic_slice as bcoo_dynamic_slice,
bcoo_extract as bcoo_extract,
bcoo_extract_p as bcoo_extract_p,
bcoo_fromdense as bcoo_fromdense,
bcoo_fromdense_p as bcoo_fromdense_p,
bcoo_gather as bcoo_gather,
bcoo_multiply_dense as bcoo_multiply_dense,
bcoo_multiply_sparse as bcoo_multiply_sparse,
bcoo_update_layout as bcoo_update_layout,
bcoo_reduce_sum as bcoo_reduce_sum,
bcoo_reshape as bcoo_reshape,
bcoo_rev as bcoo_rev,
bcoo_slice as bcoo_slice,
bcoo_sort_indices as bcoo_sort_indices,
bcoo_sort_indices_p as bcoo_sort_indices_p,
bcoo_spdot_general_p as bcoo_spdot_general_p,
bcoo_squeeze as bcoo_squeeze,
bcoo_sum_duplicates as bcoo_sum_duplicates,
bcoo_sum_duplicates_p as bcoo_sum_duplicates_p,
bcoo_todense as bcoo_todense,
bcoo_todense_p as bcoo_todense_p,
bcoo_transpose as bcoo_transpose,
bcoo_transpose_p as bcoo_transpose_p,
from jax.experimental.sparse.bcsr import (
bcsr_broadcast_in_dim as bcsr_broadcast_in_dim,
bcsr_concatenate as bcsr_concatenate,
bcsr_dot_general as bcsr_dot_general,
bcsr_dot_general_p as bcsr_dot_general_p,
bcsr_extract as bcsr_extract,
bcsr_extract_p as bcsr_extract_p,
bcsr_fromdense as bcsr_fromdense,
bcsr_fromdense_p as bcsr_fromdense_p,
bcsr_sum_duplicates as bcsr_sum_duplicates,
bcsr_todense as bcsr_todense,
bcsr_todense_p as bcsr_todense_p,
from jax.experimental.sparse._base import (
JAXSparse as JAXSparse
from jax.experimental.sparse.api import (
empty as empty,
eye as eye,
todense as todense,
todense_p as todense_p,
from jax.experimental.sparse.util import (
CuSparseEfficiencyWarning as CuSparseEfficiencyWarning,
SparseEfficiencyError as SparseEfficiencyError,
SparseEfficiencyWarning as SparseEfficiencyWarning,
from jax.experimental.sparse.coo import (
coo_fromdense as coo_fromdense,
coo_fromdense_p as coo_fromdense_p,
coo_matmat as coo_matmat,
coo_matmat_p as coo_matmat_p,
coo_matvec as coo_matvec,
coo_matvec_p as coo_matvec_p,
coo_todense as coo_todense,
coo_todense_p as coo_todense_p,
from jax.experimental.sparse.csr import (
csr_fromdense as csr_fromdense,
csr_fromdense_p as csr_fromdense_p,
csr_matmat as csr_matmat,
csr_matmat_p as csr_matmat_p,
csr_matvec as csr_matvec,
csr_matvec_p as csr_matvec_p,
csr_todense as csr_todense,
csr_todense_p as csr_todense_p,
from jax.experimental.sparse.random import random_bcoo as random_bcoo
from jax.experimental.sparse.transform import (
sparsify as sparsify,
SparseTracer as SparseTracer,
from jax.experimental.sparse import linalg as linalg