299 lines
11 KiB
Python
299 lines
11 KiB
Python
![]() |
# Copyright 2021 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
"""
|
||
|
.. currentmodule:: jax.experimental.sparse
|
||
|
|
||
|
The :mod:`jax.experimental.sparse` module includes experimental support for sparse matrix
|
||
|
operations in JAX. It is under active development, and the API is subject to change. The
|
||
|
primary interfaces made available are the :class:`BCOO` sparse array type, and the
|
||
|
:func:`sparsify` transform.
|
||
|
|
||
|
Batched-coordinate (BCOO) sparse matrices
|
||
|
-----------------------------------------
|
||
|
The main high-level sparse object currently available in JAX is the :class:`BCOO`,
|
||
|
or *batched coordinate* sparse array, which offers a compressed storage format compatible
|
||
|
with JAX transformations, in particular JIT (e.g. :func:`jax.jit`), batching
|
||
|
(e.g. :func:`jax.vmap`) and autodiff (e.g. :func:`jax.grad`).
|
||
|
|
||
|
Here is an example of creating a sparse array from a dense array:
|
||
|
|
||
|
>>> from jax.experimental import sparse
|
||
|
>>> import jax.numpy as jnp
|
||
|
>>> import numpy as np
|
||
|
|
||
|
>>> M = jnp.array([[0., 1., 0., 2.],
|
||
|
... [3., 0., 0., 0.],
|
||
|
... [0., 0., 4., 0.]])
|
||
|
|
||
|
>>> M_sp = sparse.BCOO.fromdense(M)
|
||
|
|
||
|
>>> M_sp
|
||
|
BCOO(float32[3, 4], nse=4)
|
||
|
|
||
|
Convert back to a dense array with the ``todense()`` method:
|
||
|
|
||
|
>>> M_sp.todense()
|
||
|
Array([[0., 1., 0., 2.],
|
||
|
[3., 0., 0., 0.],
|
||
|
[0., 0., 4., 0.]], dtype=float32)
|
||
|
|
||
|
The BCOO format is a somewhat modified version of the standard COO format, and the dense
|
||
|
representation can be seen in the ``data`` and ``indices`` attributes:
|
||
|
|
||
|
>>> M_sp.data # Explicitly stored data
|
||
|
Array([1., 2., 3., 4.], dtype=float32)
|
||
|
|
||
|
>>> M_sp.indices # Indices of the stored data
|
||
|
Array([[0, 1],
|
||
|
[0, 3],
|
||
|
[1, 0],
|
||
|
[2, 2]], dtype=int32)
|
||
|
|
||
|
BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:
|
||
|
|
||
|
>>> M_sp.ndim
|
||
|
2
|
||
|
|
||
|
>>> M_sp.shape
|
||
|
(3, 4)
|
||
|
|
||
|
>>> M_sp.dtype
|
||
|
dtype('float32')
|
||
|
|
||
|
>>> M_sp.nse # "number of specified elements"
|
||
|
4
|
||
|
|
||
|
BCOO objects also implement a number of array-like methods, to allow you to use them
|
||
|
directly within jax programs. For example, here we compute the transposed matrix-vector
|
||
|
product:
|
||
|
|
||
|
>>> y = jnp.array([3., 6., 5.])
|
||
|
|
||
|
>>> M_sp.T @ y
|
||
|
Array([18., 3., 20., 6.], dtype=float32)
|
||
|
|
||
|
>>> M.T @ y # Compare to dense version
|
||
|
Array([18., 3., 20., 6.], dtype=float32)
|
||
|
|
||
|
BCOO objects are designed to be compatible with JAX transforms, including :func:`jax.jit`,
|
||
|
:func:`jax.vmap`, :func:`jax.grad`, and others. For example:
|
||
|
|
||
|
>>> from jax import grad, jit
|
||
|
|
||
|
>>> def f(y):
|
||
|
... return (M_sp.T @ y).sum()
|
||
|
...
|
||
|
>>> jit(grad(f))(y)
|
||
|
Array([3., 3., 4.], dtype=float32)
|
||
|
|
||
|
Note, however, that under normal circumstances :mod:`jax.numpy` and :mod:`jax.lax` functions
|
||
|
do not know how to handle sparse matrices, so attempting to compute things like
|
||
|
``jnp.dot(M_sp.T, y)`` will result in an error (however, see the next section).
|
||
|
|
||
|
Sparsify transform
|
||
|
------------------
|
||
|
An overarching goal of the JAX sparse implementation is to provide a means to switch from
|
||
|
dense to sparse computation seamlessly, without having to modify the dense implementation.
|
||
|
This sparse experiment accomplishes this through the :func:`sparsify` transform.
|
||
|
|
||
|
Consider this function, which computes a more complicated result from a matrix and a vector input:
|
||
|
|
||
|
>>> def f(M, v):
|
||
|
... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
|
||
|
...
|
||
|
>>> f(M, y)
|
||
|
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
|
||
|
|
||
|
Were we to pass a sparse matrix to this directly, it would result in an error, because ``jnp``
|
||
|
functions do not recognize sparse inputs. However, with :func:`sparsify`, we get a version of
|
||
|
this function that does accept sparse matrices:
|
||
|
|
||
|
>>> f_sp = sparse.sparsify(f)
|
||
|
|
||
|
>>> f_sp(M_sp, y)
|
||
|
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
|
||
|
|
||
|
Support for :func:`sparsify` includes a large number of the most common primitives, including:
|
||
|
|
||
|
- generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`)
|
||
|
- zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.)
|
||
|
- zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.)
|
||
|
- summation reductions (:obj:`~jax.lax.reduce_sum_p`)
|
||
|
- general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
|
||
|
- concatenation and stacking (:obj:`~jax.lax.concatenate_p`)
|
||
|
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`,
|
||
|
:obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`)
|
||
|
- some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`)
|
||
|
- some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`)
|
||
|
|
||
|
Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used
|
||
|
within a sparsify transform to operate on sparse arrays. This set of primitives is enough
|
||
|
to enable relatively sophisticated sparse workflows, as the next section will show.
|
||
|
|
||
|
Example: sparse logistic regression
|
||
|
-----------------------------------
|
||
|
As an example of a more complicated sparse workflow, let's consider a simple logistic regression
|
||
|
implemented in JAX. Notice that the following implementation has no reference to sparsity:
|
||
|
|
||
|
>>> import functools
|
||
|
>>> from sklearn.datasets import make_classification
|
||
|
>>> from jax.scipy import optimize
|
||
|
|
||
|
>>> def sigmoid(x):
|
||
|
... return 0.5 * (jnp.tanh(x / 2) + 1)
|
||
|
...
|
||
|
>>> def y_model(params, X):
|
||
|
... return sigmoid(jnp.dot(X, params[1:]) + params[0])
|
||
|
...
|
||
|
>>> def loss(params, X, y):
|
||
|
... y_hat = y_model(params, X)
|
||
|
... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
|
||
|
...
|
||
|
>>> def fit_logreg(X, y):
|
||
|
... params = jnp.zeros(X.shape[1] + 1)
|
||
|
... result = optimize.minimize(functools.partial(loss, X=X, y=y),
|
||
|
... x0=params, method='BFGS')
|
||
|
... return result.x
|
||
|
|
||
|
>>> X, y = make_classification(n_classes=2, random_state=1701)
|
||
|
>>> params_dense = fit_logreg(X, y)
|
||
|
>>> print(params_dense) # doctest: +SKIP
|
||
|
[-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008
|
||
|
-0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844
|
||
|
0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087
|
||
|
-0.67060554 0.03139788 -0.05359547]
|
||
|
|
||
|
This returns the best-fit parameters of a dense logistic regression problem.
|
||
|
To fit the same model on sparse data, we can apply the :func:`sparsify` transform:
|
||
|
|
||
|
>>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input
|
||
|
>>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function
|
||
|
>>> params_sparse = fit_logreg_sp(Xsp, y)
|
||
|
>>> print(params_sparse) # doctest: +SKIP
|
||
|
[-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566
|
||
|
-0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753
|
||
|
0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593
|
||
|
-0.670236 0.03132951 -0.05356663]
|
||
|
"""
|
||
|
|
||
|
# Note: import <name> as <name> is required for names to be exported.
|
||
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||
|
|
||
|
from jax.experimental.sparse.ad import (
|
||
|
jacfwd as jacfwd,
|
||
|
jacobian as jacobian,
|
||
|
jacrev as jacrev,
|
||
|
grad as grad,
|
||
|
value_and_grad as value_and_grad,
|
||
|
)
|
||
|
from jax.experimental.sparse.bcoo import (
|
||
|
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
|
||
|
bcoo_concatenate as bcoo_concatenate,
|
||
|
bcoo_conv_general_dilated as bcoo_conv_general_dilated,
|
||
|
bcoo_dot_general as bcoo_dot_general,
|
||
|
bcoo_dot_general_p as bcoo_dot_general_p,
|
||
|
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
|
||
|
bcoo_dot_general_sampled_p as bcoo_dot_general_sampled_p,
|
||
|
bcoo_dynamic_slice as bcoo_dynamic_slice,
|
||
|
bcoo_extract as bcoo_extract,
|
||
|
bcoo_extract_p as bcoo_extract_p,
|
||
|
bcoo_fromdense as bcoo_fromdense,
|
||
|
bcoo_fromdense_p as bcoo_fromdense_p,
|
||
|
bcoo_gather as bcoo_gather,
|
||
|
bcoo_multiply_dense as bcoo_multiply_dense,
|
||
|
bcoo_multiply_sparse as bcoo_multiply_sparse,
|
||
|
bcoo_update_layout as bcoo_update_layout,
|
||
|
bcoo_reduce_sum as bcoo_reduce_sum,
|
||
|
bcoo_reshape as bcoo_reshape,
|
||
|
bcoo_rev as bcoo_rev,
|
||
|
bcoo_slice as bcoo_slice,
|
||
|
bcoo_sort_indices as bcoo_sort_indices,
|
||
|
bcoo_sort_indices_p as bcoo_sort_indices_p,
|
||
|
bcoo_spdot_general_p as bcoo_spdot_general_p,
|
||
|
bcoo_squeeze as bcoo_squeeze,
|
||
|
bcoo_sum_duplicates as bcoo_sum_duplicates,
|
||
|
bcoo_sum_duplicates_p as bcoo_sum_duplicates_p,
|
||
|
bcoo_todense as bcoo_todense,
|
||
|
bcoo_todense_p as bcoo_todense_p,
|
||
|
bcoo_transpose as bcoo_transpose,
|
||
|
bcoo_transpose_p as bcoo_transpose_p,
|
||
|
BCOO as BCOO,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse.bcsr import (
|
||
|
bcsr_broadcast_in_dim as bcsr_broadcast_in_dim,
|
||
|
bcsr_concatenate as bcsr_concatenate,
|
||
|
bcsr_dot_general as bcsr_dot_general,
|
||
|
bcsr_dot_general_p as bcsr_dot_general_p,
|
||
|
bcsr_extract as bcsr_extract,
|
||
|
bcsr_extract_p as bcsr_extract_p,
|
||
|
bcsr_fromdense as bcsr_fromdense,
|
||
|
bcsr_fromdense_p as bcsr_fromdense_p,
|
||
|
bcsr_sum_duplicates as bcsr_sum_duplicates,
|
||
|
bcsr_todense as bcsr_todense,
|
||
|
bcsr_todense_p as bcsr_todense_p,
|
||
|
BCSR as BCSR,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse._base import (
|
||
|
JAXSparse as JAXSparse
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse.api import (
|
||
|
empty as empty,
|
||
|
eye as eye,
|
||
|
todense as todense,
|
||
|
todense_p as todense_p,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse.util import (
|
||
|
CuSparseEfficiencyWarning as CuSparseEfficiencyWarning,
|
||
|
SparseEfficiencyError as SparseEfficiencyError,
|
||
|
SparseEfficiencyWarning as SparseEfficiencyWarning,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse.coo import (
|
||
|
coo_fromdense as coo_fromdense,
|
||
|
coo_fromdense_p as coo_fromdense_p,
|
||
|
coo_matmat as coo_matmat,
|
||
|
coo_matmat_p as coo_matmat_p,
|
||
|
coo_matvec as coo_matvec,
|
||
|
coo_matvec_p as coo_matvec_p,
|
||
|
coo_todense as coo_todense,
|
||
|
coo_todense_p as coo_todense_p,
|
||
|
COO as COO,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse.csr import (
|
||
|
csr_fromdense as csr_fromdense,
|
||
|
csr_fromdense_p as csr_fromdense_p,
|
||
|
csr_matmat as csr_matmat,
|
||
|
csr_matmat_p as csr_matmat_p,
|
||
|
csr_matvec as csr_matvec,
|
||
|
csr_matvec_p as csr_matvec_p,
|
||
|
csr_todense as csr_todense,
|
||
|
csr_todense_p as csr_todense_p,
|
||
|
CSC as CSC,
|
||
|
CSR as CSR,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse.random import random_bcoo as random_bcoo
|
||
|
from jax.experimental.sparse.transform import (
|
||
|
sparsify as sparsify,
|
||
|
SparseTracer as SparseTracer,
|
||
|
)
|
||
|
|
||
|
from jax.experimental.sparse import linalg as linalg
|