90 lines
3.8 KiB
Python
90 lines
3.8 KiB
Python
![]() |
# Copyright 2021 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import math
|
||
|
import operator
|
||
|
|
||
|
from jax import dtypes
|
||
|
from jax import vmap
|
||
|
from jax import random
|
||
|
from jax.util import split_list
|
||
|
import jax.numpy as jnp
|
||
|
from jax.experimental import sparse
|
||
|
|
||
|
|
||
|
def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None,
|
||
|
nse=0.2, n_batch=0, n_dense=0, unique_indices=True,
|
||
|
sorted_indices=False, generator=random.uniform, **kwds):
|
||
|
"""Generate a random BCOO matrix.
|
||
|
|
||
|
Args:
|
||
|
key : random.PRNGKey to be passed to ``generator`` function.
|
||
|
shape : tuple specifying the shape of the array to be generated.
|
||
|
dtype : dtype of the array to be generated.
|
||
|
indices_dtype: dtype of the BCOO indicies.
|
||
|
nse : number of specified elements in the matrix, or if 0 < nse < 1, a
|
||
|
fraction of sparse dimensions to be specified (default: 0.2).
|
||
|
n_batch : number of batch dimensions. must satisfy ``n_batch >= 0`` and
|
||
|
``n_batch + n_dense <= len(shape)``.
|
||
|
n_dense : number of batch dimensions. must satisfy ``n_dense >= 0`` and
|
||
|
``n_batch + n_dense <= len(shape)``.
|
||
|
unique_indices : boolean specifying whether indices should be unique
|
||
|
(default: True).
|
||
|
sorted_indices : boolean specifying whether indices should be row-sorted in
|
||
|
lexicographical order (default: False).
|
||
|
generator : function for generating random values accepting a key, shape,
|
||
|
and dtype. It defaults to :func:`jax.random.uniform`, and may be any
|
||
|
function with a similar signature.
|
||
|
**kwds : additional keyword arguments to pass to ``generator``.
|
||
|
|
||
|
Returns:
|
||
|
arr : a sparse.BCOO array with the specified properties.
|
||
|
"""
|
||
|
shape = tuple(map(operator.index, shape))
|
||
|
n_batch = operator.index(n_batch)
|
||
|
n_dense = operator.index(n_dense)
|
||
|
if n_batch < 0 or n_dense < 0 or n_batch + n_dense > len(shape):
|
||
|
raise ValueError(f"Invalid {n_batch=}, {n_dense=} for {shape=}")
|
||
|
n_sparse = len(shape) - n_batch - n_dense
|
||
|
batch_shape, sparse_shape, dense_shape = map(tuple, split_list(shape, [n_batch, n_sparse]))
|
||
|
batch_size = math.prod(batch_shape)
|
||
|
sparse_size = math.prod(sparse_shape)
|
||
|
if not 0 <= nse < sparse_size:
|
||
|
raise ValueError(f"got {nse=}, expected to be between 0 and {sparse_size}")
|
||
|
if 0 < nse < 1:
|
||
|
nse = int(math.ceil(nse * sparse_size))
|
||
|
nse = operator.index(nse)
|
||
|
|
||
|
data_shape = batch_shape + (nse,) + dense_shape
|
||
|
indices_shape = batch_shape + (nse, n_sparse)
|
||
|
if indices_dtype is None:
|
||
|
indices_dtype = dtypes.canonicalize_dtype(jnp.int_)
|
||
|
if sparse_size > jnp.iinfo(indices_dtype).max:
|
||
|
raise ValueError(f"{indices_dtype=} does not have enough range to generate "
|
||
|
f"sparse indices of size {sparse_size}.")
|
||
|
@vmap
|
||
|
def _indices(key):
|
||
|
if not sparse_shape:
|
||
|
return jnp.empty((nse, n_sparse), dtype=indices_dtype)
|
||
|
flat_ind = random.choice(key, sparse_size, shape=(nse,),
|
||
|
replace=not unique_indices).astype(indices_dtype)
|
||
|
return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
|
||
|
|
||
|
keys = random.split(key, batch_size + 1)
|
||
|
data_key, index_keys = keys[0], keys[1:]
|
||
|
data = generator(data_key, shape=data_shape, dtype=dtype, **kwds)
|
||
|
indices = _indices(index_keys).reshape(indices_shape)
|
||
|
mat = sparse.BCOO((data, indices), shape=shape)
|
||
|
return mat.sort_indices() if sorted_indices else mat
|