627 lines
24 KiB
Python
627 lines
24 KiB
Python
|
# Copyright 2021 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
"""COO (coordinate format) matrix object and associated primitives."""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from functools import partial
|
||
|
import operator
|
||
|
from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import jax
|
||
|
from jax import lax
|
||
|
from jax.interpreters import mlir
|
||
|
from jax.experimental.sparse._base import JAXSparse
|
||
|
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
|
||
|
from jax import tree_util
|
||
|
from jax._src import core
|
||
|
from jax._src import dispatch
|
||
|
from jax._src.interpreters import ad
|
||
|
from jax._src.lax.lax import _const
|
||
|
from jax._src.lib.mlir.dialects import hlo
|
||
|
from jax._src.lib import gpu_sparse
|
||
|
from jax._src.numpy.util import promote_dtypes
|
||
|
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||
|
import jax.numpy as jnp
|
||
|
|
||
|
|
||
|
Dtype = Any
|
||
|
Shape = Tuple[int, ...]
|
||
|
|
||
|
class COOInfo(NamedTuple):
|
||
|
shape: Shape
|
||
|
rows_sorted: bool = False
|
||
|
cols_sorted: bool = False
|
||
|
|
||
|
|
||
|
@tree_util.register_pytree_node_class
|
||
|
class COO(JAXSparse):
|
||
|
"""Experimental COO matrix implemented in JAX.
|
||
|
|
||
|
Note: this class has minimal compatibility with JAX transforms such as
|
||
|
grad and autodiff, and offers very little functionality. In general you
|
||
|
should prefer :class:`jax.experimental.sparse.BCOO`.
|
||
|
|
||
|
Additionally, there are known failures in the case that `nse` is larger
|
||
|
than the true number of nonzeros in the represented matrix. This situation
|
||
|
is better handled in BCOO.
|
||
|
"""
|
||
|
data: jax.Array
|
||
|
row: jax.Array
|
||
|
col: jax.Array
|
||
|
shape: Tuple[int, int]
|
||
|
nse = property(lambda self: self.data.size)
|
||
|
dtype = property(lambda self: self.data.dtype)
|
||
|
_info = property(lambda self: COOInfo(
|
||
|
shape=self.shape, rows_sorted=self._rows_sorted,
|
||
|
cols_sorted=self._cols_sorted))
|
||
|
_bufs = property(lambda self: (self.data, self.row, self.col))
|
||
|
_rows_sorted: bool
|
||
|
_cols_sorted: bool
|
||
|
|
||
|
def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
|
||
|
rows_sorted: bool = False, cols_sorted: bool = False):
|
||
|
self.data, self.row, self.col = map(jnp.asarray, args)
|
||
|
self._rows_sorted = rows_sorted
|
||
|
self._cols_sorted = cols_sorted
|
||
|
super().__init__(args, shape=shape)
|
||
|
|
||
|
@classmethod
|
||
|
def fromdense(cls, mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLike = np.int32) -> COO:
|
||
|
return coo_fromdense(mat, nse=nse, index_dtype=index_dtype)
|
||
|
|
||
|
def _sort_indices(self) -> COO:
|
||
|
"""Return a copy of the COO matrix with sorted indices.
|
||
|
|
||
|
The matrix is sorted by row indices and column indices per row.
|
||
|
If self._rows_sorted is True, this returns ``self`` without a copy.
|
||
|
"""
|
||
|
# TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility?
|
||
|
if self._rows_sorted:
|
||
|
return self
|
||
|
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=2)
|
||
|
return self.__class__((data, row, col), shape=self.shape,
|
||
|
rows_sorted=True)
|
||
|
|
||
|
@classmethod
|
||
|
def _empty(cls, shape: Sequence[int], *, dtype: Optional[DTypeLike] = None,
|
||
|
index_dtype: DTypeLike = 'int32') -> COO:
|
||
|
"""Create an empty COO instance. Public method is sparse.empty()."""
|
||
|
shape = tuple(shape)
|
||
|
if len(shape) != 2:
|
||
|
raise ValueError(f"COO must have ndim=2; got {shape=}")
|
||
|
data = jnp.empty(0, dtype)
|
||
|
row = col = jnp.empty(0, index_dtype)
|
||
|
return cls((data, row, col), shape=shape, rows_sorted=True,
|
||
|
cols_sorted=True)
|
||
|
|
||
|
@classmethod
|
||
|
def _eye(cls, N: int, M: int, k: int, *, dtype: Optional[DTypeLike] = None,
|
||
|
index_dtype: DTypeLike = 'int32') -> COO:
|
||
|
if k > 0:
|
||
|
diag_size = min(N, M - k)
|
||
|
else:
|
||
|
diag_size = min(N + k, M)
|
||
|
|
||
|
if diag_size <= 0:
|
||
|
# if k is out of range, return an empty matrix.
|
||
|
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
|
||
|
|
||
|
data = jnp.ones(diag_size, dtype=dtype)
|
||
|
idx = jnp.arange(diag_size, dtype=index_dtype)
|
||
|
zero = _const(idx, 0)
|
||
|
k = _const(idx, k)
|
||
|
row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
|
||
|
col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k))
|
||
|
return cls((data, row, col), shape=(N, M), rows_sorted=True, cols_sorted=True)
|
||
|
|
||
|
def todense(self) -> Array:
|
||
|
return coo_todense(self)
|
||
|
|
||
|
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> COO:
|
||
|
if axes is not None:
|
||
|
raise NotImplementedError("axes argument to transpose()")
|
||
|
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
|
||
|
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)
|
||
|
|
||
|
def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
|
||
|
return (self.data, self.row, self.col), self._info._asdict()
|
||
|
|
||
|
@classmethod
|
||
|
def tree_unflatten(cls, aux_data, children):
|
||
|
obj = object.__new__(cls)
|
||
|
obj.data, obj.row, obj.col = children
|
||
|
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
|
||
|
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
|
||
|
obj.shape = aux_data['shape']
|
||
|
obj._rows_sorted = aux_data['rows_sorted']
|
||
|
obj._cols_sorted = aux_data['cols_sorted']
|
||
|
return obj
|
||
|
|
||
|
def __matmul__(self, other: ArrayLike) -> Array:
|
||
|
if isinstance(other, JAXSparse):
|
||
|
raise NotImplementedError("matmul between two sparse objects.")
|
||
|
other = jnp.asarray(other)
|
||
|
data, other = promote_dtypes(self.data, other)
|
||
|
self_promoted = COO((data, self.row, self.col), **self._info._asdict())
|
||
|
if other.ndim == 1:
|
||
|
return coo_matvec(self_promoted, other)
|
||
|
elif other.ndim == 2:
|
||
|
return coo_matmat(self_promoted, other)
|
||
|
else:
|
||
|
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
||
|
|
||
|
#--------------------------------------------------------------------
|
||
|
# coo_todense
|
||
|
|
||
|
coo_todense_p = core.Primitive('coo_todense')
|
||
|
|
||
|
def coo_todense(mat: COO) -> Array:
|
||
|
"""Convert a COO-format sparse matrix to a dense matrix.
|
||
|
|
||
|
Args:
|
||
|
mat : COO matrix
|
||
|
Returns:
|
||
|
mat_dense: dense version of ``mat``
|
||
|
"""
|
||
|
return _coo_todense(mat.data, mat.row, mat.col, spinfo=mat._info)
|
||
|
|
||
|
def _coo_todense(data: Array, row: Array, col: Array, *, spinfo: COOInfo) -> Array:
|
||
|
"""Convert CSR-format sparse matrix to a dense matrix.
|
||
|
|
||
|
Args:
|
||
|
data : array of shape ``(nse,)``.
|
||
|
row : array of shape ``(nse,)``
|
||
|
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
||
|
spinfo : COOInfo object containing matrix metadata
|
||
|
|
||
|
Returns:
|
||
|
mat : array with specified shape and dtype matching ``data``
|
||
|
"""
|
||
|
return coo_todense_p.bind(data, row, col, spinfo=spinfo)
|
||
|
|
||
|
def _coo_todense_impl(data, row, col, *, spinfo):
|
||
|
return jnp.zeros(spinfo.shape, data.dtype).at[row, col].add(data)
|
||
|
|
||
|
@coo_todense_p.def_abstract_eval
|
||
|
def _coo_todense_abstract_eval(data, row, col, *, spinfo):
|
||
|
return core.ShapedArray(spinfo.shape, data.dtype)
|
||
|
|
||
|
_coo_todense_lowering = mlir.lower_fun(
|
||
|
_coo_todense_impl, multiple_results=False)
|
||
|
|
||
|
def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
|
||
|
data_aval, row_aval, _ = ctx.avals_in
|
||
|
dtype = data_aval.dtype
|
||
|
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||
|
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for {dtype=}. "
|
||
|
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||
|
|
||
|
if spinfo.rows_sorted:
|
||
|
shape = spinfo.shape
|
||
|
transpose = False
|
||
|
elif spinfo.cols_sorted:
|
||
|
row, col = col, row
|
||
|
transpose = True
|
||
|
shape = spinfo.shape[::-1]
|
||
|
else:
|
||
|
warnings.warn("coo_todense GPU lowering requires matrices with sorted rows or sorted cols. "
|
||
|
"To sort the rows in your matrix, use e.g. mat = mat._sort_indices(). Falling "
|
||
|
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||
|
|
||
|
result = coo_todense_hlo(
|
||
|
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
|
||
|
return (
|
||
|
[hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
|
||
|
if transpose else [result])
|
||
|
|
||
|
|
||
|
def _coo_todense_jvp(data_dot, data, row, col, *, spinfo):
|
||
|
return _coo_todense(data_dot, row, col, spinfo=spinfo)
|
||
|
|
||
|
def _coo_todense_transpose(ct, data, row, col, *, spinfo):
|
||
|
# Note: we assume that transpose has the same sparsity pattern.
|
||
|
# Can we check this?
|
||
|
assert ad.is_undefined_primal(data)
|
||
|
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
|
||
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
||
|
assert ct.shape == spinfo.shape
|
||
|
assert row.aval.dtype == col.aval.dtype
|
||
|
assert ct.dtype == data.aval.dtype
|
||
|
return _coo_extract(row, col, ct), row, col
|
||
|
|
||
|
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
||
|
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||
|
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
|
||
|
dispatch.simple_impl(coo_todense_p)
|
||
|
|
||
|
if gpu_sparse.cuda_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_todense_p,
|
||
|
partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense),
|
||
|
platform='cuda')
|
||
|
if gpu_sparse.rocm_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_todense_p,
|
||
|
partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense),
|
||
|
platform='rocm')
|
||
|
|
||
|
#--------------------------------------------------------------------
|
||
|
# coo_fromdense
|
||
|
|
||
|
coo_fromdense_p = core.Primitive('coo_fromdense')
|
||
|
coo_fromdense_p.multiple_results = True
|
||
|
|
||
|
def coo_fromdense(mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLike = jnp.int32) -> COO:
|
||
|
"""Create a COO-format sparse matrix from a dense matrix.
|
||
|
|
||
|
Args:
|
||
|
mat : array to be converted to COO.
|
||
|
nse : number of specified entries in ``mat``. If not specified,
|
||
|
it will be computed from the input matrix.
|
||
|
index_dtype : dtype of sparse indices
|
||
|
|
||
|
Returns:
|
||
|
mat_coo : COO representation of the matrix.
|
||
|
"""
|
||
|
if nse is None:
|
||
|
nse = int((mat != 0).sum())
|
||
|
nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
|
||
|
return COO(_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype),
|
||
|
shape=mat.shape, rows_sorted=True)
|
||
|
|
||
|
def _coo_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
|
||
|
"""Create COO-format sparse matrix from a dense matrix.
|
||
|
|
||
|
Args:
|
||
|
mat : array to be converted to COO.
|
||
|
nse : number of specified entries in ``mat``
|
||
|
index_dtype : dtype of sparse indices
|
||
|
|
||
|
Returns:
|
||
|
data : array of shape ``(nse,)`` and dtype ``mat.dtype``
|
||
|
row : array of shape ``(nse,)`` and dtype ``index_dtype``
|
||
|
col : array of shape ``(nse,)`` and dtype ``index_dtype``
|
||
|
"""
|
||
|
mat = jnp.asarray(mat)
|
||
|
nse = core.concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()")
|
||
|
return coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype)
|
||
|
|
||
|
def _coo_fromdense_impl(mat, *, nse, index_dtype):
|
||
|
mat = jnp.asarray(mat)
|
||
|
assert mat.ndim == 2
|
||
|
|
||
|
row, col = jnp.nonzero(mat, size=nse)
|
||
|
data = mat[row, col]
|
||
|
|
||
|
true_nonzeros = jnp.arange(nse) < (mat != 0).sum()
|
||
|
data = jnp.where(true_nonzeros, data, 0)
|
||
|
|
||
|
return data, row.astype(index_dtype), col.astype(index_dtype)
|
||
|
|
||
|
@coo_fromdense_p.def_abstract_eval
|
||
|
def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||
|
data = core.ShapedArray((nse,), mat.dtype)
|
||
|
row = col = core.ShapedArray((nse,), index_dtype)
|
||
|
return data, row, col
|
||
|
|
||
|
_coo_fromdense_lowering = mlir.lower_fun(
|
||
|
_coo_fromdense_impl, multiple_results=True)
|
||
|
|
||
|
def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse,
|
||
|
index_dtype):
|
||
|
dtype = ctx.avals_in[0].dtype
|
||
|
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||
|
warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. "
|
||
|
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
|
||
|
data, row, col = coo_fromdense_hlo(
|
||
|
mat, nnz=nse,
|
||
|
data_dtype=dtype,
|
||
|
index_dtype=np.dtype(index_dtype),
|
||
|
index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
|
||
|
return [data, row, col]
|
||
|
|
||
|
|
||
|
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
|
||
|
M, = primals
|
||
|
Mdot, = tangents
|
||
|
|
||
|
primals_out = _coo_fromdense(M, nse=nse, index_dtype=index_dtype)
|
||
|
data, row, col = primals_out
|
||
|
|
||
|
if type(Mdot) is ad.Zero:
|
||
|
data_dot = ad.Zero.from_value(data)
|
||
|
else:
|
||
|
data_dot = _coo_extract(row, col, Mdot)
|
||
|
|
||
|
tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))
|
||
|
|
||
|
return primals_out, tangents_out
|
||
|
|
||
|
def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
|
||
|
data, row, col = ct
|
||
|
assert len(data) == nse
|
||
|
assert row.dtype == col.dtype == index_dtype
|
||
|
if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
|
||
|
raise ValueError("Cannot transpose with respect to sparse indices")
|
||
|
assert ad.is_undefined_primal(M)
|
||
|
return _coo_todense(data, row, col, spinfo=COOInfo(shape=M.aval.shape))
|
||
|
|
||
|
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
|
||
|
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
||
|
mlir.register_lowering(coo_fromdense_p, _coo_fromdense_lowering)
|
||
|
dispatch.simple_impl(coo_fromdense_p)
|
||
|
|
||
|
if gpu_sparse.cuda_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_fromdense_p,
|
||
|
partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense),
|
||
|
platform='cuda')
|
||
|
if gpu_sparse.rocm_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_fromdense_p,
|
||
|
partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense),
|
||
|
platform='rocm')
|
||
|
|
||
|
#--------------------------------------------------------------------
|
||
|
# coo_matvec
|
||
|
|
||
|
coo_matvec_p = core.Primitive('coo_matvec')
|
||
|
|
||
|
def coo_matvec(mat: COO, v: Array, transpose: bool = False) -> Array:
|
||
|
"""Product of COO sparse matrix and a dense vector.
|
||
|
|
||
|
Args:
|
||
|
mat : COO matrix
|
||
|
v : one-dimensional array of size ``(shape[0] if transpose else shape[1],)`` and
|
||
|
dtype ``mat.dtype``
|
||
|
transpose : boolean specifying whether to transpose the sparse matrix
|
||
|
before computing.
|
||
|
|
||
|
Returns:
|
||
|
y : array of shape ``(mat.shape[1] if transpose else mat.shape[0],)`` representing
|
||
|
the matrix vector product.
|
||
|
"""
|
||
|
data, row, col = mat._bufs
|
||
|
return _coo_matvec(data, row, col, v, spinfo=mat._info, transpose=transpose)
|
||
|
|
||
|
def _coo_matvec(data: Array, row: Array, col: Array, v: Array, *, spinfo: COOInfo, transpose: bool = False) -> Array:
|
||
|
"""Product of COO sparse matrix and a dense vector.
|
||
|
|
||
|
Args:
|
||
|
data : array of shape ``(nse,)``.
|
||
|
row : array of shape ``(nse,)``
|
||
|
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
||
|
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
|
||
|
dtype ``data.dtype``
|
||
|
shape : length-2 tuple representing the matrix shape
|
||
|
transpose : boolean specifying whether to transpose the sparse matrix
|
||
|
before computing.
|
||
|
|
||
|
Returns:
|
||
|
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
||
|
the matrix vector product.
|
||
|
"""
|
||
|
return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
def _coo_matvec_impl(data, row, col, v, *, spinfo, transpose):
|
||
|
v = jnp.asarray(v)
|
||
|
if transpose:
|
||
|
row, col = col, row
|
||
|
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
|
||
|
dv = data * v[col]
|
||
|
return jnp.zeros(out_shape, dv.dtype).at[row].add(dv)
|
||
|
|
||
|
@coo_matvec_p.def_abstract_eval
|
||
|
def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose):
|
||
|
assert data.shape == row.shape == col.shape
|
||
|
assert data.dtype == v.dtype
|
||
|
assert row.dtype == col.dtype
|
||
|
assert len(spinfo.shape) == 2
|
||
|
assert v.ndim == 1
|
||
|
assert v.shape[0] == (spinfo.shape[0] if transpose else spinfo.shape[1])
|
||
|
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
|
||
|
return core.ShapedArray((out_shape,), data.dtype)
|
||
|
|
||
|
_coo_matvec_lowering = mlir.lower_fun(
|
||
|
_coo_matvec_impl, multiple_results=False)
|
||
|
|
||
|
def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
|
||
|
transpose):
|
||
|
data_aval, row_aval, _, x_aval = ctx.avals_in
|
||
|
dtype = data_aval.dtype
|
||
|
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||
|
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for {dtype=}. "
|
||
|
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
if spinfo.rows_sorted:
|
||
|
shape = spinfo.shape
|
||
|
elif spinfo.cols_sorted:
|
||
|
row, col = col, row
|
||
|
transpose = not transpose
|
||
|
shape = spinfo.shape[::-1]
|
||
|
else:
|
||
|
warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. "
|
||
|
"To sort the rows in your matrix, use e.g. mat = mat._sort_indices(). Falling "
|
||
|
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
|
||
|
transpose=transpose)
|
||
|
|
||
|
return [coo_matvec_hlo(
|
||
|
data, row, col, v, shape=shape, transpose=transpose,
|
||
|
index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)]
|
||
|
|
||
|
|
||
|
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose):
|
||
|
return _coo_matvec(data_dot, row, col, v, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, spinfo, transpose):
|
||
|
return _coo_matvec(data, row, col, v_dot, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
|
||
|
assert not ad.is_undefined_primal(row)
|
||
|
assert not ad.is_undefined_primal(col)
|
||
|
|
||
|
if ad.is_undefined_primal(v):
|
||
|
return data, row, col, _coo_matvec(data, row, col, ct, spinfo=spinfo, transpose=not transpose)
|
||
|
else:
|
||
|
v = jnp.asarray(v)
|
||
|
# The following line does this, but more efficiently:
|
||
|
# return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
|
||
|
return ct[row] * v[col], row, col, v
|
||
|
|
||
|
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
||
|
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||
|
mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering)
|
||
|
dispatch.simple_impl(coo_matvec_p)
|
||
|
|
||
|
if gpu_sparse.cuda_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_matvec_p,
|
||
|
partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec),
|
||
|
platform='cuda')
|
||
|
if gpu_sparse.rocm_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_matvec_p,
|
||
|
partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec),
|
||
|
platform='rocm')
|
||
|
|
||
|
|
||
|
#--------------------------------------------------------------------
|
||
|
# coo_matmat
|
||
|
|
||
|
coo_matmat_p = core.Primitive('coo_matmat')
|
||
|
|
||
|
def coo_matmat(mat: COO, B: Array, *, transpose: bool = False) -> Array:
|
||
|
"""Product of COO sparse matrix and a dense matrix.
|
||
|
|
||
|
Args:
|
||
|
mat : COO matrix
|
||
|
B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
|
||
|
dtype ``mat.dtype``
|
||
|
transpose : boolean specifying whether to transpose the sparse matrix
|
||
|
before computing.
|
||
|
|
||
|
Returns:
|
||
|
C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
|
||
|
representing the matrix vector product.
|
||
|
"""
|
||
|
data, row, col = mat._bufs
|
||
|
return _coo_matmat(data, row, col, B, spinfo=mat._info, transpose=transpose)
|
||
|
|
||
|
def _coo_matmat(data: Array, row: Array, col: Array, B: Array, *, spinfo: COOInfo, transpose: bool = False) -> Array:
|
||
|
"""Product of COO sparse matrix and a dense matrix.
|
||
|
|
||
|
Args:
|
||
|
data : array of shape ``(nse,)``.
|
||
|
row : array of shape ``(nse,)``
|
||
|
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
||
|
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
||
|
dtype ``data.dtype``
|
||
|
shape : length-2 tuple representing the matrix shape
|
||
|
transpose : boolean specifying whether to transpose the sparse matrix
|
||
|
before computing.
|
||
|
|
||
|
Returns:
|
||
|
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
||
|
representing the matrix vector product.
|
||
|
"""
|
||
|
return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
def _coo_matmat_impl(data, row, col, B, *, spinfo, transpose):
|
||
|
B = jnp.asarray(B)
|
||
|
if transpose:
|
||
|
row, col = col, row
|
||
|
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
|
||
|
dB = data[:, None] * B[col]
|
||
|
return jnp.zeros((out_shape, B.shape[1]), dB.dtype).at[row].add(dB)
|
||
|
|
||
|
@coo_matmat_p.def_abstract_eval
|
||
|
def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose):
|
||
|
assert data.shape == row.shape == col.shape
|
||
|
assert data.dtype == B.dtype
|
||
|
assert B.ndim == 2
|
||
|
assert len(spinfo.shape) == 2
|
||
|
assert B.shape[0] == (spinfo.shape[0] if transpose else spinfo.shape[1])
|
||
|
out_shape = spinfo.shape[1] if transpose else spinfo.shape[0]
|
||
|
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||
|
|
||
|
_coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False)
|
||
|
|
||
|
def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
|
||
|
transpose):
|
||
|
data_aval, row_aval, _, B_aval = ctx.avals_in
|
||
|
dtype = data_aval.dtype
|
||
|
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||
|
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for {dtype=}. "
|
||
|
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
if spinfo.rows_sorted:
|
||
|
shape = spinfo.shape
|
||
|
elif spinfo.cols_sorted:
|
||
|
row, col = col, row
|
||
|
transpose = not transpose
|
||
|
shape = spinfo.shape[::-1]
|
||
|
else:
|
||
|
warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows or sorted cols. "
|
||
|
"To sort the rows in your matrix, use e.g. mat = mat._sort_indices(). Falling "
|
||
|
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||
|
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
|
||
|
transpose=transpose)
|
||
|
|
||
|
return [coo_matmat_hlo(data, row, col, B, shape=shape,
|
||
|
transpose=transpose, x_dtype=B_aval.dtype,
|
||
|
data_dtype=data_aval.dtype,
|
||
|
index_dtype=row_aval.dtype)]
|
||
|
|
||
|
|
||
|
def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose):
|
||
|
return _coo_matmat(data_dot, row, col, B, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
def _coo_matmat_jvp_right(B_dot, data, row, col, B, *, spinfo, transpose):
|
||
|
return _coo_matmat(data, row, col, B_dot, spinfo=spinfo, transpose=transpose)
|
||
|
|
||
|
def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
|
||
|
assert not ad.is_undefined_primal(row)
|
||
|
assert not ad.is_undefined_primal(col)
|
||
|
if ad.is_undefined_primal(B):
|
||
|
return data, row, col, _coo_matmat(data, row, col, ct, spinfo=spinfo, transpose=not transpose)
|
||
|
else:
|
||
|
B = jnp.asarray(B)
|
||
|
return (ct[row] * B[col]).sum(1), row, col, B
|
||
|
|
||
|
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
|
||
|
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
|
||
|
mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering)
|
||
|
dispatch.simple_impl(coo_matmat_p)
|
||
|
|
||
|
if gpu_sparse.cuda_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_matmat_p,
|
||
|
partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat),
|
||
|
platform='cuda')
|
||
|
if gpu_sparse.rocm_is_supported:
|
||
|
mlir.register_lowering(
|
||
|
coo_matmat_p,
|
||
|
partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat),
|
||
|
platform='rocm')
|