
164 lines
6.5 KiB
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2021 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX primitives related to sparse operations.
This is experimental work to explore sparse support in JAX.
The primitives defined here are deliberately low-level: each primitive implements
a common sparse operation (sparse to dense, dense to sparse, sparse matrix/vector
product, sparse matrix/matrix product) for two common sparse representations
(CSR and COO).
These routines have reference implementations defined via XLA scatter/gather
operations that will work on any backend, although they are not particularly
performant. On GPU runtimes built against CUDA 11.0/ROCm 5.0 or newer, each operation is
computed efficiently via cusparse/hipsparse.
Further down are some examples of potential high-level wrappers for sparse objects.
(API should be considered unstable and subject to change).
from functools import partial
import operator
from typing import Optional, Union
import jax
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.bcoo import BCOO
from jax.experimental.sparse.bcsr import BCSR
from jax.experimental.sparse.coo import COO
from jax.experimental.sparse.csr import CSR, CSC
from jax.experimental.sparse.util import _coo_extract
from jax.interpreters import mlir
from jax._src import core
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.typing import Array, DTypeLike, Shape
# todense function to convert sparse matrices to dense while letting
# dense matrices pass through.
todense_p = core.Primitive('todense')
todense_p.multiple_results = False
def todense(arr: Union[JAXSparse, Array]) -> Array:
"""Convert input to a dense matrix. If input is already dense, pass through."""
bufs, tree = tree_util.tree_flatten(arr)
return todense_p.bind(*bufs, tree=tree)
def _todense_impl(*bufs, tree):
arr = tree_util.tree_unflatten(tree, bufs)
return arr.todense() if isinstance(arr, JAXSparse) else arr
def _todense_abstract_eval(*bufs, tree):
arr = tree_util.tree_unflatten(tree, bufs)
if isinstance(arr, core.ShapedArray):
return arr
return core.ShapedArray(arr.shape, arr.dtype, weak_type=dtypes.is_weakly_typed(
def _todense_jvp(primals, tangents, *, tree):
assert not isinstance(tangents[0], ad.Zero)
assert all(isinstance(t, ad.Zero) for t in tangents[1:])
primals_out = todense_p.bind(*primals, tree=tree)
tangents_out = todense_p.bind(tangents[0], *primals[1:], tree=tree)
return primals_out, tangents_out
def _todense_transpose(ct, *bufs, tree):
assert ad.is_undefined_primal(bufs[0])
assert not any(ad.is_undefined_primal(buf) for buf in bufs[1:])
standin = object()
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
from jax.experimental.sparse import BCOO, BCSR
from jax.experimental.sparse.bcoo import _bcoo_extract
from jax.experimental.sparse.bcsr import bcsr_extract
if obj is standin:
return (ct,)
elif isinstance(obj, BCOO):
_, indices = bufs
return _bcoo_extract(indices, ct), indices
elif isinstance(obj, BCSR):
_, indices, indptr = bufs
return bcsr_extract(indices, indptr, ct), indices, indptr
elif isinstance(obj, COO):
_, row, col = bufs
return _coo_extract(row, col, ct), row, col
raise NotImplementedError(f"todense_transpose for {type(obj)}")
def _todense_batching_rule(batched_args, batch_dims, *, tree):
return jax.vmap(partial(_todense_impl, tree=tree), batch_dims)(*batched_args), 0
ad.primitive_jvps[todense_p] = _todense_jvp
ad.primitive_transposes[todense_p] = _todense_transpose
batching.primitive_batchers[todense_p] = _todense_batching_rule
mlir.register_lowering(todense_p, mlir.lower_fun(
_todense_impl, multiple_results=False))
def empty(shape: Shape, dtype: Optional[DTypeLike]=None, index_dtype: DTypeLike = 'int32',
sparse_format: str = 'bcoo', **kwds) -> JAXSparse:
"""Create an empty sparse array.
shape: sequence of integers giving the array shape.
dtype: (optional) dtype of the array.
index_dtype: (optional) dtype of the index arrays.
format: string specifying the matrix format (e.g. ['bcoo']).
**kwds: additional keywords passed to the format-specific _empty constructor.
mat: empty sparse matrix.
formats = {'bcsr': BCSR, 'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
if sparse_format not in formats:
raise ValueError(f"sparse_format={sparse_format!r} not recognized; "
f"must be one of {list(formats.keys())}")
cls = formats[sparse_format]
return cls._empty(shape, dtype=dtype, index_dtype=index_dtype, **kwds)
def eye(N: int, M: Optional[int] = None, k: int = 0, dtype: Optional[DTypeLike] = None,
index_dtype: DTypeLike = 'int32', sparse_format: str = 'bcoo', **kwds) -> JAXSparse:
"""Create 2D sparse identity matrix.
N: int. Number of rows in the output.
M: int, optional. Number of columns in the output. If None, defaults to `N`.
k: int, optional. Index of the diagonal: 0 (the default) refers to the main
diagonal, a positive value refers to an upper diagonal, and a negative value
to a lower diagonal.
dtype: data-type, optional. Data-type of the returned array.
index_dtype: (optional) dtype of the index arrays.
format: string specifying the matrix format (e.g. ['bcoo']).
**kwds: additional keywords passed to the format-specific _empty constructor.
I: two-dimensional sparse matrix with ones along the k-th diagonal.
formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
if M is None:
M = N
N = core.concrete_or_error(operator.index, N)
M = core.concrete_or_error(operator.index, M)
k = core.concrete_or_error(operator.index, k)
cls = formats[sparse_format]
return cls._eye(M=M, N=N, k=k, dtype=dtype, index_dtype=index_dtype, **kwds)