from functools import partial
from typing import (Any, Callable, Optional, Sequence, Union, Tuple)
import warnings
import numpy as np
from jax import tree_util
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
map = util.safe_map
zip = util.safe_zip
Array = Any
def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `ReduceWindowWithGeneralPadding
flat_operands, operand_tree = tree_util.tree_flatten(operand)
flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
if operand_tree != init_value_tree:
raise ValueError('Operands must have the same tree structure as '
f'init_values: {operand_tree} vs. {init_value_tree}')
if len(flat_operands) == 0:
raise ValueError('reduce_window must have at least one operand.')
if len(flat_operands) != len(flat_init_values):
raise ValueError('Must have same total number of operands as init_values: '
f' {len(flat_operands)} vs. {len(flat_init_values)}')
if isinstance(padding, str):
dilated_window_dims = (
window_dimensions if window_dilation is None else
lax._dilate_shape(window_dimensions, window_dilation))
padding = tuple(lax.padtype_to_pads(
flat_operands[0].shape, dilated_window_dims, window_strides, padding))
padding = tuple(padding)
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values)
if monoid_reducer:
return monoid_reducer(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
flat_init_avals = map(lax._abstractify, flat_init_values)
jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
if operand_tree != out_tree:
raise ValueError(
'reduce_window output must have the same tree structure as the operands'
f' {operand_tree} vs. {out_tree}')
out_flat = reduce_window_p.bind(
*flat_operands, *flat_init_values, jaxpr=jaxpr, consts=consts,
window_strides=tuple(window_strides), padding=padding,
return tree_util.tree_unflatten(out_tree, out_flat)
def _get_monoid_window_reducer(monoid_op: Callable,
xs: Sequence[Array]) -> Optional[Callable]:
if len(xs) != 1:
return None
x, = xs
aval = core.get_aval(x)
if (type(aval) is ConcreteArray) and aval.shape == ():
if monoid_op is lax.add:
return aval.val == 0 and _reduce_window_sum
elif monoid_op is lax.max:
return (aval.val == lax._get_max_identity(aval.dtype)
and _reduce_window_max)
elif monoid_op is lax.min:
return (aval.val == lax._get_min_identity(aval.dtype)
and _reduce_window_min)
return None
def _reduce_window_sum(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_sum_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = lax._const(operand, 1)
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
out, = reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_strides=tuple(window_strides), padding=tuple(padding),
return out
def _reduce_window_max(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_max_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
def _reduce_window_min(operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
return reduce_window_min_p.bind(
operand, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding),
def _reduce_window_logaddexp(
operand: Array, window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
base_dilation: Optional[Sequence[int]] = None,
window_dilation: Optional[Sequence[int]] = None) -> Array:
init_value = lax._const(operand, -np.inf)
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
window_dilation = (1,) * len(window_dimensions)
out, = reduce_window_p.bind(
operand, init_value, jaxpr=jaxpr, consts=consts,
window_strides=tuple(window_strides), padding=tuple(padding),
return out
def _select_and_scatter(operand: Array, select: Callable,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
select_jaxpr, select_consts = lax._reduction_jaxpr(
select, lax._abstractify(init_value))
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
scatter, lax._abstractify(init_value))
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_scatter_add(source: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]]) -> Array:
return select_and_scatter_add_p.bind(
source, operand, select_prim=select_prim,
window_strides=tuple(window_strides), padding=tuple(padding))
def _select_and_gather_add(tangents: Array, operand: Array,
select_prim: core.Primitive,
window_dimensions: core.Shape,
window_strides: Sequence[int],
padding: Sequence[Tuple[int, int]],
base_dilation: Sequence[int],
window_dilation: Sequence[int]) -> Array:
"""Extracts the tangent corresponding to the minimum or maximum element in
each window of the `operand` array.
Wraps XLA's `ReduceWindow
operator, which applies a reduction function to all elements in each window of
the input multi-dimensional array. In this case, the input multi-dimensional
array is built by packing each element in the `operand` array with its
corresponding element in the `tangents` array.
tangents: an array
operand: an array with the same shape as `tangents`
select_prim: a reduction function (restricted to `ge_p` and `le_p`)
window_dimensions: an array of integers for window dimension values
window_strides: an array of integers for window stride values
base_dilation: an array of integers for base dilation values
window_dilation: an array of integers for window dilation values
An array containing the elements in `tangents` corresponding to the output
of the reduction of `operand` fin each window.
return select_and_gather_add_p.bind(
tangents, operand, select_prim=select_prim,
window_strides=tuple(window_strides), padding=tuple(padding),
def _reduce_window_abstract_eval_rule(
*avals, jaxpr, consts, window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
" got operand dtypes {} and init_value dtypes {}.")
raise TypeError(msg.format([o.dtype for o in operand_avals],
[iv.dtype for iv in init_val_avals]))
if any(len(v.shape) != 0 for v in init_val_avals):
msg = ("reduce_window expected init_values to be scalars but init_values "
"have shapes {}.")
raise TypeError(msg.format([v.shape for v in init_val_avals]))
out_shape = _common_reduce_window_shape_rule(
operand_avals[0], window_dimensions, window_strides, padding,
base_dilation, window_dilation)
return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
def _generic_reduce_window_batch_rule(
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation):
num_operands = len(batched_args) // 2
operands, init_values = util.split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
if any(init_bdim is not None for init_bdim in init_value_bdims):
raise NotImplementedError("reduce_window batching is not implemented for "
"initial values")
size = next(x.shape[ax] for x, ax in zip(operands, operand_bdims)
if ax is not None)
operands = [batching.bdim_at_front(arg, bdim, size)
for arg, bdim in zip(operands, operand_bdims)]
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
base_dilation = (1,) + base_dilation
window_dilation = (1,) + window_dilation
outs = reduce_window_p.bind(
*(operands + init_values), jaxpr=jaxpr, consts=consts,
window_dimensions=window_dimensions, window_strides=window_strides,
padding=padding, base_dilation=base_dilation,
return outs, (0,) * num_operands
reduce_window_p = core.Primitive('reduce_window')
reduce_window_p.multiple_results = True
reduce_window_p.def_impl(partial(dispatch.apply_primitive, reduce_window_p))
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
rw = hlo.ReduceWindowOp(
map(mlir.aval_to_ir_type, ctx.avals_out),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
if jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
return rw.results
mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
if not dtypes.issubdtype(operand.dtype, np.number):
msg = "operand to reduce_window_sum must have a number dtype, got {}"
raise TypeError(msg.format(np.dtype(operand.dtype).name))
return _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding,
base_dilation, window_dilation)
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
window_strides, padding, base_dilation,
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
pads = convolution._conv_general_vjp_lhs_padding(
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
base_dilation, window_dilation)
ones = [1] * len(input_shape)
padding_config = [(lo, hi, stride - 1)
for (lo, hi), stride in zip(pads, window_strides)]
pad_cotangent = lax.pad(cotangent, lax._zero(cotangent), padding_config)
result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
[(0, 0)] * len(input_shape),
assert result.shape == input_shape, (result.shape, input_shape)
return [result]
def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
window_dimensions, window_strides, padding,
base_dilation, window_dilation):
operand, = batched_args
bdim, = bdims
if bdim is not None:
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
operand = reduce_window(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
return operand, bdim
reduce_window_sum_p = lax.standard_primitive(
_reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum')
ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
batching.primitive_batchers[reduce_window_sum_p] = partial(
_reduce_window_batch_rule, _reduce_window_sum)
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
window_strides, padding, base_dilation,
assert prim is lax.max_p or prim is lax.min_p
select_prim = lax.ge_p if prim is lax.max_p else lax.le_p
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
window_strides, padding, base_dilation,
def _common_reduce_window_shape_rule(operand, window_dimensions,
window_strides, padding, base_dilation,
lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions,
lax._check_shapelike("reduce_window", "window_strides", window_strides,
lax._check_shapelike("reduce_window", "base_dilation", base_dilation)
lax._check_shapelike("reduce_window", "window_dilation", window_dilation)
if operand.ndim != len(window_dimensions):
msg = ("reduce_window got the wrong number of window_dimensions for "
"operand: got operand shape {} with window_dimensions {}.")
raise TypeError(msg.format(operand.shape, window_dimensions))
if len(window_strides) != len(window_dimensions):
msg = ("reduce_window got inconsistent window_strides and "
"window_dimensions: got window_strides {} and window_dimensions {}.")
raise TypeError(msg.format(window_strides, window_dimensions))
if len(base_dilation) != len(window_dimensions):
msg = ("reduce_window got inconsistent base_dilation and "
"window_dimensions: got base_dilation {} and window_dimensions {}.")
raise TypeError(msg.format(base_dilation, window_dimensions))
if len(window_dilation) != len(window_dimensions):
msg = ("reduce_window got inconsistent window_dilation and "
"window_dimensions: got window_dilation {} and window_dimensions "
raise TypeError(msg.format(window_dilation, window_dimensions))
return reduce_window_shape_tuple(operand.shape, window_dimensions,
window_strides, padding, base_dilation,
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding, base_dilation=None,
if base_dilation is not None:
operand_shape = lax._dilate_shape(operand_shape, base_dilation)
if window_dilation is not None:
window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
pads_lo, pads_hi = util.unzip2(padding)
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
return core.stride_shape(operand_padded, window_dimensions, window_strides)
reduce_window_max_p = lax.standard_primitive(
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max')
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule,
batching.primitive_batchers[reduce_window_max_p] = partial(
_reduce_window_batch_rule, _reduce_window_max)
reduce_window_min_p = lax.standard_primitive(
_common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min')
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule,
_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
batching.primitive_batchers[reduce_window_min_p] = partial(
_reduce_window_batch_rule, _reduce_window_min)
def _reduce_window_lower(
reduce_op, init_value, ctx, operand, *,
window_dimensions, window_strides, padding, base_dilation, window_dilation):
aval_out, = ctx.avals_out
operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
if any(not core.is_constant_shape(s)
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
raise NotImplementedError("ReduceWindowOp for dynamic shapes")
rw = hlo.ReduceWindowOp(
mlir.aval_to_ir_types(aval_out), [operand],
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
return rw.results
mlir.register_lowering(reduce_window_sum_p, partial(
_reduce_window_lower, hlo.AddOp, lambda _: 0))
mlir.register_lowering(reduce_window_min_p, partial(
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
mlir.register_lowering(reduce_window_max_p, partial(
_reduce_window_lower, mlir.max_hlo, lax._get_max_identity))
def _select_and_scatter_shape_rule(
operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding):
lax._check_shapelike("select_and_scatter", "window_dimensions",
lax._check_shapelike("select_and_scatter", "window_strides", window_strides)
if len(window_dimensions) != len(window_strides):
msg = ("select_and_scatter got inconsistent window_strides and "
"window_dimensions: got window_strides {} and window_dimensions {}.")
raise TypeError(msg.format(window_strides, window_dimensions))
return operand.shape
select_and_scatter_p = lax.standard_primitive(
_select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter')
def _select_and_scatter_lower(
ctx, operand, source, init_value, *, select_jaxpr,
select_consts, scatter_jaxpr, scatter_consts, window_dimensions,
window_strides, padding):
operand_aval, source_aval, init_value_aval = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
op = hlo.SelectAndScatterOp(
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
select =, scalar_type)
with ir.InsertionPoint(select):
if select_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `select`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
mlir.TokenSet(), select_consts,
*([a] for a in select.arguments),
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(scatter):
if scatter_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `scatter`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
mlir.TokenSet(), scatter_consts,
*([a] for a in scatter.arguments),
return op.results
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
def _select_and_scatter_add_shape_rule(
source, operand, *, select_prim, window_dimensions, window_strides,
return operand.shape
def _select_and_scatter_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_scatter_add(
source, operand, select_prim, window_dimensions, window_strides,
del g_operand
if type(g_source) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
tangent_out = _select_and_scatter_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding)
return val_out, tangent_out
def _select_and_scatter_add_transpose(
t, source, operand, *, select_prim, window_dimensions, window_strides,
assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
if type(t) is ad_util.Zero:
return [ad_util.Zero(source.aval), None]
ones = (1,) * len(window_dimensions)
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
window_strides, padding, ones, ones)
return [source_t, None]
def _select_and_scatter_add_batch_rule(
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
source, operand = batched_args
s_bdim, o_bdim = batch_dims
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
if bdim is not None)
source = batching.bdim_at_front(source, s_bdim, size)
operand = batching.bdim_at_front(operand, o_bdim, size)
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
out = _select_and_scatter_add(source, operand, select_prim, window_dimensions,
window_strides, padding)
return out, 0
select_and_scatter_add_p = lax.standard_primitive(
_select_and_scatter_add_shape_rule, lax._input_dtype,
ad.primitive_transposes[select_and_scatter_add_p] = \
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
batching.primitive_batchers[select_and_scatter_add_p] = \
def _select_and_scatter_add_impl(source, operand, *,
select_prim, window_dimensions, window_strides,
padding, expand_padding):
dtype = source.dtype
select = lambda x, y: select_prim.bind(x, y)
scatter = lax.bitwise_or if dtype == np.bool_ else lax.add
if expand_padding:
operand_shape = operand.shape
original_padding = padding
identity = (lax._get_max_identity if select_prim is lax.ge_p
else lax._get_min_identity)
pads = [(lo, hi, 0) for (lo, hi) in padding]
operand = lax.pad(operand, identity(dtype), pads)
padding = [(0, 0) for _ in padding]
out = _select_and_scatter(
operand, select, window_dimensions, window_strides, padding, source,
lax._zero(operand), scatter)
if expand_padding:
start_indices = [lo for (lo, hi) in original_padding]
stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
out = slicing.slice(out, start_indices, stop_indices)
return out
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
partial(_select_and_scatter_add_impl, expand_padding=False),
# TODO(b/161704903): workaround for XLA/CPU crash.
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
partial(_select_and_scatter_add_impl, expand_padding=True),
multiple_results=False), platform='cpu')
# TODO(b/182390722): workaround for XLA/GPU crash.
mlir.register_lowering(select_and_scatter_add_p, mlir.lower_fun(
partial(_select_and_scatter_add_impl, expand_padding=True),
multiple_results=False), platform='gpu')
def _select_and_gather_add_shape_rule(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
if tangents.shape != operand.shape:
msg = ("select_and_gather_add tangents and operand shapes must match, "
"got {} and {}.")
raise TypeError(msg.format(tangents.shape, operand.shape))
return _common_reduce_window_shape_rule(
operand, window_dimensions, window_strides, padding, base_dilation,
def _select_and_gather_add_lowering(
ctx: mlir.LoweringRuleContext,
tangents, operand, *, select_prim,
window_dimensions, window_strides, padding, base_dilation, window_dilation,
_, operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
assert isinstance(operand_aval, core.ShapedArray), operand_aval
dtype = operand_aval.dtype
etype = mlir.dtype_to_ir_type(dtype)
nbits = dtypes.finfo(dtype).bits
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits
const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype),
def _broadcast_scalar_const(x, aval_out):
return mlir.broadcast_in_dim(ctx, const(aval_out.dtype, x),
if double_word_reduction:
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks.
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
word_type = mlir.dtype_to_ir_type(word_dtype) # type: ignore
double_word_type = mlir.dtype_to_ir_type(double_word_dtype) # type: ignore
# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
double_word_type_ab_aval = ab_aval.update(dtype=double_word_dtype)
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
a = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
b = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
a = hlo.ShiftLeftOp(a,
_broadcast_scalar_const(nbits, double_word_type_ab_aval))
return hlo.OrOp(a, b)
# Unpacks the first element of a double_word_type.
def fst(t):
assert not ir.RankedTensorType(t.type).shape
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
return hlo.BitcastConvertOp(
ir.RankedTensorType.get([], etype),
hlo.ConvertOp(ir.RankedTensorType.get([], word_type), st)).result
# Unpacks the second element of a double_word_type.
def snd(t, t_aval):
return hlo.BitcastConvertOp(
hlo.ConvertOp(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)).result
# The double-word trick above only works if we have a sufficiently large
# type. As an alternative, we can pack two half words into a single word,
# at the cost of precision.
# TODO(b/73062247): add support for tuple reductions and remove this case.
warnings.warn("Using reduced precision for gradient of reduce-window "
"min/max operator to work around missing XLA support for "
"pair-reductions. This is likely from a second or "
"higher derivative of a max-pooling operation.")
r_nbits = nbits // 2
# Drop/round the bottom mantissa bits.
nexp = dtypes.finfo(dtype).nexp
nmant = r_nbits - nexp - 1
double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype) # type: ignore
# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
b = hlo.ShiftRightLogicalOp(
b, _broadcast_scalar_const(r_nbits, word_type_ab_aval))
return hlo.OrOp(a, b)
# Unpacks the first element of a double_word_type.
def fst(t):
assert not ir.RankedTensorType(t.type).shape
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
# Unpacks the second element of a double_word_type.
def snd(t, t_aval):
return hlo.BitcastConvertOp(
hlo.ShiftLeftOp(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype)))
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
double_word_out_aval = out_aval.update(dtype=double_word_dtype)
rw = hlo.ReduceWindowOp(
pack(operand, tangents, operand_aval),
pack(const(dtype, init), const(dtype, 0), core.ShapedArray((), dtype)),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
scalar_type = ir.RankedTensorType.get([], double_word_type)
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
x, y = reducer.arguments
assert select_prim is lax.ge_p or select_prim is lax.le_p
cmp_op = "GE" if select_prim is lax.ge_p else "LE"
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y)
return [snd(rw.result, double_word_out_aval)]
# TODO(phawkins): use this translation rule on all platforms.
def _select_and_gather_add_using_variadic_reducewindow(
tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
def reducer(x, y):
kx, vx = x
ky, vy = y
which = select_prim.bind(kx, ky)
return (, kx, ky),, vx, vy))
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
_, out = reduce_window(
(operand, tangents),
(np.array(init, dtype=operand.dtype), np.array(0, dtype=operand.dtype)),
reducer, window_dimensions, window_strides, padding, base_dilation,
return out
def _select_and_gather_add_jvp(
primals, tangents, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_gather_add(
source, operand, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation)
del g_operand
if type(g_source) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
tangent_out = _select_and_gather_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding, base_dilation, window_dilation)
return val_out, tangent_out
def _select_and_gather_add_transpose(
t, tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
assert select_prim in (lax.le_p, lax.ge_p)
assert (ad.is_undefined_primal(tangents) and
not ad.is_undefined_primal(operand))
if any(d != 1 for d in window_dilation):
msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
"dilation, got window_dilation={}.")
raise NotImplementedError(msg.format(window_dilation))
if type(t) is ad_util.Zero:
return [ad_util.Zero(tangents.aval), None]
has_base_dilation = any(d != 1 for d in base_dilation)
if has_base_dilation:
select_identity = (lax._get_max_identity if select_prim is lax.ge_p
else lax._get_min_identity)
operand = lax.pad(operand, select_identity(operand.dtype),
tuple((0, 0, d - 1) for d in base_dilation))
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
window_strides, padding)
if has_base_dilation:
result = slicing.slice(result, (0,) * len(result.shape), result.shape,
return [result, None]
def _select_and_gather_add_batching_rule(
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
t, x = batched_args
t_bdim, x_bdim = batch_dims
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
if bdim is not None)
t = batching.bdim_at_front(t, t_bdim, size)
x = batching.bdim_at_front(x, x_bdim, size)
window_dimensions = (1,) + window_dimensions
window_strides = (1,) + window_strides
padding = ((0, 0),) + padding
base_dilation = (1,) + base_dilation
window_dilation = (1,) + window_dilation
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
window_strides, padding, base_dilation,
return (out, 0)
select_and_gather_add_p = lax.standard_primitive(
_select_and_gather_add_shape_rule, lax._input_dtype,
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
batching.primitive_batchers[select_and_gather_add_p] = \
mlir.register_lowering(select_and_gather_add_p, mlir.lower_fun(
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.