Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/lax/control_flow/loops.py

2107 lines
92 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the loop primitives."""
from functools import partial
import inspect
import itertools
import operator
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
import jax
import weakref
from jax._src import core
from jax._src import linear_util as lu
from jax import config # type: ignore[no-redef]
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr)
from jax._src.api_util import shaped_abstractify
from jax._src.tree_util import equality_errors
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache)
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
_make_closed_jaxpr, _prune_zeros, _typecheck_param, allowed_effects)
_map = safe_map
zip = safe_zip
T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.
### Helper functions
def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
"""Promote weakly-typed in_vals to be compatible with out_avals.
Args:
in_vals : flattened list of input values.
in_avals : corresponding list of avals.
out_avals : list of target output avals.
Returns:
in_vals_new : flattened list of modified in_vals with no weak types.
changed : bool; true if in_vals required modification.
"""
if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
# Calling function is responsible for catching this.
return in_vals, False
weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
if not weak_mismatches:
return in_vals, False
for i in weak_mismatches:
new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
return in_vals, True
### scan
Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')
@api_boundary
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where we use [t] here to denote the type t with an additional leading axis.
That is, if t is an array type then [t] represents the type with an additional
leading axis, and if t is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves
each with an additional leading axis.
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
given roughly by this Python implementation::
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
values, and so multiple arrays can be scanned over at once and produce multiple
output arrays. ``None`` is actually a special case of this, as it represents an
empty pytree.
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
lowered to a single WhileOp. That makes it useful for reducing
compilation times for JIT-compiled functions, since native Python
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
XLA computations.
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
across all iterations (and not just be consistent up to NumPy rank/shape
broadcasting and dtype promotion rules, for example). In other words, the type
``c`` in the type signature above represents an array with a fixed shape and
dtype (or a nested tuple/list/dict container data structure with a fixed
structure and arrays with fixed shape and dtype at the leaves).
.. note::
:py:func:`scan` compiles ``f``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
that ``f`` accepts two arguments where the first is a value of the loop
carry and the second is a slice of ``xs`` along its leading axis, and that
``f`` returns a pair where the first element represents a new value for
the loop carry and the second represents a slice of the output.
init: an initial loop carry value of type ``c``, which can be a scalar,
array, or any pytree (nested Python tuple/list/dict) thereof, representing
the initial loop carry value. This value must have the same structure as
the first element of the pair returned by ``f``.
xs: the value of type ``[a]`` over which to scan along the leading axis,
where ``[a]`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
length: optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
be used to perform scans where no input ``xs`` are needed).
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll: optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(f):
raise TypeError("lax.scan: f argument should be a callable.")
xs_flat, xs_tree = tree_flatten(xs)
try:
lengths = [x.shape[0] for x in xs_flat]
except AttributeError as err:
msg = "scan got value with no leading axis to scan over: {}."
raise ValueError(
msg.format(', '.join(str(x) for x in xs_flat
if not hasattr(x, 'shape')))) from err
if length is not None:
length = int(length)
if not all(length == l for l in lengths):
msg = ("scan got `length` argument of {} which disagrees with "
"leading axis sizes {}.")
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
else:
unique_lengths = set(lengths)
if len(unique_lengths) > 1:
msg = "scan got values with different leading axis sizes: {}."
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
elif len(unique_lengths) == 0:
msg = "scan got no values to scan over and `length` not provided."
raise ValueError(msg)
else:
length, = unique_lengths
if config.jax_disable_jit:
if length == 0:
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda *ys: jax.numpy.stack(ys)
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves]
return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children
# The carry input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
if changed:
init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
_check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out)
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `scan`: {disallowed_effects}')
out = scan_p.bind(*consts, *in_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)),
unroll=unroll)
return tree_unflatten(out_tree, out)
def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals):
try:
sig = inspect.signature(body_fun)
except (ValueError, TypeError):
sig = None
carry_name = sig and list(sig.parameters)[0]
if carry_name:
component = lambda p: (f'the input carry component {carry_name}{keystr(p)}'
if p else f'the input carry {carry_name}')
else:
component = lambda p: (f'the input carry at path {keystr(p)}'
if p else 'the input carry')
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
paths, in_carry_flat = unzip2(leaves_and_paths)
in_avals = _map(_abstractify, in_carry_flat)
if in_carry_tree != out_carry_tree:
try:
out_carry = tree_unflatten(out_carry_tree, out_avals)
except:
out_carry = None
if out_carry is None:
differences = [f'the input tree structure is:\n{in_carry_tree}\n',
f'the output tree structure is:\n{out_carry_tree}\n']
else:
differences = '\n'.join(
f' * {component(path)} is a {thing1} but the corresponding component '
f'of the carry output is a {thing2}, so {explanation}\n'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry))
raise TypeError(
"Scanned function carry input and carry output must have the same "
"pytree structure, but they differ:\n"
f"{differences}\n"
"Revise the scanned function so that its output is a pair where the "
"first element has the same pytree structure as the first argument."
)
if not all(_map(core.typematch, in_avals, out_avals)):
differences = '\n'.join(
f' * {component(path)} has type {in_aval.str_short()}'
' but the corresponding output carry component has type '
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval))
raise TypeError(
"Scanned function carry input and carry output must have equal types "
"(e.g. shapes and dtypes of arrays), "
"but they differ:\n"
f"{differences}\n"
"Revise the scanned function so that all output types (e.g. shapes "
"and dtypes) match the corresponding input types."
)
def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
assert not core.typematch(a1, a2)
if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray):
dtype_mismatch = a1.dtype != a2.dtype
shape_mismatch = a1.shape != a2.shape
return (', so ' * (dtype_mismatch or shape_mismatch) +
'the dtypes do not match' * dtype_mismatch +
' and also ' * (dtype_mismatch and shape_mismatch) +
'the shapes do not match' * shape_mismatch)
return ''
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
carry = init
ys = []
for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
def cond_fun(vals):
i, *_ = vals
return i < length
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out
ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, block_length)
assert rem == 0
partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)
prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)
f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll):
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
f_impl = core.jaxpr_as_fun(jaxpr)
if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll
if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
else:
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))
outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys = split_list(outs, [num_carry])
if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)
return (*carry, *ys)
def _stack(aval, vals):
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)
def _concatenate(aval, x1, x2):
return lax.concatenate([x1, x2], 0)
def _split_leading_dim(i, aval, x):
assert x.ndim >= 1
return (slicing.slice_in_dim(x, 0, i),
slicing.slice_in_dim(x, i, x.shape[0]))
def _dynamic_index_array(i, aval, x):
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
def _partition_leading(sz0, sz1, aval, x):
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
def _combine_leading(sz0, sz1, aval, x):
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)
def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
return carry_avals + ys_avals, jaxpr.effects
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear, unroll):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
# Fixpoint computation of which carry are not ad.zero: either
# non-zero from init, or the carry out is non-zero. Each iteration promotes
# at least one carry to non-zero. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_nz.
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
nonzeros = const_nz + carry_nz + xs_nz
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
if carry_nz_out == carry_nz:
break
else:
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
else:
assert False, "Fixpoint not reached"
tangents = [ad.instantiate_zeros(t) if nz else t
for t, nz in zip(tangents, nonzeros)]
consts, init, xs = split_list(primals, [num_consts, num_carry])
all_tangents = split_list(tangents, [num_consts, num_carry])
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
jaxpr_jvp_rearranged = ad.rearrange_binders(
jaxpr_jvp,
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
+ init_linear + [True] * len(init_dot)
+ xs_linear + [True] * len(xs_dot))
out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts + len(consts_dot),
num_carry=num_carry + len(init_dot),
linear=jaxpr_jvp_linear, unroll=unroll)
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
tangents_out_iter = iter(carry_dot + ys_dot)
tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(primals_out, nonzeros_out)]
return primals_out, tangents_out
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [not t.pval.is_known() for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
# Fixpoint computation of which carry elements are unknown. Each iteration
# promotes at least one carry to unknown. We need at most len(carry)
# iterations, but we need one last iteration to prepare the jaxpr based on the
# final carry_uk.
carry_uk = init_uk
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
num_res = len(res_avals)
del res_avals, carry_uk_out
# Instantiate those inputs which must be treated as unknown from the fixpoint.
tracers = [trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns)]
# The residual inputs and outputs of the jaxprs produced haven't yet been
# adapted to the scan calling convention; in particular, jaxpr_known has its
# residual outputs all at the end, meaning they're extensive outputs (which is
# fully general but may be wasteful for residuals which are loop-invariant)
# while jaxpr_unknown has its corresponding residual inputs at the front (just
# as a convention with partial_eval_jaxpr_nounits), making them constant
# inputs. To make them consistent, we move the residual inputs on
# jaxpr_unknown to the end, even though we may move some back in the sequel.
jaxpr_unknown = pe.move_binders_to_back(
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))
# At this point, all residuals are treated as extensive outputs of jaxpr_known
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
# invariant can be hoisted out of the scan, rather than letting them get
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
# want to broadcast the matrix!). So, outside the loop we perform a partial
# evaluation with known 'const' inputs (but all other inputs unknown).
const_pvals = [pe.PartialVal.known(t.pval.get_known())
for t in tracers[:num_consts] if t.pval.is_known()]
other_pvals = [pe.PartialVal.unknown(aval)
for aval in jaxpr_known.in_avals[len(const_pvals):]]
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
# (known values in invar_pvals_out) and also computed loop-invariant values
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
# previous consts). We need to collect the computed inteisive residuals, and
# move corresponding intensive residual binders in jaxpr_unknown to the front.
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
jaxpr_unknown = pe.move_binders_to_front(
jaxpr_unknown,
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
# As another optimization, for any extensive inputs that are just forwarded to
# extensive outputs, to avoid a copy (which would be looping over
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
# Prune fwds_known to include only extensive input to extensive output.
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
in_idx is not None and
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
else None for out_idx, in_idx in enumerate(fwds_known)]
# Drop any extensive output we can instead get by forwarding an input.
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
jaxpr_known_ = jaxpr_known_.replace(
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
del jaxpr_known_
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
# Run the known part of the scan (if it has any outputs or effects).
known_inputs = (list(jaxpr_known_consts) +
[t.pval.get_known() for t in tracers[num_consts:]
if t.pval.is_known()])
if not jaxpr_known.out_avals and not jaxpr_known.effects:
out_known = []
else:
linear_known = [False] * len(known_inputs) # conservative!
out_known = scan_p.bind(
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk),
linear=tuple(linear_known), unroll=unroll)
del linear_known
# Complete the known output by filling in forwarded values using fwds_known.
out_known_iter = iter(out_known)
out_known = [next(out_known_iter) if f is None
else _maybe_put(known_inputs[f]) for f in fwds_known]
assert next(out_known_iter, None) is None
del known_inputs, out_known_iter
# Split known outputs from residuals.
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
assert len(intensive_res) + len(extensive_res) == num_res
# Create input tracers for jaxpr_unknown bind.
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
intensive_res = _map(trace.new_instantiated_const, intensive_res)
extensive_res = _map(trace.new_instantiated_const, extensive_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval)
for y_aval in y_avals]
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in itertools.chain(carry_avals, ys_avals)]
del carry_avals, y_avals
# Create equation.
linear_unknown = tuple([False] * len(intensive_res) +
[l for l, uk in zip(linear, unknowns) if uk] +
[False] * len(extensive_res))
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
out_tracers, scan_p,
dict(reverse=reverse, length=length, unroll=unroll,
jaxpr=jaxpr_unknown, linear=linear_unknown,
num_consts=len(intensive_res) + sum(const_uk),
num_carry=sum(carry_uk)),
jaxpr_unknown.effects, source)
for t in out_tracers: t.recipe = eqn
# Merge known and unknown outputs into final result.
return util.merge_lists(out_uk, out_known, out_tracers)
def _maybe_put(x):
if isinstance(x, np.ndarray):
return dispatch._put_x(
x, jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]),
shaped_abstractify(x), False)
else:
return x
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
num_eres = len(xs_lin) - sum(xs_lin)
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
raise NotImplementedError
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
_, eres = split_list(xs, [sum(xs_lin)])
assert not any(ad.is_undefined_primal(r) for r in ires)
assert not any(ad.is_undefined_primal(r) for r in eres)
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
ct_carry, ct_ys = split_list(cts, [num_carry])
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
jaxpr_trans = _transpose_scan_jaxpr(
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
linear_trans = ([False] * num_ires +
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
[False] * num_eres)
outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
unroll=unroll)
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
# if an axis isn't reduced
res1_avals, c_avals, a_avals, res2_avals = split_list(
jaxpr.in_avals, [num_res1, num_c, num_a])
num_b = len(jaxpr.out_avals)
b_avals = list(jaxpr.out_avals)
@lu.wrap_init
def transposed(*res1_cbar_bbar_res2):
res1, c_bar, b_bar, res2 = split_list(
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
_map(ad.add_tangents, c_bar, new_c_bar))
return c_bar + a_bar
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
dims, reverse, length,
jaxpr, num_consts, num_carry, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
orig_batched = [d is not batching.not_mapped for d in dims]
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_batched.
carry_batched = init_batched
for _ in range(1 + len(carry_batched)):
batched = const_batched + carry_batched + xs_batched
jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, axis_size, batched,
instantiate=carry_batched + [False] * num_ys,
axis_name=axis_name,
spmd_axis_name=spmd_axis_name,
main_type=main_type)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched:
break
else:
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
else:
assert False, "Fixpoint not reached"
consts, init, xs = split_list(args, [num_consts, num_carry])
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, consts_bdims)]
new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched
else batching.moveaxis(x, d, 0) if now_batched else x
for x, d, was_batched, now_batched in
zip(init, init_bdims, init_batched, carry_batched)]
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs
outs = scan_p.bind(
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims
def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
for i in range(1 + num_carry):
used_outputs = used_carry_out + used_extensive_out
jaxpr_dce, used_inputs = pe.dce_jaxpr(
jaxpr.jaxpr, used_outputs,
instantiate=[False] * num_consts + used_carry_out + [False] * num_xs)
used_consts, used_carry_in, used_extensive_in = \
split_list(used_inputs, [num_consts, num_carry])
if list(used_carry_in) == list(used_carry_out):
break
else:
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
# TODO(mattjj,sharadmv): don't assume effects are never DCE'd?
new_invars = [v for v, used in zip(eqn.invars, used_inputs) if used]
new_outvars = [v for v, used in zip(eqn.outvars, used_outputs) if used]
_, new_effects = eqn.primitive.abstract_eval(*[v.aval for v in new_invars],
**new_params)
new_eqn = pe.new_jaxpr_eqn(
new_invars,
new_outvars,
eqn.primitive, new_params, new_effects, eqn.source_info)
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn
# TODO(mattjj): de-duplicate code with _scan_partial_eval
def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_ys = len(jaxpr.out_avals) - num_carry
# Fixpoint (trivial on 'inst_in', since we might as well make all inputs
# available as DCE can subsequently prune any unused ones)
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
for _ in range(1 + len(carry_uk)):
unks_in = const_uk + carry_uk + xs_uk
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True,
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
# Move all residual binders to the back of jaxpr_staged so they're extensive.
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
res_avals = jaxpr_staged.in_avals[:num_res]
jaxpr_staged = pe.move_binders_to_back(
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
# As an optimization, hoist loop-invariant residuals out of the loop rather
# than using extensive outputs for them. See _scan_partial_eval for comments.
num_const_known = len(const_uk) - sum(const_uk)
num_carry_known = len(carry_uk) - sum(carry_uk)
num_xs_known = len( xs_uk) - sum( xs_uk)
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \
pe.partial_eval_jaxpr_nounits(
jaxpr_known,
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
# jaxpr_known_hoist produces intensive residuals followed by the constants for
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
jaxpr_staged = pe.move_binders_to_front(
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
del loop_dep, num_carry_known, num_xs_known, const_uk
# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
extensive_res = _map(newvar, ext_avals)
# Create known eqn, which is a call_p combining evaluation of
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
# jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist
# (corresponding to consts_known_lp_avals) followed by known carry and xs.
linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
_, linear_known_ = split_list(linear_known_, [num_const_known])
linear_known = [False] * len(consts_known_lp_avals) + linear_known_
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
num_consts=len(consts_known_lp_avals),
num_carry=len(carry_uk)-sum(carry_uk),
linear=tuple(linear_known))
@lu.wrap_init
def known(*ins_known):
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
return [*intensive_res, *out_loop]
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in ins_known])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
eqn.source_info)
# Create the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
[False] * len(extensive_res))
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
num_consts=len(intensive_res) + eqn.params['num_consts'],
linear=tuple(linear_staged))
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
out_binders_staged, eqn.primitive,
params_staged, jaxpr_staged.effects,
eqn.source_info)
new_vars = [*new_inst, *intensive_res, *extensive_res]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
if not bind_time:
_, *in_atoms = in_atoms
avals = [x.aval for x in in_atoms]
tc = partial(_typecheck_param, 'scan')
tc(reverse, 'reverse', 'bool', type(reverse) is bool)
tc(num_consts, 'num_consts', 'non-negative int',
type(num_consts) is int and num_consts >= 0)
tc(num_carry, 'num_carry', 'non-negative int',
type(num_carry) is int and num_carry >= 0)
tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
tc(linear, 'linear', 'tuple of bool',
type(linear) is tuple and all(type(x) is bool for x in linear))
tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)
tc(length, 'length', 'non-negative int', core.greater_equal_dim(length, 0))
if len(linear) != len(avals):
raise core.JaxprTypeError(
f'scan param linear has length {len(linear)} for {len(avals)} operands')
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
jaxpr.in_avals, [num_consts, num_carry])
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a)
for a in y_avals_mapped]
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
raise core.JaxprTypeError(
f'scan input carry input and output types mismatch: '
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
f'called with consts of type\n{_avals_short(const_avals)}')
if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
f'called with initial carry of type\n{_avals_short(init_avals)}')
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}')
return [*init_avals, *y_avals], jaxpr.effects
def _scan_pp_rule(eqn, context, settings):
printed_params = dict(eqn.params)
del printed_params['linear']
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
del printed_params['length']
if printed_params['unroll'] == 1:
del printed_params['unroll']
if printed_params['num_carry'] == 0:
del printed_params['num_carry']
if printed_params['num_consts'] == 0:
del printed_params['num_consts']
if not printed_params['reverse']:
del printed_params['reverse']
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
def scan_bind(*args, **params):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_scan_typecheck(True, *in_atoms, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
return core.AxisPrimitive.bind(scan_p, *args, **params)
scan_p = core.AxisPrimitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.register_initial_style_primitive(scan_p)
mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True))
batching.axis_primitive_batchers[scan_p] = partial(_scan_batching_rule, None)
batching.spmd_axis_primitive_batchers[scan_p] = _scan_batching_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
### while_loop
@api_boundary
def while_loop(cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T) -> T:
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of ``while_loop`` are given by this Python implementation::
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
to a single WhileOp. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
Another difference from using Python-native loop constructs is that
``while_loop`` is not reverse-mode differentiable because XLA computations
require static bounds on memory requirements.
.. note::
:py:func:`while_loop` compiles ``cond_fun`` and ``body_fun``, so while it
can be combined with :py:func:`jit`, it's usually unnecessary.
Args:
cond_fun: function of type ``a -> Bool``.
body_fun: function of type ``a -> a``.
init_val: value of type ``a``, a type that can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof, representing the initial
loop carry value.
Returns:
The output from the final iteration of body_fun, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not (callable(body_fun) and callable(cond_fun)):
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
if config.jax_disable_jit:
try:
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
except core.ConcretizationTypeError:
# Can't run this while_loop in Python (e.g. because there's a vmap
# transformation on it), so we fall back to the primitive version.
pass
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, "while_loop")
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
# The body input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
if changed:
new_init_val, = tree_unflatten(in_tree, new_init_vals)
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
cond_jaxpr, cond_consts, body_consts, body_tree = rest
in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
_check_tree_and_avals("body_fun output and input",
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = allowed_effects.filter_not_in(effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)
def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts
) -> effects.Effects:
joined_effects = set()
for eff in cond_jaxpr.effects:
if isinstance(eff, effects.JaxprInputEffect):
index = eff.input_index
if index >= cond_nconsts:
index += body_nconsts
eff = eff.replace(input_index=index)
joined_effects.add(eff)
for eff in body_jaxpr.effects:
if isinstance(eff, effects.JaxprInputEffect):
index = eff.input_index + cond_nconsts
eff = eff.replace(input_index=index)
joined_effects.add(eff)
return joined_effects
def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts,
cond_nconsts):
del avals
joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts,
cond_nconsts)
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
args, dims, cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr):
from jax._src.callback import _IOEffect, _OrderedIOEffect
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
for branch in [body_jaxpr, cond_jaxpr]):
raise NotImplementedError(
"IO effect not supported in vmap-of-while.")
orig_batched = [d is not batching.not_mapped for d in dims]
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts])
cconst_dims, bconst_dims, init_dims = split_list(dims, [cond_nconsts, body_nconsts])
carry_bat = init_bat
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations to
# reach a fixpoint.
for _ in range(1 + len(carry_bat)):
_, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if carry_bat == carry_bat_out:
break
carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out)
else:
assert False, "Fixpoint not reached"
# Knowing how the carry is batched now, we can determine if the predicate is
# batched.
_, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if pred_bat:
# If the predicate is batched, we have to batch *all* of the carry
# regardless of if the body needs it.
carry_bat = [True] * len(carry_bat)
carry_dims = [0] * len(carry_bat)
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims,
carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name,
main_type=main_type)
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, [0],
axis_name=axis_name, spmd_axis_name=spmd_axis_name,
main_type=main_type)
else:
# If the predicate is not batched, we can look at the `cond_jaxpr`'s out
# shape to determine the rank of the predicate. From this rank we pick the
# dims of the carry to be batched to ensure that the predicate shape is a
# prefix of the carry in and out shapes. We can then batch the `body_jaxpr`
# according to these new batch dims.
cond_rank = len(cond_jaxpr.out_avals[0].shape)
carry_dims = [cond_rank if b else None for b in carry_bat]
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# Now we need to rebatch the `cond_jaxpr` according to the new dims of the
# carry.
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,),
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# To prepare the `init` to the `while_p`, we broadcast values if they are
# unbatched and need to have an out axis. If their current batch axis does not
# match the one it needs to be for the translation rule to work, we move it
# into place.
new_init = []
for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
new_init.append(batching.broadcast(x, axis_size, new_axis))
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
new_init.append(x)
else:
assert new_axis is not batching.not_mapped
new_init.append(batching.moveaxis(x, old_axis, new_axis))
outs = while_p.bind(*(cconsts + bconsts + new_init),
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
return outs, carry_dims
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
body_jaxpr):
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
body_nonzeros = bconst_nz + carry_nz
body_jvp, nonzeros_out = ad.jvp_jaxpr(
body_jaxpr, body_nonzeros, instantiate=carry_nz)
if nonzeros_out == carry_nz:
break
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
else:
assert False, "Fixpoint not reached"
nonzeros = cconst_nz + body_nonzeros
tangents = [ad.instantiate_zeros(t) if nz else t
for t, nz in zip(tangents, nonzeros)]
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
bconst_dot = _prune_zeros(bconst_dot)
init_dot = _prune_zeros(init_dot)
num_carry = len(primals) - cond_nconsts - body_nconsts
body_jvp_rearranged = ad.rearrange_binders(
body_jvp,
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
[num_carry], [len(init_dot)])
newvar = core.gensym([cond_jaxpr.jaxpr])
invars_aug = (
cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
invars_aug,
cond_jaxpr.jaxpr.outvars,
cond_jaxpr.jaxpr.eqns,
cond_jaxpr.jaxpr.effects)
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
out = while_p.bind(
*(cconst + bconst + bconst_dot + init + init_dot),
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_augmented,
body_nconsts=len(bconst) + len(bconst_dot),
body_jaxpr=body_jvp_rearranged)
out_carry, out_carry_dot = split_list(out, [num_carry])
out_tangents_iter = iter(out_carry_dot)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_carry, nonzeros_out)]
return out_carry, out_tangents
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
# As long as some carry (and hence output) are known and the output of
# `cond_jaxpr` is known, we use a portion of the loop body to compute the
# known outputs of the `while_loop`. For the unknown outputs we generate a
# jaxpr to run the whole while, including recomputing the known parts,
# basically like building in checkpointing/rematieralization. This means that
# we don't actually save any computation by partial evaluation if there are
# unknown outputs.
#
# What this achieves is twofold: jax.linearize works, and we can give a proper
# error for reverse differentiation of `while`.
unknowns = [not t.pval.is_known() for t in tracers]
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
cond_consts_uk, body_consts_uk, carry_init_uk = \
split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes at least one
# carry to unknown. We need one last iteration to prepare the jaxpr.
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
if carry_out_uk == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
else:
assert False, "Fixpoint not reached"
cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
# If conditional is unknown, or all inputs are known, or all are unknown,
# just do the default processing.
return trace.default_process_primitive(while_p, tracers, params)
# Run the known part of the while.
in_consts = [t.pval.get_known() for uk, t in
zip(cond_consts_uk + body_consts_uk + carry_uk, tracers)
if not uk]
cond_nconsts_known = len(cond_consts_uk) - sum(cond_consts_uk)
body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk)
num_known_outs = len(carry_uk) - sum(carry_uk)
# TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs
body_jaxpr_known = body_jaxpr_known.replace(
jaxpr=body_jaxpr_known.jaxpr.replace(
outvars=body_jaxpr_known.jaxpr.outvars[:num_known_outs]))
out_known = while_p.bind(
*in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known,
body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known)
del body_jaxpr_known
# Run the whole while_loop to get all the outputs, then merge with known ones
out_tracers_ = trace.default_process_primitive(while_p, tracers, params)
out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk]
return util.merge_lists(carry_uk, out_known, out_tracers)
# TODO(mattjj): de-duplicate code with _while_partial_eval
def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # We can't save any residuals anyway (w/o dynamic shapes)!
cond_jaxpr = eqn.params['cond_jaxpr']
cond_nconsts = eqn.params['cond_nconsts']
body_jaxpr = eqn.params['body_jaxpr']
body_nconsts = eqn.params['body_nconsts']
cond_consts_uk, body_consts_uk, carry_init_uk = \
split_list(unks_in, [cond_nconsts, body_nconsts])
# Fixpoint to compute known part of the body (trivial on 'inst_in', since we
# make all inputs available as DCE can subsequently prune any unused ones)
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_unks_in = body_consts_uk + carry_uk
jaxpr_known_, _, carry_uk_out, _, num_res = \
pe.partial_eval_jaxpr_custom(
body_jaxpr.jaxpr, in_unknowns=body_unks_in, in_inst=True,
ensure_out_unknowns=carry_uk, ensure_out_inst=True,
saveable=ad_checkpoint.nothing_saveable)
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
assert not num_res
body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts)
del jaxpr_known_, carry_uk_out, num_res
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
# Compute the known part of cond_fun (basically pruning inputs on known side).
cond_unks_in = cond_consts_uk + carry_uk
cond_jaxpr_known_, _, [cond_uk], _, _ = \
pe.partial_eval_jaxpr_custom(
cond_jaxpr.jaxpr, cond_unks_in, in_inst=True,
ensure_out_unknowns=False, ensure_out_inst=True,
saveable=ad_checkpoint.nothing_saveable)
# NOTE(mattjj): I think it should be impossible for the condition to be
# unknown, but asserting that caused a test failure in diffrax. So
# we handle it: if it is unknown, stage out the whole cond function.
if cond_uk:
return None, eqn, [True] * len(carry_uk), [True] * len(carry_uk), new_inst
cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts)
del cond_uk
# Build the known eqn.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(carry_uk, eqn.outvars)
params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known,
cond_nconsts=len(cond_consts_uk) - sum(cond_consts_uk),
body_nconsts=len(body_consts_uk) - sum(body_consts_uk))
effects_known = core.join_effects(cond_jaxpr_known.effects,
body_jaxpr_known.effects)
eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p,
params_known, effects_known, eqn.source_info)
# Staged eqn is same as input eqn.
eqn_staged = eqn
unks_out = carry_uk
inst_out = [True] * len(unks_out)
return eqn_known, eqn_staged, unks_out, inst_out, new_inst
def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop with dynamic start/stop values. "
"Try using lax.scan, or using fori_loop with static start/stop.")
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
# this:
# ```
# while cond(x):
# x = body(x)
# ```
# into something that looks like this:
# ```
# while True:
# token, pred = cond(token, x)
# if not pred:
# break
# token, x = body(token, x)
# ```
# Unfortunately, with a WhileOp we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
# def new_cond(pred, token, x):
# return pred
# token, pred = cond(token, x)
# while new_cond(pred, token, x):
# token, x = body(token, x)
# token, pred = cond(token, x)
# ```
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)
cond_ordered_effects = effects.ordered_effects.filter_in(cond_jaxpr.effects)
if cond_ordered_effects:
def cond(args):
# Pred can be batched
pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
if batched:
pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape))))
return pred
def body(args):
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
def new_cond(pred_args):
pred, _ = pred_args
return pred
def new_body(pred_args):
_, args = pred_args
args = body(args)
pred = cond(args)
return pred, args
def fun(*args):
pred = cond(args)
_, out = while_loop(new_cond, new_body, (pred, args))
return out
return mlir.lower_fun(fun)(ctx, *args)
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
body_effects = effects.ordered_effects.filter_in(body_jaxpr.effects)
num_tokens = len(body_effects)
tokens = [ctx.tokens_in.get(eff) for eff in body_effects]
token_types = [mlir.token_type() for _ in tokens]
loop_carry_types = [*token_types, *loop_carry_types]
flat_loop_carry_types = util.flatten(loop_carry_types)
args = [*tokens, *args]
flat_args = mlir.flatten_lowering_ir_args(args)
while_op = hlo.WhileOp(flat_loop_carry_types, flat_args)
# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
name_stack = ctx.module_context.name_stack.extend('while')
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
# Remove tokens from cond args
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond'))
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())],
tokens_in=mlir.TokenSet(),
tokens_out=None)
pred, = lax._unary_reduce_lower(
hlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
hlo.ReturnOp([pred])
# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(body_block):
flat_body_args = [
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
# Tokens are at the front of the args list to the while loop
token_args, body_args = util.split_list(body_args, [num_tokens])
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body'))
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts),
*(y + z), dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=name_stack.extend('body_pred'))
((body_pred,),), _ = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
new_z = _map(
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
if tokens:
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
return z
def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts,
cond_nconsts)
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
return body_jaxpr.out_avals, joined_effects
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(dispatch.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = partial(_while_loop_batching_rule, None)
batching.spmd_axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
def _pred_bcast_select_hlo(ctx,
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [hlo.AfterAllOp([x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
x_y_aval = core.physical_aval(x_y_aval)
bcast_pred = mlir.broadcast_in_dim(
ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
broadcast_dimensions=list(range(len(pred_aval.shape))))
return hlo.SelectOp(bcast_pred, x, y).results
### fori_loop
def _fori_cond_fun(loop_carry):
i, upper, _ = loop_carry
return lax.lt(i, upper)
@weakref_lru_cache
def _fori_body_fun(body_fun):
body_fun = weakref.ref(body_fun)
def while_body_fun(loop_carry):
i, upper, x = loop_carry
return lax.add(i, lax._const(i, 1)), upper, body_fun()(i, x)
return while_body_fun
@weakref_lru_cache
def _fori_scan_body_fun(body_fun):
body_fun = weakref.ref(body_fun)
def scanned_fun(loop_carry, _):
i, x = loop_carry
return (i + 1, body_fun()(i, x)), None
return scanned_fun
@api_boundary
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
The semantics of ``fori_loop`` are given by this Python implementation::
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
As the Python version suggests, setting ``upper <= lower`` will produce no
iterations. Negative or custom increments are not supported.
Unlike that Python version, ``fori_loop`` is implemented in terms of either a
call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the
trip count is static (meaning known at tracing time, perhaps because ``lower``
and ``upper`` are Python integer literals) then the ``fori_loop`` is
implemented in terms of :func:`~scan` and reverse-mode autodiff is supported;
otherwise, a ``while_loop`` is used and reverse-mode autodiff is not
supported. See those functions' docstrings for more information.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
.. note::
:py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: function of type ``(int, a) -> a``.
init_val: initial loop carry value of type ``a``.
Returns:
Loop value from the final iteration, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(body_fun):
raise TypeError("lax.fori_loop: body_fun argument should be callable.")
# TODO(phawkins): perhaps do more type checking here, better error messages.
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
if lower_dtype == upper_dtype:
dtype = lower_dtype
else:
# As a special case: allow promotion of weak integers (e.g., Python scalars)
# This improves the ergonomics if one but not both of the loop bounds is a
# scalar.
dtype = None
if (np.issubdtype(lower_dtype, np.signedinteger) and
np.issubdtype(upper_dtype, np.signedinteger)):
lower_weak = dtypes.is_weakly_typed(lower)
upper_weak = dtypes.is_weakly_typed(upper)
if lower_weak and not upper_weak:
dtype = upper_dtype
elif not lower_weak and upper_weak:
dtype = lower_dtype
if dtype is None:
raise TypeError("lower and upper arguments to fori_loop must have equal "
f"types, got {lower_dtype.name} and {upper_dtype.name}")
# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
if (isinstance(core.get_aval(lower), ConcreteArray) and
isinstance(core.get_aval(upper), ConcreteArray)):
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
use_scan = False
else:
use_scan = True
else:
use_scan = False
if use_scan:
if config.jax_disable_jit and upper_ == lower_:
# non-jit implementation of scan does not support length=0
return init_val
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
None, length=upper_ - lower_)
return result
if lower_dtype != dtype:
lower = lax.convert_element_type(lower, dtype)
if upper_dtype != dtype:
upper = lax.convert_element_type(upper, dtype)
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
return result
### map and miscellaneous rules
@api_boundary
def map(f, xs):
"""Map a function over leading array axes.
Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.
When ``xs`` is an array type, the semantics of :func:`~map` are given by this
Python implementation::
def map(f, xs):
return np.stack([f(x) for x in xs])
Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
many of the same advantages over a Python loop apply: ``xs`` may be an
arbitrary nested pytree type, and the mapped computation is compiled only
once.
Args:
f: a Python function to apply element-wise over the first axis or axes of
``xs``.
xs: values over which to map along the leading axis.
Returns:
Mapped values.
"""
g = lambda _, x: ((), f(x))
_, ys = scan(g, (), xs)
return ys
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
stacked_keys, stacked_bits = map(map_body, key)
return (stacked_keys, stacked_bits), (0, 0)
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore
### associative_scan
@api_boundary
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
"""Performs a scan with an associative binary operation, in parallel.
For an introduction to associative scans, see [BLE1990]_.
Args:
fn: A Python callable implementing an associative binary operation with
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
must satisfy the equation
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
The inputs and result are (possibly nested Python tree structures of)
array(s) matching ``elems``. Each array has a dimension in place
of the ``axis`` dimension. `fn` should be applied elementwise over
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
elementwise function.)
The result ``r`` has the same shape (and structure) as the two inputs
``a`` and ``b``.
elems: A (possibly nested Python tree structure of) array(s), each with
an ``axis`` dimension of size ``num_elems``.
reverse: A boolean stating if the scan should be reversed with respect to
the ``axis`` dimension.
axis: an integer identifying the axis over which the scan should occur.
Returns:
A (possibly nested Python tree structure of) array(s) of the same shape
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
result of recursively applying ``fn`` to combine the first ``k`` elements
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
Array([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
Array([6, 6, 5, 3], dtype=int32)
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
University.
"""
if not callable(fn):
raise TypeError("lax.associative_scan: fn argument should be callable.")
elems_flat, tree = tree_flatten(elems)
if reverse:
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
def combine(a_flat, b_flat):
# Lower `fn` to operate on flattened sequences of elems.
a = tree_unflatten(tree, a_flat)
b = tree_unflatten(tree, b_flat)
c = fn(a, b)
c_flat, _ = tree_flatten(c)
return c_flat
# Check that all inputs have a consistent leading dimension `num_elems`.
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
if core.is_special_dim_size(elems_flat[0].shape[axis]):
raise NotImplementedError("associative scan over axis "
f"of non-constant size: {elems_flat[0].shape[axis]}. You may be "
"able to avoid this on TPU. See b/274176030.")
num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Array inputs to associative_scan must have the same '
'first dimension. (saw: {})'
.format([elem.shape for elem in elems_flat]))
# Summary of algorithm:
#
# Consider elements of `_scan(elems)` at odd indices. That's the same as first
# summing successive pairs of elements of `elems` and performing a scan on
# that half sized tensor. We perform the latter scan by recursion.
#
# Now consider the even elements of `_scan(elems)`. These can be computed
# from the odd elements of `_scan(elems)` by adding each odd element of
# `_scan(elems)` to the matching even element in the original `elems`.
#
# We return the odd and even elements interleaved.
#
# For the base case of the recursion we return the first element
# of `elems` followed by the sum of the first two elements computed as
# a (small two-down-to-one) reduction step.
def _scan(elems):
"""Perform scan on `elems`."""
num_elems = elems[0].shape[axis]
if num_elems < 2:
return elems
# Combine adjacent pairs of elements.
reduced_elems = combine(
[slicing.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
[slicing.slice_in_dim(elem, 1, None, stride=2, axis=axis)
for elem in elems])
# Recursively compute scan for partially reduced tensors.
odd_elems = _scan(reduced_elems)
if num_elems % 2 == 0:
even_elems = combine(
[slicing.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
[slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
else:
even_elems = combine(
odd_elems,
[slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
# The first element of a scan is the same as the first element
# of the original `elems`.
even_elems = [
lax.concatenate([slicing.slice_in_dim(elem, 0, 1, axis=axis), result],
dimension=axis)
for (elem, result) in zip(elems, even_elems)]
return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
scans = _scan(elems_flat)
if reverse:
scans = [lax.rev(scanned, [axis]) for scanned in scans]
return tree_unflatten(tree, scans)
def _interleave(a, b, axis):
"""Given two Tensors of static shape, interleave them along the first axis."""
assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
a_pad = [(0, 0, 0)] * a.ndim
b_pad = [(0, 0, 0)] * b.ndim
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
op = lax.bitwise_or if a.dtype == np.bool_ else lax.add
return op(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))
### Cumulative reductions.
def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative sum along `axis`."""
return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative maximum along `axis`."""
return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative minimum along `axis`."""
return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cumlogsumexp(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative logsumexp along `axis`."""
return cumlogsumexp_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def _cumred_shape_rule(x, *, axis: int, reverse: bool):
if axis < 0 or axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of shape {x.shape}")
return x.shape
def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
return [cumsum(t, axis=axis, reverse=not reverse)]
def cumred_reduce_window_impl(window_reduce: Callable, x, *, axis: int,
reverse: bool):
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0)] * x.ndim
padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, padding)
def cumred_gpu_impl(window_reduce: Callable, reduce_fn: Callable, x, *,
axis: int, reverse: bool):
# On GPU, reduce_window is executed in a single fusion and associative_scan
# is split into multiple to materialize intermediate calculations.
# On small inputs reduce_window is faster being a single fusion,
# but on larger ones is slower because of O(n^2) complexity.
# This conservative value of the threshold was obtained via benchmarking.
if x.shape[axis] > 32:
return associative_scan(reduce_fn, x, reverse=reverse, axis=axis)
return cumred_reduce_window_impl(window_reduce, x, axis=axis, reverse=reverse)
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
reverse: bool):
operand, = batched_args
bdim, = batch_dims
axis = axis if axis < bdim else axis + 1
return prim.bind(operand, axis=axis, reverse=reverse), bdim
def _cumred_dtype_rule(name, operand, *args, **kw):
if not dtypes.issubdtype(operand.dtype, np.number):
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
"of number.".format(name, np.dtype(operand.dtype).name))
return dtypes.canonicalize_dtype(operand.dtype)
def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn):
reducer_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
name)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule,
reducer_p)
def register_lowering(fn, platform=None):
mlir.register_lowering(
reducer_p,
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
platform=platform)
# Default for platforms not treated specially below.
register_lowering(partial(associative_scan, reduce_fn))
# On GPU, we choose between window reduction and associative scan
# based on the input size.
for platform in ['cuda', 'rocm']:
register_lowering(
partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform)
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
register_lowering(partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu')
return reducer_p
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum)
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
cumlogsumexp_p = _cumulative_reduction_primitive(
"cumlogsumexp", logaddexp, windowed_reductions._reduce_window_logaddexp)
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, windowed_reductions._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, windowed_reductions._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, windowed_reductions._reduce_window_min)
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals, tangents)
ad.primitive_jvps[cumlogsumexp_p] = partial(_cumulative_jvp_rule, combine_fn=logaddexp)
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)