Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/sparse/bcsr.py
2023-06-19 00:49:18 +02:00

899 lines
36 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BCSR (Bached compressed row) matrix object and associated primitives."""
from __future__ import annotations
from functools import partial
import operator
import warnings
from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
import jax
import jax.numpy as jnp
from jax import config
from jax import lax
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import (
nfold_vmap, _count_stored_elements,
_csr_to_coo, _dot_general_validated_shape,
CuSparseEfficiencyWarning, SparseInfo, Shape)
from jax.util import split_list, safe_zip
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import hlo
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.typing import Array, ArrayLike, DTypeLike
def bcsr_eliminate_zeros(mat: BCSR, nse: Optional[int] = None) -> BCSR:
"""Eliminate zeros in BCSR representation."""
return BCSR.from_bcoo(bcoo.bcoo_eliminate_zeros(mat.to_bcoo(), nse=nse))
def bcsr_sum_duplicates(mat: BCSR, nse: Optional[int] = None) -> BCSR:
"""Sums duplicate indices within a BCSR array, returning an array with sorted indices.
Args:
mat : BCSR array
nse : integer (optional). The number of specified elements in the output matrix. This must
be specified for bcoo_sum_duplicates to be compatible with JIT and other JAX transformations.
If not specified, the optimal nse will be computed based on the contents of the data and
index arrays. If specified nse is larger than necessary, data and index arrays will be padded
with standard fill values. If smaller than necessary, data elements will be dropped from the
output matrix.
Returns:
mat_out : BCSR array with sorted indices and no duplicate indices.
"""
return BCSR.from_bcoo(bcoo.bcoo_sum_duplicates(mat.to_bcoo(), nse=nse))
def _bcsr_batch_dims_to_front(batched_args, batch_dims, spinfo, batch_size=None):
data, indices, indptr = batched_args
data_bdim, indices_bdim, indptr_bdim = batch_dims
n_batch = indices.ndim - 1 + int(indices_bdim is None)
if not all(b is None or 0 <= b < n_batch for b in batch_dims):
raise NotImplementedError("batch_dims must be None or satisfy 0 < dim < n_batch. "
f"Got {batch_dims=} for {n_batch=}.")
batched_data, batched_indices, batched_indptr = [
lax.expand_dims(arg, [0]) if bdim is None else jnp.moveaxis(arg, bdim, 0)
for arg, bdim in [(data, data_bdim), (indices, indices_bdim), (indptr, indptr_bdim)]]
if batch_size is None:
batch_size = max(arg.shape[dim] for arg, dim in zip(batched_args, batch_dims) if dim is not None)
batched_spinfo = SparseInfo((batch_size, *spinfo.shape),
indices_sorted=spinfo.indices_sorted,
unique_indices=spinfo.unique_indices)
return batched_data, batched_indices, batched_indptr, batched_spinfo
class BCSRProperties(NamedTuple):
n_batch: int
n_dense: int
nse: int
def _compatible(shape1: Sequence[int], shape2: Sequence[int]) -> bool:
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
def _validate_bcsr_indices(indices: jax.Array, indptr: jax.Array,
shape: Sequence[int]) -> BCSRProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
assert jnp.issubdtype(indptr.dtype, jnp.integer)
shape = tuple(shape)
nse = indices.shape[-1]
n_batch = indices.ndim - 1
n_dense = len(shape) - n_batch - 2
assert n_dense >= 0
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
raise ValueError(f"indices batch dimensions not compatible for {indices.shape=}, {shape=}")
if not _compatible(indptr.shape[:n_batch], shape[:n_batch]):
raise ValueError(f"indptr batch dimensions not compatible for {indptr.shape=}, {shape=}")
if indptr.shape[n_batch:] != (shape[n_batch] + 1,):
raise ValueError("indptr shape must match the matrix shape plus 1.")
return BCSRProperties(n_batch=n_batch, n_dense=n_dense, nse=nse)
def _validate_bcsr(data: jax.Array, indices: jax.Array,
indptr: jax.Array, shape: Sequence[int]) -> BCSRProperties:
props = _validate_bcsr_indices(indices, indptr, shape)
shape = tuple(shape)
n_batch, n_dense, nse = props.n_batch, props.n_dense, props.nse
n_sparse = len(shape) - n_batch - n_dense
if n_sparse != 2:
raise ValueError("BCSR array must have 2 sparse dimensions; "
f"{n_sparse} is given.")
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
raise ValueError(f"data batch dimensions not compatible for {data.shape=}, {shape=}")
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + 2:]:
raise ValueError(f"Invalid {data.shape=} for {nse=}, {n_batch=}, {n_dense=}")
return props
def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *,
shape: Sequence[int]) -> jax.Array:
"""Given BCSR (indices, indptr), return BCOO (indices)."""
n_batch, _, _ = _validate_bcsr_indices(indices, indptr, shape)
csr_to_coo = nfold_vmap(_csr_to_coo, n_batch)
return jnp.stack(csr_to_coo(indices, indptr), axis=indices.ndim)
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array]:
"""Given BCOO (indices), return BCSR (indices, indptr).
Note: this assumes that ``indices`` are lexicographically sorted within each batch.
"""
n_batch, n_sparse, _, _ = bcoo._validate_bcoo_indices(indices, shape)
if n_sparse != 2:
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")
n_rows = shape[n_batch]
@partial(nfold_vmap, N=n_batch, broadcasted=False)
def get_ptr(i):
indptr = jnp.zeros(n_rows + 1, index_dtype)
return indptr.at[1:].set(jnp.cumsum(
jnp.bincount(i, length=n_rows).astype(index_dtype)))
return indices[..., 1], get_ptr(indices[..., 0])
#--------------------------------------------------------------------
# bcsr_fromdense
bcsr_fromdense_p = core.Primitive('bcsr_fromdense')
bcsr_fromdense_p.multiple_results = True
_TRACED_NSE_ERROR = """
The error arose for the nse argument of bcsr_fromdense. In order for
BCSR.fromdense() to be used in traced/compiled code, you must pass a concrete
value to the nse (number of stored elements) argument.
"""
def bcsr_fromdense(mat: ArrayLike, *, nse: Optional[int] = None, n_batch: int = 0,
n_dense:int = 0, index_dtype: DTypeLike = jnp.int32) -> BCSR:
"""Create BCSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCOO.
nse : number of stored elements in each batch
n_batch : number of batch dimensions (default: 0)
n_dense : number of dense dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
mat_bcsr: BCSR representation of the matrix.
"""
mat = jnp.asarray(mat)
if nse is None:
nse = _count_stored_elements(mat, n_batch, n_dense)
nse_int: int = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return BCSR(_bcsr_fromdense(mat, nse=nse_int, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype),
shape=mat.shape)
def _bcsr_fromdense(mat: ArrayLike, *, nse: int, n_batch: int = 0, n_dense: int = 0,
index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
"""Create BCSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to BCSR, with
``ndim = n_batch + n_sparse + n_dense``.
nse : number of stored elements in each batch.
n_batch : number of batch dimensions (default: 0)
n_dense : number of dense dimensions (default: 0)
index_dtype : dtype of sparse indices (default: int32)
Returns:
data : array of shape
``mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]``
and dtype ``mat.dtype``
indices : array of shape ``mat.shape[:n_batch] + (nse,)`` and dtype of
``index_type``.
indptr: array of shape ``mat.shape[:n_batch] + (mat.shape[n_batch] + 1,)``
and dtype of ``index_type``.
"""
mat = jnp.asarray(mat)
nse = core.concrete_or_error(operator.index, nse, _TRACED_NSE_ERROR)
return bcsr_fromdense_p.bind(mat, nse=nse, n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype)
@bcsr_fromdense_p.def_impl
def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
mat = jnp.asarray(mat)
n_sparse = mat.ndim - n_dense - n_batch
if n_sparse != 2:
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)
indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
return bcoo_mat.data, indices, indptr
@bcsr_fromdense_p.def_abstract_eval
def _bcsr_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
n_sparse = mat.ndim - n_batch - n_dense
if n_sparse != 2:
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
data_shape = mat.shape[:n_batch] + (nse,) + mat.shape[n_batch + n_sparse:]
index_shape = mat.shape[:n_batch] + (nse,)
indptr_shape = mat.shape[:n_batch] + (mat.shape[n_batch] + 1,)
return (core.ShapedArray(data_shape, mat.dtype),
core.ShapedArray(index_shape, index_dtype),
core.ShapedArray(indptr_shape, index_dtype))
def _bcsr_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch,
n_dense, index_dtype):
M, = batched_args
bdim, = batch_dims
if not (0 <= bdim <= n_batch):
raise ValueError(f"Expected 0 < bdim <= n_batch; got {bdim=}, {n_batch=}")
return _bcsr_fromdense(M, nse=nse, n_batch=n_batch + 1, n_dense=n_dense, index_dtype=index_dtype), (bdim, bdim, bdim)
def _bcsr_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype):
M, = primals
Mdot, = tangents
primals_out = _bcsr_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense, index_dtype=index_dtype)
data, indices, indptr = primals_out
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = bcsr_extract(indices, indptr, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr))
return primals_out, tangents_out
def _bcsr_fromdense_transpose(ct, M, *, nse, n_batch, n_dense, index_dtype):
data, indices, indptr = ct
n_sparse = M.ndim - n_batch - n_dense
assert data.shape == M.shape[:n_batch] + (nse,) + M.shape[n_batch + n_sparse:]
assert indices.shape == M.shape[:n_batch] + (n_sparse, nse)
assert indptr.shape == M.shape[:n_batch] + (M.shape[n_batch] + 1,)
assert indices.dtype == index_dtype
assert indptr.dtype == index_dtype
if isinstance(indices, ad.Zero) or isinstance(indptr, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ad.is_undefined_primal(M)
return _bcsr_todense(data, indices, indptr, spinfo=SparseInfo(M.aval.shape))
ad.primitive_jvps[bcsr_fromdense_p] = _bcsr_fromdense_jvp
ad.primitive_transposes[bcsr_fromdense_p] = _bcsr_fromdense_transpose
batching.primitive_batchers[bcsr_fromdense_p] = _bcsr_fromdense_batching_rule
mlir.register_lowering(bcsr_fromdense_p, mlir.lower_fun(
_bcsr_fromdense_impl, multiple_results=True))
#----------------------------------------------------------------------
# bcsr_todense
bcsr_todense_p = core.Primitive('bcsr_todense')
def bcsr_todense(mat: BCSR) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
mat: BCSR matrix.
Returns:
The dense version of ``mat``.
"""
return _bcsr_todense(mat.data, mat.indices, mat.indptr, spinfo=mat._info)
def _bcsr_todense(data: ArrayLike, indices: ArrayLike, indptr: ArrayLike, *,
spinfo: SparseInfo) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
data : array of shape ``batch_dims + (nse,) + dense_dims``.
indices : array of shape ``batch_dims + (nse,)``.
indptr : array of shape ``batch_dims + (shape[len(batch_dims)] + 1,).
spinfo : SparseInfo. In particular, this includes the shape
of the matrix, which is equal to
``batch_dims + 2(sparse_dims) + block_dims`` where
``len(sparse_dims) == 2``.
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return bcsr_todense_p.bind(jnp.asarray(data), jnp.asarray(indices),
jnp.asarray(indptr), spinfo=spinfo)
@bcsr_todense_p.def_impl
def _bcsr_todense_impl(data, indices, indptr, *, spinfo):
shape = spinfo.shape
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=shape)
return (bcoo.BCOO((data, bcoo_indices), shape=shape)).todense()
@bcsr_todense_p.def_abstract_eval
def _bcsr_todense_abstract_eval(data, indices, indptr, *, spinfo):
shape = spinfo.shape
_validate_bcsr(data, indices, indptr, shape)
return core.ShapedArray(shape, data.dtype)
def _bcsr_todense_batching_rule(batched_args, batch_dims, *, spinfo):
data, indices, indptr, spinfo = _bcsr_batch_dims_to_front(batched_args, batch_dims, spinfo)
return _bcsr_todense(data, indices, indptr, spinfo=spinfo), 0
def _bcsr_todense_jvp(data_dot, data, indices, indptr, *, spinfo):
del data
return _bcsr_todense(data_dot, indices, indptr, spinfo=spinfo)
def _bcsr_todense_transpose(ct, data, indices, indptr, *, spinfo):
shape = spinfo.shape
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert ct.dtype == data.aval.dtype
return bcsr_extract(indices, indptr, ct), indices, indptr
ad.defjvp(bcsr_todense_p, _bcsr_todense_jvp, None, None)
ad.primitive_transposes[bcsr_todense_p] = _bcsr_todense_transpose
batching.primitive_batchers[bcsr_todense_p] = _bcsr_todense_batching_rule
mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
_bcsr_todense_impl, multiple_results=False))
#--------------------------------------------------------------------
# bcsr_extract
bcsr_extract_p = core.Primitive('bcsr_extract')
def bcsr_extract(indices: ArrayLike, indptr: ArrayLike, mat: ArrayLike) -> Array:
"""Extract values from a dense matrix at given BCSR (indices, indptr).
Args:
indices: An ndarray; see BCSR indices.
indptr: An ndarray; see BCSR indptr.
mat: A dense matrix.
Returns:
An ndarray; see BCSR data.
"""
return bcsr_extract_p.bind(indices, indptr, mat)
@bcsr_extract_p.def_impl
def _bcsr_extract_impl(indices, indptr, mat):
mat = jnp.asarray(mat)
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=mat.shape)
return bcoo._bcoo_extract(bcoo_indices, mat)
@bcsr_extract_p.def_abstract_eval
def _bcsr_extract_abstract_eval(indices, indptr, mat):
n_batch, n_dense, nse = _validate_bcsr_indices(indices, indptr, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
def _bcsr_extract_jvp(arr_dot, indices, indptr, arr):
assert arr_dot.shape == arr.shape
return bcsr_extract(indices, indptr, arr_dot)
def _bcsr_extract_transpose(ct, indices, indptr, arr):
assert ad.is_undefined_primal(arr)
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.dtype == arr.aval.dtype
return indices, indptr, _bcsr_todense(ct, indices, indptr, spinfo=SparseInfo(arr.aval.shape))
def _bcsr_extract_batching_rule(batched_args, batch_dims):
indices, indptr, arr = batched_args
bdim_set = {b for b in batch_dims if b is not None}
if len(bdim_set) != 1:
# TODO(jakevdp): handle this by moving bdim to front?
raise NotImplementedError("bcoo_extract with unequal batch dimensions.")
bdim = next(iter(bdim_set))
if batch_dims[0] is None:
indices = lax.expand_dims(indices, (bdim,))
if batch_dims[1] is None:
indptr = lax.expand_dims(indptr, (bdim,))
if batch_dims[2] is None:
# TODO(jakevdp) can we handle this case without explicit broadcasting?
result_shape = list(arr.shape)
result_shape.insert(bdim, indices.shape[bdim])
arr = lax.broadcast_in_dim(arr, result_shape, (bdim,))
n_batch = indices.ndim - 1
if bdim >= n_batch:
raise ValueError(f"{batch_dims=} out of range for indices with {n_batch=}")
return bcsr_extract(indices, indptr, arr), bdim
ad.defjvp(bcsr_extract_p, None, None, _bcsr_extract_jvp)
ad.primitive_transposes[bcsr_extract_p] = _bcsr_extract_transpose
batching.primitive_batchers[bcsr_extract_p] = _bcsr_extract_batching_rule
mlir.register_lowering(bcsr_extract_p, mlir.lower_fun(
_bcsr_extract_impl, multiple_results=False))
#----------------------------------------------------------------------
# bcsr_dot_general
bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None) -> Array:
"""A general contraction operation.
Args:
lhs: An ndarray or BCSR-format sparse array.
rhs: An ndarray or BCSR-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
Returns:
An ndarray or BCSR-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision, preferred_element_type # unused
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
if isinstance(lhs, BCSR):
lhs_data, lhs_indices, lhs_indptr = lhs._bufs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
raise NotImplementedError("bcsr_dot_general currently implemented for BCSR "
"lhs and ndarray rhs.")
def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
lhs_indptr: jax.Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
lhs_spinfo: SparseInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcsr_dot_general_p.bind(jnp.asarray(lhs_data),
jnp.asarray(lhs_indices),
jnp.asarray(lhs_indptr), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
lhs_data = jnp.asarray(lhs_data)
lhs_bcsr_indices = jnp.asarray(lhs_indices)
lhs_bcsr_indptr = jnp.asarray(lhs_indptr)
rhs = jnp.asarray(rhs)
lhs_bcoo_indices = _bcsr_to_bcoo(lhs_bcsr_indices, lhs_bcsr_indptr,
shape=lhs_spinfo.shape)
return bcoo._bcoo_dot_general_impl(lhs_data, lhs_bcoo_indices, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
@bcsr_dot_general_p.def_abstract_eval
def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcsr_dot_general requires arguments to have matching "
f"dtypes; got lhs.dtype={lhs_data.dtype}, "
f"rhs.dtype={rhs.dtype}")
(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
props = _validate_bcsr_indices(lhs_indices, lhs_indptr, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
dimension_numbers)
if lhs_batch and max(lhs_batch) >= props.n_batch:
raise NotImplementedError(
"bcsr_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
f"got {lhs_batch=}, {props.n_batch=}")
# TODO: support contraction of dense dimensions?
if any(d >= props.n_batch + 2 for d in lhs_contracting):
raise NotImplementedError("bcsr_dot_general: contracting over dense dimensions.")
return core.ShapedArray(out_shape, lhs_data.dtype)
def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
del lhs_data
return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
del rhs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
# TODO(jakevdp): implement this in terms of bcsr_dot_general
lhs_bcoo_indices = _bcsr_to_bcoo(
lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
data_out, _, rhs_out = bcoo._bcoo_dot_general_transpose(
ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)
return data_out, lhs_indices, lhs_indptr, rhs_out
def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
dimension_numbers, lhs_spinfo):
*lhs_args, rhs = batched_args
*lhs_dims, rhs_bdim = batch_dims
*new_lhs_args, new_lhs_spinfo = _bcsr_batch_dims_to_front(
lhs_args, lhs_dims, lhs_spinfo,
batch_size=None if rhs_bdim is None else rhs.shape[rhs_bdim])
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(len(lhs_spinfo.shape), rhs.ndim), (0, rhs_bdim), dimension_numbers)
batched_out = _bcsr_dot_general(*new_lhs_args, rhs, lhs_spinfo=new_lhs_spinfo,
dimension_numbers=new_dimension_numbers)
return batched_out, result_batch_dim
ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None, None,
_bcsr_dot_general_jvp_rhs)
ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule
def _bcsr_correct_out_of_bound_indices(data, indices, indptr, rhs, *, shape):
props = _validate_bcsr(data, indices, indptr, shape)
if props.n_batch:
f = partial(_bcsr_correct_out_of_bound_indices, rhs=rhs, shape=shape[props.n_batch:])
return nfold_vmap(f, props.n_batch)(data, indices, indptr)
extent = indptr[-1]
i_data = lax.broadcasted_iota(indptr.dtype, data.shape, 0)
data = jnp.where(i_data < extent, data, 0)
i_indices = lax.broadcasted_iota(indptr.dtype, indices.shape, 0)
indices = jnp.where(i_indices < extent, indices, 0)
return [data, indices]
_bcsr_correct_out_of_bound_indices_lowered = mlir.lower_fun(
_bcsr_correct_out_of_bound_indices, multiple_results=True)
def _bcsr_dot_general_gpu_lowering(
csr_matvec_lowering, csr_matmat_lowering,
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
lhs_spinfo: SparseInfo):
if not config.jax_bcoo_cusparse_lowering:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in
props = _validate_bcsr(
lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, lhs_spinfo.shape)
use_default_lowering = False
dtype = lhs_data_aval.dtype
# TODO(vanderplas, tianjianlu): lower batched matmuls to GPU
if lhs_batch or rhs_batch:
# batch dimensions in dot_general are not supported
use_default_lowering = True
elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]:
# only matmat / matvec supported
use_default_lowering = True
elif props.n_batch or props.n_dense:
# batch and dense dimensions in BCSR not supported
use_default_lowering = True
elif list(lhs_contract) != [1]:
# cusparse cannot contract over more than one dimension
use_default_lowering = True
elif dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
# This would be supported if not for the dtype.
warnings.warn(f'bcsr_dot_general cusparse/hipsparse lowering not available '
f'for {dtype=}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
use_default_lowering = True
if use_default_lowering:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
# Account for a bug in cusparse: it references indices and data beyond
# the extent of indptr.
(lhs_data,), (lhs_indices,) = _bcsr_correct_out_of_bound_indices_lowered(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape)
if rhs_aval.ndim == 1:
dot_general_fn = csr_matvec_lowering
x_dtype = 'x_dtype'
elif rhs_aval.ndim == 2:
dot_general_fn = csr_matmat_lowering
x_dtype = 'B_dtype'
if rhs_contract[0] == 1:
rhs = hlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")
return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs,
shape=lhs_spinfo.shape, transpose=False,
data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
**{x_dtype: rhs_aval.dtype})]
_bcsr_dot_general_default_lowering = mlir.lower_fun(
_bcsr_dot_general_impl, multiple_results=False)
mlir.register_lowering(
bcsr_dot_general_p, _bcsr_dot_general_default_lowering)
dispatch.simple_impl(bcsr_dot_general_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(bcsr_dot_general_p,
partial(_bcsr_dot_general_gpu_lowering,
gpu_sparse.cuda_csr_matvec,
gpu_sparse.cuda_csr_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(bcsr_dot_general_p,
partial(_bcsr_dot_general_gpu_lowering,
gpu_sparse.rocm_csr_matvec,
gpu_sparse.rocm_csr_matmat),
platform='rocm')
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?
def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequence[int]) -> BCSR:
result_bcoo = bcoo.bcoo_broadcast_in_dim(
mat.to_bcoo(), shape=shape, broadcast_dimensions=broadcast_dimensions)
return BCSR.from_bcoo(result_bcoo)
def bcsr_concatenate(operands: Sequence[BCSR], *, dimension: int) -> BCSR:
"""Sparse implementation of :func:`jax.lax.concatenate`
Args:
operands : Sequence of BCSR arrays to concatenate. The arrays must have equal
shapes, except in the `dimension` axis. Additionally, the arrays must have
have equivalent batch, sparse, and dense dimensions.
dimension : Positive integer specifying the dimension along which to concatenate
the arrays. The dimension must be among batch or sparse dimensions of the input;
concatenation along dense dimensions is not supported.
Returns:
A BCSR array containing the concatenation of the inputs.
"""
return BCSR.from_bcoo(
bcoo.bcoo_concatenate([mat.to_bcoo() for mat in operands], dimension=dimension))
@tree_util.register_pytree_node_class
class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""
data: jax.Array
indices: jax.Array
indptr: jax.Array
shape: Shape
nse = property(lambda self: self.indices.shape[-1])
dtype = property(lambda self: self.data.dtype)
n_batch = property(lambda self: self.indices.ndim - 1)
n_sparse = property(lambda _: 2)
n_dense = property(lambda self: self.data.ndim - self.indices.ndim)
indices_sorted: bool
unique_indices: bool
_bufs = property(lambda self: (self.data, self.indices, self.indptr))
_info = property(lambda self: SparseInfo(self.shape, self.indices_sorted,
self.unique_indices))
@property
def _sparse_shape(self):
return tuple(self.shape[self.n_batch:self.n_batch + 2])
def __init__(self, args: Tuple[Array, Array, Array], *, shape: Sequence[int],
indices_sorted: bool = False, unique_indices: bool = False):
self.data, self.indices, self.indptr = map(jnp.asarray, args)
self.indices_sorted = indices_sorted
self.unique_indices = unique_indices
super().__init__(args, shape=shape)
_validate_bcsr(self.data, self.indices, self.indptr, self.shape)
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
n_batch = self.n_batch
n_dense = self.n_dense
dtype = self.dtype
shape = list(self.shape)
except Exception: # pylint: disable=broad-except
repr_ = f"{name}(<invalid>)"
else:
extra = f", {nse=}"
if n_batch: extra += f", {n_batch=}"
if n_dense: extra += f", {n_dense=}"
repr_ = f"{name}({dtype}{shape}{extra})"
if isinstance(self.data, core.Tracer):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_
def transpose(self, *args, **kwargs):
raise NotImplementedError("Tranpose is not implemented.")
def tree_flatten(self):
return (self.data, self.indices, self.indptr), self._info._asdict()
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape', 'indices_sorted', 'unique_indices'}:
raise ValueError(f"BCSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj
@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0,
n_batch=0, nse=0):
"""Create an empty BCSR instance. Public method is sparse.empty()."""
shape = tuple(shape)
if n_dense < 0 or n_batch < 0 or nse < 0:
raise ValueError(f"Invalid inputs: {shape=}, {n_dense=}, {n_batch=}, {nse=}")
n_sparse = len(shape) - n_dense - n_batch
if n_sparse != 2:
raise ValueError("BCSR sparse.empty: must have 2 sparse dimensions.")
batch_shape, sparse_shape, dense_shape = split_list(shape,
[n_batch, n_sparse])
data = jnp.zeros((*batch_shape, nse, *dense_shape), dtype)
indices = jnp.full((*batch_shape, nse), jnp.array(sparse_shape[1]),
index_dtype)
indptr = jnp.zeros((*batch_shape, sparse_shape[0] + 1), index_dtype)
return cls((data, indices, indptr), shape=shape)
def sum_duplicates(self, nse: Optional[int] = None, remove_zeros: bool = True) -> BCSR:
"""Return a copy of the array with duplicate indices summed.
Additionally, this operation will result in explicit zero entries removed, and
indices being sorted in lexicographic order.
Because the size of the resulting representation depends on the values in the
arrays, this operation is not compatible with JIT or other transforms. To use
``sum_duplicates`` in such cases, you may pass a value to `nse` to specify the
desired size of the output representation.
Args:
nse : integer (optional), if specified, gives the number of specified elements in
the output sparse representation; if it is larger than the number required, data
will be padded with zeros and indices will be padded with out-of-bounds values.
If it is smaller than the number required, data will be silently discarded.
remove_zeros : bool (default=True). If True, remove explicit zeros from the data
as part of summing duplicates. If False, then explicit zeros at unique indices
will remain among the specified elements. Note: remove_zeros=True is incompatible
with autodiff.
"""
if remove_zeros:
return bcsr_eliminate_zeros(self, nse=nse)
else:
return bcsr_sum_duplicates(self, nse=nse)
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0,
n_batch=0):
"""Create a BCSR array from a (dense) :class:`DeviceArray`."""
return bcsr_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)
def todense(self):
"""Create a dense version of the array."""
return bcsr_todense(self)
def to_bcoo(self) -> bcoo.BCOO:
coo_indices = _bcsr_to_bcoo(self.indices, self.indptr, shape=self.shape)
return bcoo.BCOO((self.data, coo_indices), shape=self.shape)
@classmethod
def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR:
if arr.n_sparse != 2:
raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}")
if not arr.indices_sorted:
arr = arr.sort_indices()
indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape)
return cls((arr.data, indices, indptr), shape=arr.shape)
@classmethod
def from_scipy_sparse(cls, mat, *, index_dtype=None, n_dense=0, n_batch=0):
"""Create a BCSR array from a :mod:`scipy.sparse` array."""
if n_dense != 0 or n_batch != 0:
raise NotImplementedError("BCSR from_scipy_sparse with nonzero n_dense/n_batch.")
if mat.ndim != 2:
raise ValueError(f"BCSR from_scipy_sparse requires 2D array; {mat.ndim}D is given.")
mat = mat.tocsr()
data = jnp.asarray(mat.data)
indices = jnp.asarray(mat.indices).astype(index_dtype or jnp.int32)
indptr = jnp.asarray(mat.indptr).astype(index_dtype or jnp.int32)
return cls((data, indices, indptr), shape=mat.shape)
#--------------------------------------------------------------------
# vmappable handlers
def _bcsr_to_elt(cont, _, val, axis):
if axis is None:
return val
if axis >= val.n_batch:
raise ValueError(f"Cannot map in_axis={axis} for BCSR array with n_batch="
f"{val.n_batch}. in_axes for batched BCSR operations must "
"correspond to a batched dimension.")
return BCSR((cont(val.data, axis),
cont(val.indices, axis),
cont(val.indptr, axis)),
shape=val.shape[:axis] + val.shape[axis + 1:])
def _bcsr_from_elt(cont, axis_size, elt, axis):
if axis is None:
return elt
if axis > elt.n_batch:
raise ValueError(f"BCSR: cannot add out_axis={axis} for BCSR array with "
f"n_batch={elt.n_batch}. BCSR batch axes must be a "
"contiguous block of leading dimensions.")
return BCSR((cont(axis_size, elt.data, axis),
cont(axis_size, elt.indices, axis),
cont(axis_size, elt.indptr, axis)),
shape=elt.shape[:axis] + (axis_size,) + elt.shape[axis:])
batching.register_vmappable(BCSR, int, int, _bcsr_to_elt, _bcsr_from_elt, None)