2023-06-19 00:49:18 +02:00
# Copyright 2019 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# Helpers for indexed updates.
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings
import numpy as np
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy.util import check_arraylike, promote_dtypes
Array = Any
if sys.version_info >= (3, 10):
from types import EllipsisType
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
SingleIndex = Union[None, int, slice, Sequence[int], Array]
Index = Union[SingleIndex, Tuple[SingleIndex, ...]]
Scalar = Union[complex, float, int, np.number]
def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
unique_indices, mode=None, normalize_indices=True):
"""Helper for indexed updates.
Computes the value of x that would result from computing::
x[idx] op= y
except in a pure functional way, with no in-place updating.
x: ndarray to be updated.
idx: None, an integer, a slice, an ellipsis, an ndarray with integer dtype,
or a tuple of those indicating the locations of `x` into which to scatter-
update the values in `y`.
y: values to be scattered.
scatter_op: callable, one of lax.scatter, lax.scatter_add, lax.scatter_min,
or lax_scatter_max.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
An ndarray representing an updated `x` after performing the scatter-update.
x = jnp.asarray(x)
if (isinstance(y, int) and np.issubdtype(x.dtype, np.integer) and
np.iinfo(x.dtype).min <= y <= np.iinfo(x.dtype).max):
y = jnp.asarray(y, dtype=x.dtype)
y = jnp.asarray(y)
# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted, unique_indices, mode,
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(2, 3, 4))
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted, unique_indices, mode,
dtype = lax.dtype(x)
weak_type = dtypes.is_weakly_typed(x)
if dtype != lax.dtype(y) and dtype != dtypes.result_type(x, y):
# TODO(jakevdp): change this to an error after the deprecation period.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
"In future JAX releases this will result in an error.",
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
return x
x, y = promote_dtypes(x, y)
# Broadcast `y` to the slice output shape.
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
# Collapse any `None`/`jnp.newaxis` dimensions.
y = jnp.squeeze(y, axis=indexer.newaxis_dims)
if indexer.reversed_y_dims:
y = lax.rev(y, indexer.reversed_y_dims)
# Transpose the gather dimensions into scatter dimensions (cf.
# lax._gather_transpose_rule)
dnums = lax.ScatterDimensionNumbers(
out = scatter_op(
x, indexer.gather_indices, y, dnums,
indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted,
unique_indices=indexer.unique_indices or unique_indices,
return lax_internal._convert_element_type(out, dtype, weak_type)
def _get_identity(op, dtype):
"""Get an appropriate identity for a given operation in a given dtype."""
if op is lax.scatter_add:
return 0
elif op is lax.scatter_mul:
return 1
elif op is lax.scatter_min:
if dtype == dtypes.bool_:
return True
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).max
return float('inf')
elif op is lax.scatter_max:
if dtype == dtypes.bool_:
return False
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).min
return -float('inf')
raise ValueError(f"Unrecognized op: {op}")
def _segment_update(name: str,
data: Array,
segment_ids: Array,
scatter_op: Callable,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
reducer: Optional[Callable] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
check_arraylike(name, data, segment_ids)
mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
data = jnp.asarray(data)
segment_ids = jnp.asarray(segment_ids)
dtype = data.dtype
if num_segments is None:
num_segments = np.max(segment_ids) + 1
num_segments = core.concrete_or_error(int, num_segments, "segment_sum() `num_segments` argument.")
if num_segments is not None and num_segments < 0:
raise ValueError("num_segments must be non-negative.")
if bucket_size is None:
out = jnp.full((num_segments,) + data.shape[1:],
_get_identity(scatter_op, dtype), dtype=dtype)
return _scatter_update(
out, segment_ids, data, scatter_op, indices_are_sorted,
unique_indices, normalize_indices=False, mode=mode)
# Bucketize indices and perform segment_update on each bucket to improve
# numerical stability for operations like product and sum.
assert reducer is not None
num_buckets = util.ceil_of_ratio(segment_ids.size, bucket_size)
out = jnp.full((num_buckets, num_segments) + data.shape[1:],
_get_identity(scatter_op, dtype), dtype=dtype)
out = _scatter_update(
out, np.index_exp[jnp.arange(segment_ids.shape[0]) // bucket_size,
segment_ids[None, :]],
data, scatter_op, indices_are_sorted,
unique_indices, normalize_indices=False, mode=mode)
return reducer(out, axis=0).astype(dtype)
def segment_sum(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the sum within segments of an array.
Similar to TensorFlow's `segment_sum
data: an array with the values to be summed.
segment_ids: an array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted.
num_segments: optional, an int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_sum`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
Simple 1D segment sum:
>>> data = jnp.arange(5)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2])
>>> segment_sum(data, segment_ids)
Array([1, 5, 4], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
Array([1, 5, 4], dtype=int32)
return _segment_update(
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode)
def segment_prod(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the product within segments of an array.
Similar to TensorFlow's `segment_prod
data: an array with the values to be reduced.
segment_ids: an array with integer dtype that indicates the segments of
`data` (along its leading axis) to be reduced. Values can be repeated and
need not be sorted. Values outside of the range [0, num_segments) are
dropped and do not contribute to the result.
num_segments: optional, an int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_prod`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_prod`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment products.
Simple 1D segment product:
>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_prod(data, segment_ids)
Array([ 0, 6, 20], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
Array([ 0, 6, 20], dtype=int32)
return _segment_update(
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
indices_are_sorted, unique_indices, bucket_size,, mode=mode)
def segment_max(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the maximum within segments of an array.
Similar to TensorFlow's `segment_max
data: an array with the values to be reduced.
segment_ids: an array with integer dtype that indicates the segments of
`data` (along its leading axis) to be reduced. Values can be repeated and
need not be sorted. Values outside of the range [0, num_segments) are
dropped and do not contribute to the result.
num_segments: optional, an int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_max`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_max`` is
performed on each bucket separately. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment maximums.
Simple 1D segment max:
>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids)
Array([1, 3, 5], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
Array([1, 3, 5], dtype=int32)
return _segment_update(
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode)
def segment_min(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the minimum within segments of an array.
Similar to TensorFlow's `segment_min
data: an array with the values to be reduced.
segment_ids: an array with integer dtype that indicates the segments of
`data` (along its leading axis) to be reduced. Values can be repeated and
need not be sorted. Values outside of the range [0, num_segments) are
dropped and do not contribute to the result.
num_segments: optional, an int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_min`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_min`` is
performed on each bucket separately. Default ``None`` means no bucketing.
mode: a :class:`jax.lax.GatherScatterMode` value describing how
out-of-bounds indices should be handled. By default, values outside of the
range [0, num_segments) are dropped and do not contribute to the sum.
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment minimums.
Simple 1D segment min:
>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_min(data, segment_ids)
Array([0, 2, 4], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
Array([0, 2, 4], dtype=int32)
return _segment_update(
"segment_min", data, segment_ids, lax.scatter_min, num_segments,
indices_are_sorted, unique_indices, bucket_size, reductions.min, mode=mode)