# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ .. currentmodule:: jax.experimental.sparse The :mod:`jax.experimental.sparse` module includes experimental support for sparse matrix operations in JAX. It is under active development, and the API is subject to change. The primary interfaces made available are the :class:`BCOO` sparse array type, and the :func:`sparsify` transform. Batched-coordinate (BCOO) sparse matrices ----------------------------------------- The main high-level sparse object currently available in JAX is the :class:`BCOO`, or *batched coordinate* sparse array, which offers a compressed storage format compatible with JAX transformations, in particular JIT (e.g. :func:`jax.jit`), batching (e.g. :func:`jax.vmap`) and autodiff (e.g. :func:`jax.grad`). Here is an example of creating a sparse array from a dense array: >>> from jax.experimental import sparse >>> import jax.numpy as jnp >>> import numpy as np >>> M = jnp.array([[0., 1., 0., 2.], ... [3., 0., 0., 0.], ... [0., 0., 4., 0.]]) >>> M_sp = sparse.BCOO.fromdense(M) >>> M_sp BCOO(float32[3, 4], nse=4) Convert back to a dense array with the ``todense()`` method: >>> M_sp.todense() Array([[0., 1., 0., 2.], [3., 0., 0., 0.], [0., 0., 4., 0.]], dtype=float32) The BCOO format is a somewhat modified version of the standard COO format, and the dense representation can be seen in the ``data`` and ``indices`` attributes: >>> M_sp.data # Explicitly stored data Array([1., 2., 3., 4.], dtype=float32) >>> M_sp.indices # Indices of the stored data Array([[0, 1], [0, 3], [1, 0], [2, 2]], dtype=int32) BCOO objects have familiar array-like attributes, as well as sparse-specific attributes: >>> M_sp.ndim 2 >>> M_sp.shape (3, 4) >>> M_sp.dtype dtype('float32') >>> M_sp.nse # "number of specified elements" 4 BCOO objects also implement a number of array-like methods, to allow you to use them directly within jax programs. For example, here we compute the transposed matrix-vector product: >>> y = jnp.array([3., 6., 5.]) >>> M_sp.T @ y Array([18., 3., 20., 6.], dtype=float32) >>> M.T @ y # Compare to dense version Array([18., 3., 20., 6.], dtype=float32) BCOO objects are designed to be compatible with JAX transforms, including :func:`jax.jit`, :func:`jax.vmap`, :func:`jax.grad`, and others. For example: >>> from jax import grad, jit >>> def f(y): ... return (M_sp.T @ y).sum() ... >>> jit(grad(f))(y) Array([3., 3., 4.], dtype=float32) Note, however, that under normal circumstances :mod:`jax.numpy` and :mod:`jax.lax` functions do not know how to handle sparse matrices, so attempting to compute things like ``jnp.dot(M_sp.T, y)`` will result in an error (however, see the next section). Sparsify transform ------------------ An overarching goal of the JAX sparse implementation is to provide a means to switch from dense to sparse computation seamlessly, without having to modify the dense implementation. This sparse experiment accomplishes this through the :func:`sparsify` transform. Consider this function, which computes a more complicated result from a matrix and a vector input: >>> def f(M, v): ... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1 ... >>> f(M, y) Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32) Were we to pass a sparse matrix to this directly, it would result in an error, because ``jnp`` functions do not recognize sparse inputs. However, with :func:`sparsify`, we get a version of this function that does accept sparse matrices: >>> f_sp = sparse.sparsify(f) >>> f_sp(M_sp, y) Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32) Support for :func:`sparsify` includes a large number of the most common primitives, including: - generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`) - zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.) - zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.) - summation reductions (:obj:`~jax.lax.reduce_sum_p`) - general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`) - concatenation and stacking (:obj:`~jax.lax.concatenate_p`) - transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`, :obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`) - some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`) - some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`) Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used within a sparsify transform to operate on sparse arrays. This set of primitives is enough to enable relatively sophisticated sparse workflows, as the next section will show. Example: sparse logistic regression ----------------------------------- As an example of a more complicated sparse workflow, let's consider a simple logistic regression implemented in JAX. Notice that the following implementation has no reference to sparsity: >>> import functools >>> from sklearn.datasets import make_classification >>> from jax.scipy import optimize >>> def sigmoid(x): ... return 0.5 * (jnp.tanh(x / 2) + 1) ... >>> def y_model(params, X): ... return sigmoid(jnp.dot(X, params[1:]) + params[0]) ... >>> def loss(params, X, y): ... y_hat = y_model(params, X) ... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)) ... >>> def fit_logreg(X, y): ... params = jnp.zeros(X.shape[1] + 1) ... result = optimize.minimize(functools.partial(loss, X=X, y=y), ... x0=params, method='BFGS') ... return result.x >>> X, y = make_classification(n_classes=2, random_state=1701) >>> params_dense = fit_logreg(X, y) >>> print(params_dense) # doctest: +SKIP [-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008 -0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844 0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087 -0.67060554 0.03139788 -0.05359547] This returns the best-fit parameters of a dense logistic regression problem. To fit the same model on sparse data, we can apply the :func:`sparsify` transform: >>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input >>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function >>> params_sparse = fit_logreg_sp(Xsp, y) >>> print(params_sparse) # doctest: +SKIP [-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566 -0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753 0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593 -0.670236 0.03132951 -0.05356663] """ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 from jax.experimental.sparse.ad import ( jacfwd as jacfwd, jacobian as jacobian, jacrev as jacrev, grad as grad, value_and_grad as value_and_grad, ) from jax.experimental.sparse.bcoo import ( bcoo_broadcast_in_dim as bcoo_broadcast_in_dim, bcoo_concatenate as bcoo_concatenate, bcoo_conv_general_dilated as bcoo_conv_general_dilated, bcoo_dot_general as bcoo_dot_general, bcoo_dot_general_p as bcoo_dot_general_p, bcoo_dot_general_sampled as bcoo_dot_general_sampled, bcoo_dot_general_sampled_p as bcoo_dot_general_sampled_p, bcoo_dynamic_slice as bcoo_dynamic_slice, bcoo_extract as bcoo_extract, bcoo_extract_p as bcoo_extract_p, bcoo_fromdense as bcoo_fromdense, bcoo_fromdense_p as bcoo_fromdense_p, bcoo_gather as bcoo_gather, bcoo_multiply_dense as bcoo_multiply_dense, bcoo_multiply_sparse as bcoo_multiply_sparse, bcoo_update_layout as bcoo_update_layout, bcoo_reduce_sum as bcoo_reduce_sum, bcoo_reshape as bcoo_reshape, bcoo_rev as bcoo_rev, bcoo_slice as bcoo_slice, bcoo_sort_indices as bcoo_sort_indices, bcoo_sort_indices_p as bcoo_sort_indices_p, bcoo_spdot_general_p as bcoo_spdot_general_p, bcoo_squeeze as bcoo_squeeze, bcoo_sum_duplicates as bcoo_sum_duplicates, bcoo_sum_duplicates_p as bcoo_sum_duplicates_p, bcoo_todense as bcoo_todense, bcoo_todense_p as bcoo_todense_p, bcoo_transpose as bcoo_transpose, bcoo_transpose_p as bcoo_transpose_p, BCOO as BCOO, ) from jax.experimental.sparse.bcsr import ( bcsr_broadcast_in_dim as bcsr_broadcast_in_dim, bcsr_concatenate as bcsr_concatenate, bcsr_dot_general as bcsr_dot_general, bcsr_dot_general_p as bcsr_dot_general_p, bcsr_extract as bcsr_extract, bcsr_extract_p as bcsr_extract_p, bcsr_fromdense as bcsr_fromdense, bcsr_fromdense_p as bcsr_fromdense_p, bcsr_sum_duplicates as bcsr_sum_duplicates, bcsr_todense as bcsr_todense, bcsr_todense_p as bcsr_todense_p, BCSR as BCSR, ) from jax.experimental.sparse._base import ( JAXSparse as JAXSparse ) from jax.experimental.sparse.api import ( empty as empty, eye as eye, todense as todense, todense_p as todense_p, ) from jax.experimental.sparse.util import ( CuSparseEfficiencyWarning as CuSparseEfficiencyWarning, SparseEfficiencyError as SparseEfficiencyError, SparseEfficiencyWarning as SparseEfficiencyWarning, ) from jax.experimental.sparse.coo import ( coo_fromdense as coo_fromdense, coo_fromdense_p as coo_fromdense_p, coo_matmat as coo_matmat, coo_matmat_p as coo_matmat_p, coo_matvec as coo_matvec, coo_matvec_p as coo_matvec_p, coo_todense as coo_todense, coo_todense_p as coo_todense_p, COO as COO, ) from jax.experimental.sparse.csr import ( csr_fromdense as csr_fromdense, csr_fromdense_p as csr_fromdense_p, csr_matmat as csr_matmat, csr_matmat_p as csr_matmat_p, csr_matvec as csr_matvec, csr_matvec_p as csr_matvec_p, csr_todense as csr_todense, csr_todense_p as csr_todense_p, CSC as CSC, CSR as CSR, ) from jax.experimental.sparse.random import random_bcoo as random_bcoo from jax.experimental.sparse.transform import ( sparsify as sparsify, SparseTracer as SparseTracer, ) from jax.experimental.sparse import linalg as linalg