2023-06-19 00:49:18 +02:00

2178 lines
96 KiB

# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from functools import partial
import math
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
import weakref
import numpy as np
import jax
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.lax.utils import (
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, ArrayLike, Shape
from jax._src.util import safe_map, safe_zip
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
_dtype = partial(dtypes.dtype, canonicalize=True)
def slice(operand: ArrayLike, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `Slice
return slice_p.bind(operand, start_indices=tuple(start_indices),
strides=None if strides is None else tuple(strides))
def dynamic_slice(
operand: Union[Array, np.ndarray],
start_indices: Union[Union[Array, np.ndarray], Sequence[ArrayLike]],
slice_sizes: Shape,
) -> Array:
"""Wraps XLA's `DynamicSlice
operand: an array to slice.
start_indices: a list of scalar indices, one per dimension. These values
may be dynamic.
slice_sizes: the size of the slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`. Inside a JIT compiled
function, only static values are supported (all JAX arrays inside JIT
must have statically known size).
An array containing the slice.
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
Array([[ 5, 6, 7],
[ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice
overruns the bounds of the array; in this case the start index is adjusted to
return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4))
Array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
start_indices = _dynamic_slice_indices(operand, start_indices)
if jax.config.jax_dynamic_shapes:
dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
dynamic_sizes = []
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
def dynamic_update_slice(operand: Union[Array, np.ndarray], update: ArrayLike,
start_indices: Union[Array, Sequence[ArrayLike]]) -> Array:
"""Wraps XLA's `DynamicUpdateSlice
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
start_indices: a list of scalar indices, one per dimension.
An array containing the slice.
Here is an example of updating a one-dimensional slice update:
>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice(x, y, (2,))
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start
index will be adjusted to make it fit
>>> dynamic_update_slice(x, y, (3,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice(x, y, (5,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 2))
>>> dynamic_update_slice(x, y, (1, 2))
Array([[0., 0., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_update_slice_p.bind(operand, update, *start_indices)
class GatherDimensionNumbers(NamedTuple):
Describes the dimension number arguments to an `XLA's Gather operator
<>`_. See the XLA
documentation for more details of what the dimension numbers mean.
offset_dims: the set of dimensions in the `gather` output that offset into
an array sliced from `operand`. Must be a tuple of integers in ascending
order, each representing a dimension number of the output.
collapsed_slice_dims: the set of dimensions `i` in `operand` that have
`slice_sizes[i] == 1` and that should not have a corresponding dimension
in the output of the gather. Must be a tuple of integers in ascending
start_index_map: for each dimension in `start_indices`, gives the
corresponding dimension in `operand` that is to be sliced. Must be a
tuple of integers with size equal to `start_indices.shape[-1]`.
Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is
implicit; there is always an index vector dimension and it must always be the
last dimension. To gather scalar indices, add a trailing dimension of size 1.
offset_dims: Tuple[int, ...]
collapsed_slice_dims: Tuple[int, ...]
start_index_map: Tuple[int, ...]
class GatherScatterMode(enum.Enum):
Describes how to handle out-of-bounds indices in a gather or scatter.
Possible values are:
Indices will be clamped to the nearest in-range value, i.e., such that the
entire window to be gathered is in-range.
If any part of a gathered window is out of bounds, the entire window
that is returned, even those elements that were otherwise in-bounds, will be
filled with a constant.
If any part of a scattered window is out of bounds, the entire window
will be discarded.
The user promises that indices are in bounds. No additional checking will be
performed. In practice, with the current XLA implementation this means
that, out-of-bounds gathers will be clamped but out-of-bounds scatters will
be discarded. Gradients will not be correct if indices are out-of-bounds.
def from_any(s: Optional[Union[str, 'GatherScatterMode']]):
if isinstance(s, GatherScatterMode):
return s
if s == "clip":
return GatherScatterMode.CLIP
if s is None or s == "fill" or s == "drop":
return GatherScatterMode.FILL_OR_DROP
if s == "promise_in_bounds":
return GatherScatterMode.PROMISE_IN_BOUNDS
raise ValueError(f'Unknown gather mode "{s}"')
def gather(operand: ArrayLike, start_indices: ArrayLike,
dimension_numbers: GatherDimensionNumbers,
slice_sizes: Shape,
unique_indices: bool = False,
indices_are_sorted: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None,
fill_value = None) -> Array:
"""Gather operator.
Wraps `XLA's Gather operator
The semantics of gather are complicated, and its API might change in the
future. For most use cases, you should prefer `Numpy-style indexing
(e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly.
operand: an array from which slices should be taken
start_indices: the indices at which slices should be taken
dimension_numbers: a `lax.GatherDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices` and the output relate.
slice_sizes: the size of each slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`.
indices_are_sorted: whether `indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements gathered from ``operand`` are
guaranteed not to overlap with each other. If ``True``, this may improve
performance on some backends. JAX does not check this promise: if
the elements overlap the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to ``'clip'``,
indices are clamped so that the slice is within bounds, and when
set to ``'fill'`` or ``'drop'`` gather returns a slice full of
``fill_value`` for the affected slice. The behavior for out-of-bounds
indices when set to ``'promise_in_bounds'`` is implementation-defined.
fill_value: the fill value to return for out-of-bounds slices when `mode`
is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for inexact types,
the largest negative value for signed types, the largest positive value
for unsigned types, and ``True`` for booleans.
An array containing the gather output.
if mode is None:
mode = GatherScatterMode.PROMISE_IN_BOUNDS
parsed_mode = GatherScatterMode.from_any(mode)
if parsed_mode == GatherScatterMode.FILL_OR_DROP:
if fill_value is None:
dtype = _dtype(operand)
if dtypes.issubdtype(dtype, np.inexact):
fill_value = np.nan
elif dtypes.issubdtype(dtype, np.signedinteger):
fill_value = dtypes.iinfo(dtype).min
elif dtypes.issubdtype(dtype, np.unsignedinteger):
fill_value = dtypes.iinfo(dtype).max
elif dtype == dtypes.bool_:
fill_value = True
raise ValueError(f"Unsupported dtype for gather fill_value {dtype}")
fill_value = None
return gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
class ScatterDimensionNumbers(NamedTuple):
Describes the dimension number arguments to an `XLA's Scatter operator
<>`_. See the XLA
documentation for more details of what the dimension numbers mean.
update_window_dims: the set of dimensions in the `updates` that are window
dimensions. Must be a tuple of integers in ascending
order, each representing a dimension number.
inserted_window_dims: the set of size 1 window dimensions that must be
inserted into the shape of `updates`. Must be a tuple of integers in
ascending order, each representing a dimension number of the output. These
are the mirror image of `collapsed_slice_dims` in the case of `gather`.
scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives
the corresponding dimension in `operand`. Must be a sequence of integers
with size equal to indices.shape[-1].
Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is
implicit; there is always an index vector dimension and it must always be the
last dimension. To scatter scalar indices, add a trailing dimension of size 1.
update_window_dims: Sequence[int]
inserted_window_dims: Sequence[int]
scatter_dims_to_operand_dims: Sequence[int]
def scatter_add(
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-add operator.
Wraps `XLA's Scatter operator
<>`_, where
addition is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`` property on JAX arrays which uses
the familiar NumPy indexing syntax.
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
An array containing the sum of `operand` and the scattered updates.
jaxpr, consts = lax._reduction_jaxpr(lax.add,
lax._abstractify(lax._const(operand, 0)))
return scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
def scatter_mul(
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-multiply operator.
Wraps `XLA's Scatter operator
<>`_, where
multiplication is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`` property on JAX arrays which uses
the familiar NumPy indexing syntax.
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
An array containing the sum of `operand` and the scattered updates.
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
lax._abstractify(lax._const(operand, 1)))
return scatter_mul_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
def scatter_min(
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-min operator.
Wraps `XLA's Scatter operator
<>`_, where
the `min` function is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`` property on JAX arrays which uses
the familiar NumPy indexing syntax.
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
An array containing the sum of `operand` and the scattered updates.
jaxpr, consts = lax._reduction_jaxpr(lax.min,
lax._abstractify(lax._const(operand, 0)))
return scatter_min_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
def scatter_max(
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-max operator.
Wraps `XLA's Scatter operator
<>`_, where
the `max` function is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`` property on JAX arrays which uses
the familiar NumPy indexing syntax.
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
An array containing the sum of `operand` and the scattered updates.
jaxpr, consts = lax._reduction_jaxpr(lax.max,
lax._abstractify(lax._const(operand, 0)))
return scatter_max_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
# To avoid recompilation, we store a dict of weak references to funcs.
_scatter_apply_cache: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
def scatter_apply(
operand: Array, scatter_indices: Array,
func: Callable[[Array], Array],
dimension_numbers: ScatterDimensionNumbers, *,
update_shape: Shape = (),
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-apply operator.
Wraps `XLA's Scatter operator
<>`_, where values
from ``operand`` are replaced with ``func(operand)``, with duplicate indices
resulting in multiple applications of ``func``.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Note that in the current implementation, ``scatter_apply`` is not compatible
with automatic differentiation.
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
func: unary function that will be applied at each index.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
update_shape: the shape of the updates at the given indices.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
An array containing the result of applying `func` to `operand` at the given indices.
# TODO: can we implement this without a placeholder?
unused = lax.full(update_shape, 0, operand.dtype)
_apply = lambda x, _: func(x)
_apply = _scatter_apply_cache.setdefault(func, _apply)
except TypeError: # func is not weak referenceable
jaxpr, consts = lax._reduction_jaxpr(_apply, lax._abstractify(lax._zero(operand)))
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
return scatter_p.bind(
operand, scatter_indices, unused, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
# Define this outside of scatter to ensure cache hits.
_scatter_reduction_computation = lambda x, y: y
def scatter(
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-update operator.
Wraps `XLA's Scatter operator
<>`_, where updates
replace values from `operand`.
If multiple updates are performed to the same index of operand, they may be
applied in any order.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`` property on JAX arrays which uses
the familiar NumPy indexing syntax.
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
An array containing the sum of `operand` and the scattered updates.
jaxpr, consts = lax._reduction_jaxpr(_scatter_reduction_computation,
core.ShapedArray((), lax.dtype(operand)))
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
indices = lax.concatenate([lax.expand_dims(i, (1,)) for i in idxs], 1)
max_idx = lax.expand_dims(np.array([src.shape[ax] for ax in axes]),
tuple(range(indices.ndim - 1)))
indices = indices % max_idx
slice_sizes = list(src.shape)
for ax in axes:
slice_sizes[ax] = 1
offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1))
dnums = GatherDimensionNumbers(
return gather(src, indices, dimension_numbers=dnums,
### convenience wrappers around traceables
def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int],
limit_index: Optional[int],
stride: int = 1, axis: int = 0) -> Array:
"""Convenience wrapper around slice applying to only one dimension."""
start_indices = [0] * operand.ndim
limit_indices = list(operand.shape)
strides = [1] * operand.ndim
# translate `None`
len_axis = operand.shape[axis]
start_index_int = (core._canonicalize_dimension(start_index)
if start_index is not None else 0)
limit_index_int = (core._canonicalize_dimension(limit_index)
if limit_index is not None else len_axis)
# translate negative indices
if start_index_int < 0:
start_index_int = start_index_int + len_axis
if limit_index_int < 0:
limit_index_int = limit_index_int + len_axis
axis = int(axis)
start_indices[axis] = start_index_int
limit_indices[axis] = limit_index_int
strides[axis] = int(stride)
return slice(operand, start_indices, limit_indices, strides)
def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0,
keepdims: bool = True) -> Array:
"""Convenience wrapper around slice to perform int indexing."""
index, axis = core._canonicalize_dimension(index), int(axis)
axis_size = operand.shape[axis]
wrapped_index = index + axis_size if index < 0 else index
if not 0 <= wrapped_index < axis_size:
msg = 'index {} is out of bounds for axis {} with size {}'
raise IndexError(msg.format(index, axis, axis_size))
result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
if keepdims:
return result
return lax.squeeze(result, (axis,))
def dynamic_slice_in_dim(operand: Union[Array, np.ndarray],
start_index: ArrayLike,
slice_size: int, axis: int = 0) -> Array:
"""Convenience wrapper around dynamic_slice applying to one dimension."""
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
slice_sizes = list(operand.shape)
axis = int(axis)
start_indices[axis] = start_index
slice_sizes[axis] = core._canonicalize_dimension(slice_size)
return dynamic_slice(operand, start_indices, slice_sizes)
def dynamic_index_in_dim(operand: Union[Array, np.ndarray],
index: Union[int, Array],
axis: int = 0, keepdims: bool = True) -> Array:
"""Convenience wrapper around dynamic_slice to perform int indexing."""
result = dynamic_slice_in_dim(operand, index, 1, axis)
if keepdims:
return result
return lax.squeeze(result, (axis,))
def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray],
update: ArrayLike,
start_index: ArrayLike, axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
in a single ``axis``.
axis = int(axis)
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
start_indices[axis] = start_index
return dynamic_update_slice(operand, update, start_indices)
def dynamic_update_index_in_dim(operand: Union[Array, np.ndarray],
update: ArrayLike, index: ArrayLike,
axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
axis = int(axis)
if lax._ndim(update) != lax._ndim(operand):
assert lax._ndim(update) + 1 == lax._ndim(operand)
update = lax.expand_dims(update, (axis,))
return dynamic_update_slice_in_dim(operand, update, index, axis)
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
lax._check_shapelike("slice", "start_indices", start_indices)
lax._check_shapelike("slice", "limit_indices", limit_indices)
if operand.ndim != len(start_indices):
msg = ("slice start_indices must have length equal to the number of "
"dimensions of the operand, got indices {} for operand shape {}.")
raise TypeError(msg.format(start_indices, operand.shape))
if len(start_indices) != len(limit_indices):
msg = ("slice limit_indices must have the same length as start_indices, "
"got start_indices {} and limit_indices {}.")
raise TypeError(msg.format(start_indices, limit_indices))
if not core.greater_equal_shape(operand.shape, limit_indices):
msg = ("slice limit_indices must be less than or equal to operand shape, "
"got limit_indices {} for operand shape {}.")
raise TypeError(msg.format(limit_indices, operand.shape))
if not all(core.greater_equal_dim(si, 0) for si in start_indices):
msg = ("slice start_indices must be greater than or equal to zero, "
"got start_indices of {}.")
raise TypeError(msg.format(start_indices))
if not jax.config.jax_dynamic_shapes:
if not core.greater_equal_shape(limit_indices, start_indices):
msg = ("slice limit_indices must be greater than or equal to start_indices,"
" got start_indices {} and limit_indices {}.")
raise TypeError(msg.format(start_indices, limit_indices))
if strides is None or tuple(strides) == (1,) * len(operand.shape):
shape = [limit if type(start) is int and start == 0 else limit - start
for start, limit in zip(start_indices, limit_indices)]
return tuple(shape)
lax._check_shapelike("slice", "strides", strides)
if len(strides) != operand.ndim:
msg = ("slice strides must have length equal to the number of dimensions "
"of the operand, got strides {} for operand shape {}.")
raise TypeError(msg.format(strides, operand.shape))
if not core.greater_equal_shape(strides, (0,) * len(strides)):
msg = "slice strides must be positive, got {}"
raise TypeError(msg.format(strides))
diff = core.diff_shape(limit_indices, start_indices)
return core.stride_shape(diff, (1,) * len(diff), strides)
def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
assert ad.is_undefined_primal(operand)
operand_shape = operand.aval.shape
if strides is None or np.all(np.equal(strides, 1)):
pads = zip(start_indices, np.subtract(operand_shape, limit_indices),
(0,) * len(start_indices))
real_limits = np.add(
np.where(np.array(t.shape) == 0, 0,
np.add(1, np.multiply(np.subtract(t.shape, 1), strides))))
pads = zip(start_indices, np.subtract(operand_shape, real_limits),
np.subtract(strides, 1))
result = lax.pad(t, lax._const(t, 0), pads)
assert result.shape == operand_shape, f"{result.shape=} {operand_shape=}"
return [result]
def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
limit_indices, strides):
operand, = batched_args
bdim, = batch_dims
new_start_indices = list(start_indices)
new_start_indices.insert(bdim, 0)
new_limit_indices = list(limit_indices)
new_limit_indices.insert(bdim, operand.shape[bdim])
if strides is None:
new_strides = None
new_strides = list(strides)
new_strides.insert(bdim, 1)
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
return out, bdim
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
ad.deflinear2(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out
return [mlir.slice_op(ctx, x, aval_out,
start_indices=start_indices, limit_indices=limit_indices, strides=strides)]
mlir.register_lowering(slice_p, _slice_lower)
def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
if operand.ndim != len(start_indices):
msg = ("dynamic_slice start_indices must have length equal to the number "
"of dimensions of the operand, got indices {} for operand shape {}.")
raise TypeError(msg.format(start_indices, operand.shape))
if len(start_indices) != len(slice_sizes):
msg = ("dynamic_slice slice_sizes must have the same length as "
"start_indices, got start_indices length {} and slice_sizes {}.")
raise TypeError(msg.format(len(start_indices), slice_sizes))
if not core.greater_equal_shape(operand.shape, slice_sizes):
msg = ("slice slice_sizes must be less than or equal to operand shape, "
"got slice_sizes {} for operand shape {}.")
raise TypeError(msg.format(slice_sizes, operand.shape))
if not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes):
msg = ("slice slice_sizes must be greater than or equal to zero, "
"got slice_sizes of {}.")
raise TypeError(msg.format(slice_sizes))
if any(idx.ndim != 0 for idx in start_indices):
raise TypeError("start_indices arguments to dynamic_slice must be scalars, "
f" got indices {start_indices}")
return tuple(slice_sizes)
def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes):
if any(i.dtype != start_indices[0].dtype or
not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
msg = ("index arguments to dynamic_slice must be integers of the same "
"type, got: {}")
raise TypeError(msg.format(", ".join( for i in start_indices)))
return operand.dtype
def _dynamic_slice_jvp(primals, tangents, *, slice_sizes):
tangent_out = tangents[0]
if type(tangent_out) is not ad_util.Zero:
tangent_out = dynamic_slice_p.bind(tangent_out, *primals[1:], slice_sizes=slice_sizes)
return dynamic_slice_p.bind(primals[0], *primals[1:], slice_sizes=slice_sizes), tangent_out
def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
assert ad.is_undefined_primal(operand)
assert all(not ad.is_undefined_primal(s) for s in start_indices)
operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype
if type(t) is ad_util.Zero:
return [ad_util.Zero(operand.aval)] + [None] * len(start_indices)
zeros = lax.full(operand_shape, 0, operand_dtype)
return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] +
[None] * len(start_indices))
def _batch_dynamic_slice_indices(indices, bdims):
if len(indices) == 0:
return np.array([], 'int32'), None
empty_marker = object()
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None),
if size is empty_marker:
return lax.concatenate([lax.broadcast(i, (1,)) for i in indices], 0), None
indices = lax.concatenate(
[lax.broadcast_in_dim(x, (size, 1),
broadcast_dimensions=((0,) if i is not None else ()))
for x, i in zip(indices, bdims)],
return indices, 0
def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
# A dynamic slice is a special case of gather; we can delegate to the gather
# batching rule.
# TODO(phawkins): consider removing dynamic_slice entirely and using gather
# always.
operand, *start_indices = batched_args
operand_bd, *start_idx_bds = batch_dims
operand_shape = (operand.shape if operand_bd is batching.not_mapped
else tuple(np.delete(operand.shape, operand_bd)))
dims = tuple(range(len(operand_shape)))
dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(),
index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds)
return _gather_batching_rule(
[operand, index], [operand_bd, index_bdim], dimension_numbers=dnums,
slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True,
mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None)
def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim])
if not dyn:
return trace.default_process_primitive(dynamic_slice_p, (x, *start_indices),
shape = lax._merge_dyn_shape(slice_sizes, dyn)
aval = core.DShapedArray(shape, x.dtype, False)
return lax._dyn_shape_staging_rule(trace, dynamic_slice_p, aval, x,
def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim])
if not dyn:
out_aval, effects = dynamic_slice_p.abstract_eval(
x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes)
return [out_aval], effects
# TODO(mattjj): perform more checks
out_shape = lax._merge_dyn_shape(slice_sizes, dyn)
out_shape = [d.val if type(d) is core.Literal else d for d in out_shape]
out_aval = core.DShapedArray(tuple(out_shape), x.aval.dtype,
return [out_aval], core.no_effects
def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn,
x_aval, start_indices_avals, dyn_avals = util.split_list(in_avals, [1, x.ndim])
start_indices, dyn = util.split_list(starts_and_dyn, [x.ndim])
dyn_ = [a.dtype.bound if type(a.dtype) is core.bint else d
for a, d in zip(dyn_avals, dyn)]
slice_sizes_ = lax._merge_dyn_shape(slice_sizes, dyn_)
start_idx = [d.val if type(d) is core.DArray else d for d in start_indices]
return [dynamic_slice(x, start_idx, slice_sizes_)]
dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
pe.custom_staging_rules[dynamic_slice_p] = _dynamic_slice_staging_rule
core.custom_typechecks[dynamic_slice_p] = _dynamic_slice_typecheck_rule
pe.padding_rules[dynamic_slice_p] = _dynamic_slice_padding_rule
def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
x_aval, *_ = ctx.avals_in
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x_aval.ndim])
aval_out, = ctx.avals_out
if dyn:
aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)]
mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
# def _getslice_lower(ctx, x, lo, hi):
# aval_out, = ctx.avals_out
# return hlo.RealDynamicSliceOp(
# mlir.aval_to_ir_type(aval_out), x,
# mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
# ).results
# mlir.register_lowering(getslice_p, _getslice_lower)
def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
if operand.ndim != update.ndim:
msg = ("dynamic_update_slice update must have the same rank as operand, "
"got update shape {} for operand shape {}.")
raise TypeError(msg.format(update.shape, operand.shape))
if operand.ndim != len(start_indices):
msg = ("dynamic_update_slice start_indices must have length equal to the "
"rank of operand, got indices {} for operand shape {}.")
raise TypeError(msg.format(start_indices, operand.shape))
if not core.greater_equal_shape(operand.shape, update.shape):
msg = ("dynamic_update_slice update shape must be smaller than operand "
"shape, got update shape {} for operand shape {}.")
raise TypeError(msg.format(update.shape, operand.shape))
if any(idx.ndim != 0 for idx in start_indices):
raise TypeError("start_indices arguments to dynamic_update_slice must be "
f"scalars, got indices {start_indices}")
return operand.shape
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
lax.check_same_dtypes("dynamic_update_slice", operand, update)
if any(i.dtype != start_indices[0].dtype or
not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
msg = ("index arguments to dynamic_update_slice must be integers of the "
"same type, got {}")
raise TypeError(msg.format(", ".join( for i in start_indices)))
return operand.dtype
def _dynamic_update_slice_jvp(primals, tangents):
operand, update = primals[:2]
start_indices = primals[2:]
g_operand, g_update = tangents[:2]
val_out = dynamic_update_slice_p.bind(operand, update, *start_indices)
if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
g_operand = ad.instantiate_zeros(g_operand)
g_update = ad.instantiate_zeros(g_update)
tangent_out = dynamic_update_slice_p.bind(g_operand, g_update, *start_indices)
return val_out, tangent_out
def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
assert all(not ad.is_undefined_primal(x) for x in start_indices)
if ad.is_undefined_primal(update):
update_shape = update.aval.shape
update_shape = update.shape
if type(t) is ad_util.Zero:
operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None
dus = dynamic_update_slice_p.bind
ds = dynamic_slice_p.bind
zeros = lax._zeros(t, shape=update_shape)
operand_t = dus(t, zeros, *start_indices) if ad.is_undefined_primal(operand) else None
update_t = ds(t, *start_indices, slice_sizes=update_shape) if ad.is_undefined_primal(update) else None
return [operand_t, update_t] + [None] * len(start_indices)
def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
# A dynamic update slice is a special case of scatter; we can delegate to the
# scatter batching rule.
# TODO(phawkins): consider removing dynamic_update_slice entirely and using
# scatter always.
operand, update, *start_idx = batched_args
operand_bd, update_bd, *start_idx_bd = batch_dims
update_shape = (np.shape(update) if update_bd is batching.not_mapped
else tuple(np.delete(np.shape(update), update_bd)))
dims = tuple(range(len(update_shape)))
dnums = ScatterDimensionNumbers(update_window_dims=dims,
index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd)
return _scatter_batching_rule(
scatter, (operand, index, update), (operand_bd, index_bdim, update_bd),
update_jaxpr=None, update_consts=None, dimension_numbers=dnums,
indices_are_sorted=True, unique_indices=True,
dynamic_update_slice_p = standard_primitive(
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
ad.primitive_transposes[dynamic_update_slice_p] = \
batching.primitive_batchers[dynamic_update_slice_p] = \
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
return [mlir.dynamic_update_slice(ctx, aval_out, x, update,
mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower)
def _gather_dtype_rule(operand, indices, *, fill_value, **kwargs):
if not dtypes.issubdtype(indices.dtype, np.integer):
raise ValueError("indices must have an integer type")
return dtypes.canonicalize_dtype(operand.dtype, allow_opaque_dtype=True)
_rank = lambda arr: len(arr.shape)
def _is_sorted(dims, op_name, name):
for i in range(1, len(dims)):
if dims[i] < dims[i - 1]:
raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}")
def _sorted_dims_in_range(dims, rank, op_name, name):
if len(dims) == 0:
invalid_dim = None
if dims[0] < 0:
invalid_dim = dims[0]
elif dims[-1] >= rank:
invalid_dim = dims[-1]
if invalid_dim:
raise TypeError(f"Invalid {name} set in {op_name} op; valid range is "
f"[0, {rank}); got: {invalid_dim}.")
def _no_duplicate_dims(dims, op_name, name):
if len(set(dims)) != len(dims):
raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.")
def _gather_shape_rule(operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
"""Validates the well-formedness of the arguments to Gather.
The code implements the checks based on the detailed operation semantics of
XLA's `Gather <>`_
operator and following the outline of the implementation of
ShapeInference::InferGatherShape in TensorFlow.
offset_dims = dimension_numbers.offset_dims
collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
start_index_map = dimension_numbers.start_index_map
# Note: in JAX, index_vector_dim is always computed as below, cf. the
# documentation of the GatherDimensionNumbers class.
index_vector_dim = _rank(indices) - 1
# This case should never happen in JAX, due to the implicit construction of
# index_vector_dim, but is included for completeness.
if _rank(indices) < index_vector_dim or index_vector_dim < 0:
raise TypeError(f"Gather index leaf dimension must be within [0, rank("
f"indices) + 1). rank(indices) is {_rank(indices)} and "
f"gather index leaf dimension is {index_vector_dim}.")
expanded_indices_shape = list(indices.shape)
# This case should never happen in JAX, due to the implicit construction of
# index_vector_dim, but is included for completeness.
if len(expanded_indices_shape) == index_vector_dim:
# Start ValidateGatherDimensions
# In the error messages output by XLA, "offset_dims" is called "Output window
# dimensions" in error messages. For consistency's sake, our error messages
# stick to "offset_dims".
_is_sorted(offset_dims, "gather", "offset_dims")
_no_duplicate_dims(offset_dims, "gather", "offset_dims")
output_offset_dim_count = len(offset_dims)
output_shape_rank = len(offset_dims) + _rank(indices) - 1
for i in range(output_offset_dim_count):
offset_dim = offset_dims[i]
if offset_dim < 0 or offset_dim >= output_shape_rank:
raise TypeError(f"Offset dimension {i} in gather op is out of bounds; "
f"got {offset_dim}, but should have been in "
f"[0, {output_shape_rank})")
if len(start_index_map) != indices.shape[index_vector_dim]:
raise TypeError(f"Gather op has {len(start_index_map)} elements in "
f"start_index_map and the bound of dimension "
f"{index_vector_dim=} of indices is "
f"{indices.shape[index_vector_dim]}. These two "
f"numbers must be equal.")
for i in range(len(start_index_map)):
operand_dim_for_start_index_i = start_index_map[i]
if (operand_dim_for_start_index_i < 0 or
operand_dim_for_start_index_i >= _rank(operand)):
raise TypeError(f"Invalid start_index_map; domain is "
f"[0, {_rank(operand)}), got: "
_no_duplicate_dims(start_index_map, "gather", "start_index_map")
# _is_sorted and _sorted_dims_in_range are checked in the opposite order
# compared to the XLA implementation. In cases when the input is not sorted
# AND there are problematic collapsed_slice_dims, the error message will thus
# be different.
_is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims")
_sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather",
_no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims")
# End ValidateGatherDimensions
if _rank(operand) != len(slice_sizes):
raise TypeError(f"Gather op must have one slice size for every input "
f"dimension; got: len(slice_sizes)={len(slice_sizes)}, "
if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims):
raise TypeError(f"All components of the offset index in a gather op must "
f"either be a offset dimension or explicitly collapsed; "
f"got len(slice_sizes)={len(slice_sizes)}, "
f"output_slice_sizes={offset_dims}, collapsed_slice_dims="
for i in range(len(slice_sizes)):
slice_size = slice_sizes[i]
corresponding_input_size = operand.shape[i]
if not (core.greater_equal_dim(slice_size, 0) and
core.greater_equal_dim(corresponding_input_size, slice_size)):
raise TypeError(f"Slice size at index {i} in gather op is out of range, "
f"must be within [0, {corresponding_input_size} + 1), "
f"got {slice_size}.")
for i in range(len(collapsed_slice_dims)):
bound = slice_sizes[collapsed_slice_dims[i]]
if bound != 1:
raise TypeError(f"Gather op can only collapse slice dims with bound 1, "
f"but bound is {bound} for index "
f"{collapsed_slice_dims[i]} at position {i}.")
indices_shape = iter(expanded_indices_shape)
slice_sizes = (s for i, s in enumerate(slice_sizes)
if i not in collapsed_slice_dims)
return tuple(next(slice_sizes) if i in offset_dims
else next(indices_shape) for i in range(output_shape_rank))
def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
unique_indices, indices_are_sorted, fill_value,
"""Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking."""
dnums = dimension_numbers
intarray = partial(np.array, dtype=np.int64)
operand_dims = lax.shape_as_value(operand.shape)
indices = lax.convert_element_type(indices, np.int64)
num_batch_dims = len(indices.shape) - 1
upper_bound = (
operand_dims[intarray(dnums.start_index_map)] -
mask = lax.bitwise_and(, np.int64(0)),
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
mask = lax._reduce_and(mask, [num_batch_dims])
# Computes the output shape and the positions of the batch dimensions in the
# output
output_ndims = num_batch_dims + len(dnums.offset_dims)
batch_dims_in_output = np.delete(np.arange(output_ndims),
# We don't consume unique_indices directly in gather(), only in its transpose
# (scatter).
gather_out = gather(operand, indices, dnums, slice_sizes,
lax.broadcast_in_dim(mask, output_shape, batch_dims_in_output),
gather_out, lax.full_like(gather_out, fill_value=fill_value))
def _gather_jvp_rule(g, operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted, mode,
return gather(g, indices, dimension_numbers, slice_sizes,
indices_are_sorted=indices_are_sorted, mode=mode,
def _gather_transpose_rule(t, operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
assert ad.is_undefined_primal(operand)
operand_shape = operand.aval.shape
if type(t) is ad_util.Zero:
out = ad_util.Zero(operand.aval)
zeros = lax.full(operand_shape, lax._zero(t))
scatter_dnums = ScatterDimensionNumbers(
out = scatter_add(zeros, indices, t, scatter_dnums,
return [out, None]
def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
operand, indices = batched_args
operand_bdim, indices_bdim = batch_dims
if operand_bdim is not None and indices_bdim is None:
operand = batching.moveaxis(operand, operand_bdim, 0)
slice_sizes = (operand.shape[0],) + slice_sizes
offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims))
collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
start_index_map = tuple(np.add(1, dimension_numbers.start_index_map))
dnums = GatherDimensionNumbers(
return gather(operand, indices, dimension_numbers=dnums,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode,
fill_value=fill_value), 0
elif operand_bdim is None and indices_bdim is not None:
indices = batching.moveaxis(indices, indices_bdim, 0)
offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims)
dnums = GatherDimensionNumbers(
# If batching indexed accesses into the same array, the batched gather may
# no longer have sorted or unique indices.
return gather(operand, indices, dimension_numbers=dnums,
slice_sizes=slice_sizes, unique_indices=False,
indices_are_sorted=False, mode=mode, fill_value=fill_value), 0
# move batch dimensions to the front to simplify logic
operand = batching.moveaxis(operand, operand_bdim, 0)
indices = batching.moveaxis(indices, indices_bdim, 0)
# This slightly awkward special case is needed because the shape rule for
# gather does not allow size-1 slices out of a size-0 dimension, even if
# the number of slices is zero. Likely the best fix would be to change the
# definition of gather() so it can be batched without the construction of
# an explicit iota of size-1 slices.
if core.symbolic_equal_dim(operand.shape[0], 0):
output_shape = _gather_shape_rule(
core.ShapedArray(operand.shape[1:], operand.dtype),
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value)
return lax.full((0,) + output_shape, lax._zero(operand)), 0
# Example: user code had indices shape (3, 4, 5), and we have to deal with
# indices shape (7, 3, 4, 5). We transform that to indices of shape
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
# dimension to the front of the ndindex.
count_shape = list(indices.shape)
count_shape[-1] = 1
counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0)
indices = lax.concatenate([counts, indices], len(count_shape) - 1)
slice_sizes = (1,) + slice_sizes
collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map))
dnums = GatherDimensionNumbers(
return gather(operand, indices, dimension_numbers=dnums,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode,
fill_value=fill_value), 0
def _gather_pad_rule(in_avals, out_avals, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
operand_aval, indices_aval = in_avals
if any(isinstance(d, pe.BoundedAxisSize) for d in operand_aval.shape):
raise NotImplementedError
if mode != GatherScatterMode.PROMISE_IN_BOUNDS:
# with fill, jnp.where on operand; with clip, jnp.where on indices
raise NotImplementedError
return [gather(operand, indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, mode=mode, fill_value=fill_value)]
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
pe.padding_rules[gather_p] = _gather_pad_rule
def _gather_lower_opaque(ctx, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value) -> ir.Value:
aval_x, aval_indices = ctx.avals_in
aval_y, = ctx.avals_out
elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape
trailing_offset_dims = [aval_y.ndim + i for i in range(len(elt_shape))]
dimension_numbers = dimension_numbers._replace(
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
slice_sizes = (*slice_sizes, *elt_shape)
gather_lower = partial(
_gather_lower, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
res, = mlir.delegate_lowering(
ctx, gather_lower, operand, indices,
avals_in=[core.physical_aval(aval_x), aval_indices],
return res
def _gather_lower(ctx, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
aval_out, = ctx.avals_out
if dtypes.is_opaque_dtype(aval_out.dtype):
return [_gather_lower_opaque(
ctx, operand, indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode,
if mode == GatherScatterMode.FILL_OR_DROP:
gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False)
return gather_fill_fn(
ctx, operand, indices,
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value, output_shape=aval_out.shape)
assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS,
GatherScatterMode.CLIP), mode
dnums = hlo.GatherDimensionNumbers.get(
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
if not core.is_constant_shape(slice_sizes):
slice_sizes = mlir.eval_dynamic_shape(ctx, slice_sizes)
# TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
# For now use the build_generic so that we can specify the result type.
# return hlo.DynamicGatherOp(
# operand, indices, mlir.shape_tensor(slice_sizes),
# dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results
results = [mlir.aval_to_ir_type(aval_out)]
operands = [operand, indices, mlir.shape_tensor(slice_sizes)]
attributes = {
"dimension_numbers": dnums,
"indices_are_sorted": ir.BoolAttr.get(indices_are_sorted)
return hlo.DynamicGatherOp.build_generic(
results=results, operands=operands, attributes=attributes).results
return hlo.GatherOp(
mlir.register_lowering(gather_p, _gather_lower)
def _scatter_dtype_rule(operand, indices, updates, **kwargs):
if not dtypes.issubdtype(indices.dtype, np.integer):
raise ValueError("indices must have an integer type")
lax.check_same_dtypes("scatter", operand, updates)
return dtypes.canonicalize_dtype(operand.dtype, allow_opaque_dtype=True)
def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr,
update_consts, dimension_numbers, indices_are_sorted,
unique_indices, mode):
"""Validates the well-formedness of the ``dimension_numbers`` argument to
The code implements the checks based on the detailed operation semantics of
XLA's `Scatter <>`_
operator and following the outline of the implementation of
ShapeInference::InferScatterShape in TensorFlow.
update_window_dims = dimension_numbers.update_window_dims
inserted_window_dims = dimension_numbers.inserted_window_dims
scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims
# Note: in JAX, index_vector_dim is always computed as below, cf. the
# documentation of the ScatterDimensionNumbers class.
index_vector_dim = _rank(indices) - 1
# This case should never happen in JAX, due to the implicit construction of
# index_vector_dim, but is included for completeness.
if _rank(indices) < index_vector_dim or index_vector_dim < 0:
raise TypeError(f"Scatter index leaf dimension must be within [0, "
f"rank(indices) + 1). rank(indices) is {_rank(indices)} "
f"and scatter index leaf dimension is {index_vector_dim}.")
expanded_indices_shape = list(indices.shape)
# This case should never happen in JAX, due to the implicit construction of
# index_vector_dim, but is included for completeness.
if len(expanded_indices_shape) == index_vector_dim:
expected_updates_rank = (len(expanded_indices_shape) - 1 +
if _rank(updates) != expected_updates_rank:
raise TypeError(f"Updates tensor must be of rank {expected_updates_rank}; "
f"got {_rank(updates)}.")
# Validate update_window_dims
_is_sorted(update_window_dims, "scatter", "update_window_dims")
_no_duplicate_dims(update_window_dims, "scatter", "update_window_dims")
_sorted_dims_in_range(update_window_dims, _rank(updates), "scatter",
# Validate inserted_window_dims
_is_sorted(inserted_window_dims, "scatter", "inserted_window_dims")
_no_duplicate_dims(inserted_window_dims, "scatter", "inserted_window_dims")
_sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter",
# Validate window_size
window_size = len(update_window_dims) + len(inserted_window_dims)
if _rank(operand) != window_size:
raise TypeError(f"Scatter op has window of size {window_size}; doesn't "
f"match operand of rank {_rank(operand)}.")
# Validate scatter_dims_to_operand_dims
if (len(scatter_dims_to_operand_dims) !=
raise TypeError(f"Scatter op has {len(scatter_dims_to_operand_dims)} "
f"elements in scatter_dims_to_operand_dims and the bound "
f"of dimension {index_vector_dim=} of "
f"indices is {indices.shape[index_vector_dim]}. These two "
f"numbers must be equal")
for i in range(len(scatter_dims_to_operand_dims)):
dim = scatter_dims_to_operand_dims[i]
if dim < 0 or dim >= _rank(operand):
raise TypeError(f"Invalid scatter_dims_to_operand_dims mapping; domain "
f"is [0, {_rank(operand)}), got: {i}->{dim}.")
_no_duplicate_dims(scatter_dims_to_operand_dims, "scatter",
max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape))
if not i in set(inserted_window_dims)]
for i in range(len(update_window_dims)):
update_window_dim = update_window_dims[i]
if not core.greater_equal_dim(max_update_slice_sizes[i], updates.shape[update_window_dim]):
raise TypeError(f"Bounds of the window dimensions of updates must not "
f"exceed the bounds of the corresponding dimensions of "
f"operand. For dimension {update_window_dim}, updates "
f"bound is {updates.shape[update_window_dim]}, operand "
f"bound is {max_update_slice_sizes[i]}.")
update_scatter_dims = [dim for dim in range(_rank(updates)) if dim not in
scatter_dims_seen = 0
for i in update_scatter_dims:
if scatter_dims_seen == index_vector_dim:
scatter_dims_seen += 1
if not core.symbolic_equal_dim(updates.shape[i], expanded_indices_shape[scatter_dims_seen]):
raise TypeError(f"Bounds of the scatter dimensions of updates must be "
f"the same as the bounds of the corresponding dimensions "
f"of scatter indices. For scatter dimension {i}, updates "
f"bound is {updates.shape[i]}, indices bound is "
scatter_dims_seen += 1
return operand.shape
def _clamp_scatter_indices(operand, indices, updates, *, dnums):
"""Clamps `indices` to be in-range for a scatter."""
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
pos += 1
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims)
# Stack upper_bounds into a Array[n]
upper_bound = lax.shape_as_value(upper_bounds)
# This fix fails lax_test_no_jax_array
upper_bound = lax.min(upper_bound,
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),
def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
operand, indices, updates = primals
g_operand, g_indices, g_updates = tangents
del g_indices # ignored
val_out = scatter_add_p.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
tangent_out = scatter_add_p.bind(
g_operand, indices, g_updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
return val_out, tangent_out
def _scatter_add_transpose_rule(t, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
assert not ad.is_undefined_primal(indices)
if ad.is_undefined_primal(updates):
updates_shape = updates.aval.shape
updates_shape = updates.shape
if type(t) is ad_util.Zero:
operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None
operand_t = update_t = None
if ad.is_undefined_primal(operand):
operand_t = t
if ad.is_undefined_primal(updates):
gather_dnums = GatherDimensionNumbers(
slice_sizes = []
pos = 0
for i in range(len(t.shape)):
if i in dimension_numbers.inserted_window_dims:
pos += 1
update_t = gather(t, indices, dimension_numbers=gather_dnums,
slice_sizes=slice_sizes, mode=mode, fill_value=0)
return [operand_t, None, update_t]
def _scatter_mul_transpose_rule(t, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
assert not ad.is_undefined_primal(indices)
if ad.is_undefined_primal(updates):
updates_shape = updates.aval.shape
updates_shape = updates.shape
if type(t) is ad_util.Zero:
operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None
operand_t = update_t = None
if ad.is_undefined_primal(operand):
operand_t = scatter_mul(
t, indices, updates, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
if ad.is_undefined_primal(updates):
if not unique_indices:
raise NotImplementedError(
"scatter_mul gradients are only implemented if `unique_indices=True`")
gather_dnums = GatherDimensionNumbers(
slice_sizes = []
pos = 0
for i in range(len(t.shape)):
if i in dimension_numbers.inserted_window_dims:
pos += 1
update_t = gather(lax.mul(t, operand), indices,
dimension_numbers=gather_dnums, slice_sizes=slice_sizes,
mode=mode, fill_value=0)
return [operand_t, None, update_t]
def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
operand, indices, updates = batched_args
operand_bdim, indices_bdim, updates_bdim = batch_dims
del update_jaxpr, update_consts # Unused.
# move the operand batch dim to the front if it is not None, otherwise create
# it at the front (so that we can scatter into it)
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
if ax is not None)
operand = batching.bdim_at_front(operand, operand_bdim, size)
operand_bdim = 0
updates = batching.bdim_at_front(updates, updates_bdim, size)
if indices_bdim is None:
inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims))
update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims))
scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims))
dnums = ScatterDimensionNumbers(
return scatter_op(
operand, indices, updates, dnums,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode), 0
# see the third case in _gather_batching_rule for comparison and comments
indices = batching.bdim_at_front(indices, indices_bdim, size)
count_shape = list(indices.shape)
count_shape[-1] = 1
counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0)
indices = lax.concatenate([counts, indices], len(count_shape) - 1)
update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims))
inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims))
scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims))
dnums = ScatterDimensionNumbers(
return scatter_op(
operand, indices, updates, dnums,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode), 0
scatter_add_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))
scatter_mul_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers,
indices_are_sorted, unique_indices, mode, **kw):
if not unique_indices:
raise NotImplementedError(
"scatter_mul gradients are only implemented if `unique_indices=True`")
return lax.mul(x, scatter_add(
lax.zeros_like_array(x), i, g, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw),
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
batching.primitive_batchers[scatter_mul_p] = (
partial(_scatter_batching_rule, scatter_mul))
def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
operand, indices, updates = primals
g_operand, g_indices, g_updates = tangents
scatter_dnums = dimension_numbers
updates_shape = updates.shape
val_out = scatter_op.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=scatter_dnums,
unique_indices=unique_indices, mode=mode)
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
# gather_dnums and slice_sizes define the gather op that is the inverse of
# the scatter op specified by scatter_dnums
gather_dnums = GatherDimensionNumbers(
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in scatter_dnums.inserted_window_dims:
pos += 1
# For consistency with other max operations, if there are two or more values
# in updates that are contending to replace the same index location, the
# resulting tangent at that location will be the average of the associated
# tangents for the values in updates.
initial_vals = gather(
operand, indices, gather_dnums, np.array(slice_sizes))
target_vals = gather(
val_out, indices, gather_dnums, np.array(slice_sizes))
successful_updates = (updates == target_vals)
retained_values = (initial_vals == target_vals)
num_updates = gather(
lax._zeros(operand), indices,, lax._ones(updates),
num_refs = gather(
updates_normalizer =,
1.0 / (num_updates + 1),
1.0 / num_updates)
updates_coef =,
operand_normalizer =,
1.0 / (num_updates + 1),
operand_coef = (-1.0 + operand_normalizer) / num_refs
# This can be simplified once scatter has transpose implemented
target_tangents = gather(
g_operand, indices, gather_dnums, np.array(slice_sizes))
tangent_updates = (target_tangents * operand_coef +
g_updates * updates_coef)
tangent_out = scatter_add(g_operand,
return val_out, tangent_out
scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
batching.primitive_batchers[scatter_min_p] = (
partial(_scatter_batching_rule, scatter_min))
ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
scatter_max_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
batching.primitive_batchers[scatter_max_p] = (
partial(_scatter_batching_rule, scatter_max))
ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
operand, indices, updates = primals
g_operand, g_indices, g_updates = tangents
dnums = dimension_numbers
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
val_out = scatter_p.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dnums,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
return val_out, ad_util.Zero.from_value(val_out)
g_operand = ad.instantiate_zeros(g_operand)
g_updates = ad.instantiate_zeros(g_updates)
if unique_indices:
# If the user has promised that the updates don't overlap, we can use a much
# simpler JVP.
val_out = scatter_p.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dnums,
indices_are_sorted=indices_are_sorted, unique_indices=True, mode=mode)
tangent_out = scatter_p.bind(
g_operand, indices, g_updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dnums,
indices_are_sorted=indices_are_sorted, unique_indices=True, mode=mode)
return val_out, tangent_out
# If there are overlapping indices in the scatter, it is unspecified which
# update "wins". So we use the following perhaps surprising scheme:
# a) attach a positive ID to each update in updates, and perform the scatter
# on the IDs
# b) perform the inverse gather on the scattered IDs (similar to
# _scatter_add_transpose).
# c) use the gathered IDs to mask the primal and tangent values.
# d) perform a scatter-add on the masked primal and tangent values. A benefit
# of using scatter-add here is that we don't need a `scatter` transpose
# rule.
# a) attach a positive ID to each update in `updates`, and perform a scatter
# on the IDs.
ids_shape = np.array(updates.shape, dtype=np.int64)
ids_shape[dnums.update_window_dims,] = 1
num_ids =
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
lax._ones(updates, dtype=id_dtype))
scattered_ids = scatter(lax.full(operand.shape, 0, id_dtype),
indices, update_ids, dnums,
unique_indices=unique_indices, mode=mode)
# b) compute the inverse gather that "undoes" the scatter on the id values.
gather_dnums = GatherDimensionNumbers(
slice_sizes = []
pos = 0
for i in range(len(scattered_ids.shape)):
if i in dnums.inserted_window_dims:
pos += 1
gathered_update_ids = gather(scattered_ids, indices,
# c) mask off input elements that do not correspond to a primal output.
masked_operand =, lax._zeros(scattered_ids)),
operand, lax._zeros(operand))
masked_updates =, gathered_update_ids),
updates, lax._zeros(updates))
masked_g_operand =, lax._zeros(scattered_ids)),
g_operand, lax._zeros(g_operand))
masked_g_updates =, gathered_update_ids),
g_updates, lax._zeros(g_updates))
# d) perform scatter-adds to compute the primal and tangent outputs.
val_out = scatter_add(masked_operand, indices, masked_updates,
unique_indices=unique_indices, mode=mode)
tangent_out = scatter_add(masked_g_operand, indices, masked_g_updates,
unique_indices=unique_indices, mode=mode)
return val_out, tangent_out
def _scatter_transpose_rule(t, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
if not unique_indices:
raise NotImplementedError("scatter transpose is only implemented where"
assert not ad.is_undefined_primal(indices)
if ad.is_undefined_primal(updates):
updates_shape = updates.aval.shape
updates_shape = updates.shape
if type(t) is ad_util.Zero:
operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None
operand_t = update_t = None
if ad.is_undefined_primal(operand):
# Zero out gradient entries that correspond to updated indices.
mask = scatter(lax._ones(t, dtype=np.bool_), indices,
lax.full(updates_shape, False),
unique_indices=True, mode=mode)
operand_t =, t, lax._zeros(t))
if ad.is_undefined_primal(updates):
gather_dnums = GatherDimensionNumbers(
slice_sizes = []
pos = 0
for i in range(len(t.shape)):
if i in dimension_numbers.inserted_window_dims:
pos += 1
update_t = gather(t, indices, dimension_numbers=gather_dnums,
slice_sizes=slice_sizes, mode=mode,
return [operand_t, None, update_t]
scatter_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
ad.primitive_jvps[scatter_p] = _scatter_jvp
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
batching.primitive_batchers[scatter_p] = (
partial(_scatter_batching_rule, scatter))
def _scatter_lower_opaque(ctx, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
unique_indices, indices_are_sorted, mode):
aval_x, aval_indices, aval_updates = ctx.avals_in
aval_y, = ctx.avals_out
elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape
trailing_window_dims = [aval_updates.ndim + i for i in range(len(elt_shape))]
dimension_numbers = dimension_numbers._replace(
scatter_lower = partial(
_scatter_lower, update_jaxpr=update_jaxpr, update_consts=update_consts,
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode)
res, = mlir.delegate_lowering(
ctx, scatter_lower, operand, indices, updates,
avals_in=[core.physical_aval(aval_x), aval_indices,
return res
def _scatter_lower(ctx, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
aval_out, = ctx.avals_out
if dtypes.is_opaque_dtype(aval_out.dtype):
return [_scatter_lower_opaque(
ctx, operand, indices, updates,
update_jaxpr=update_jaxpr, update_consts=update_consts,
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode)]
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
updates, dnums=dimension_numbers)
dnums = dimension_numbers
scatter_dnums = hlo.ScatterDimensionNumbers.get(
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
result = mlir.aval_to_ir_types(aval_out)
operand = [operand]
updates = [updates]
op = hlo.ScatterOp(
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
update = op.update_computation.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(update):
update_ctx = ctx.module_context.replace(
if update_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `scatter`.')
out_nodes, _ = mlir.jaxpr_subcomp(
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
(update.arguments[0],), (update.arguments[1],),
return op.results
mlir.register_lowering(scatter_p, _scatter_lower)
mlir.register_lowering(scatter_add_p, _scatter_lower)
mlir.register_lowering(scatter_mul_p, _scatter_lower)
mlir.register_lowering(scatter_min_p, _scatter_lower)
mlir.register_lowering(scatter_max_p, _scatter_lower)
def _real_dtype(dtype): return np.finfo(dtype).dtype
def _scatter_add_lower_gpu(ctx, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
operand_aval_in, _, updates_aval_in = ctx.avals_in
if operand_aval_in.dtype != np.complex128:
return _scatter_lower(ctx, operand, indices, updates,
unique_indices=unique_indices, mode=mode)
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices, updates,
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = hlo.ScatterDimensionNumbers.get(
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
real_dtype = _real_dtype(aval_out.dtype)
operand_type_part = mlir.aval_to_ir_types(
core.ShapedArray(aval_out.shape, real_dtype))
def _scatter(operand_part, updates_part):
operand_part = [operand_part]
updates_part = [updates_part]
scatter = hlo.ScatterOp(
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
add = hlo.AddOp(*reducer.arguments).result
return scatter.result
real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result)
imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result)
return hlo.ComplexOp(real, imag).results
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
def _dynamic_slice_indices(
operand: Union[Array, np.ndarray],
start_indices: Union[Union[Array, np.ndarray], Sequence[ArrayLike]]
) -> List[ArrayLike]:
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
raise ValueError(msg.format(len(start_indices), operand.shape))
if not isinstance(start_indices, (tuple, list)):
if start_indices.ndim != 1: # type: ignore[union-attr]
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape)) # type: ignore[union-attr]
start_indices = list(start_indices)
result: List[ArrayLike] = []
for i, d in zip(start_indices, operand.shape):
# We test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
d = core.dimension_as_value(d)
if isinstance(i, (int, np.integer)):
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
d_arr = lax.convert_element_type(d, _dtype(i))
result.append( < 0, i + d_arr, i))
return result