# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for the `for_loop` primitive.""" import functools import operator from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union import jax.numpy as jnp from jax import lax from jax.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, treedef_tuple, tree_map, tree_leaves, PyTreeDef) from jax._src import ad_util from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect) from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives from jax._src.state import utils as state_utils from jax._src.state import types as state_types from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip, split_list, split_dict) from jax._src.lax.control_flow import loops from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr ## JAX utilities map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip ## Helpful type aliases S = TypeVar('S') T = TypeVar('T') class Ref(Generic[T]): pass Array = Any ref_set = state_primitives.ref_set ref_get = state_primitives.ref_get ref_addupdate = state_primitives.ref_addupdate discharge_state = state_discharge.discharge_state ## `for_loop` implementation for_p = core.Primitive('for') for_p.multiple_results = True ### Tracing utilities def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr: all_const_avals = [var.aval for var in jaxpr.constvars] is_const_ref = [isinstance(var.aval, AbstractRef) for var in jaxpr.constvars] const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals) const_avals = map(AbstractRef, const_avals) merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals) i_aval, *arg_avals = [var.aval for var in jaxpr.invars] in_avals = [i_aval, *merged_const_avals, *arg_avals] num_consts = len(merged_const_avals) def _hoist(i, *consts_args): all_consts, args = split_list(consts_args, [num_consts]) consts, const_refs = partition_list(is_const_ref, all_consts) # We immediately read the const values out of the `Ref`s. consts = map(lambda x: ref_get(x, ()), consts) all_consts = merge_lists(is_const_ref, consts, const_refs) return core.eval_jaxpr(jaxpr, all_consts, i, *args) hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_hoist), in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue] ) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]: f, out_tree_thunk = flatten_fun_nokwargs( lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( f, state_avals) return jaxpr, consts, out_tree_thunk() def for_loop(nsteps: Union[int, Sequence[int]], body: Callable[[Array, Ref[S]], None], init_state: S, *, reverse: bool = False, unroll: int = 1) -> S: """A for-loop combinator that allows read/write semantics in the loop body. `for_loop` is a higher-order function that enables writing loops that can be staged out in JIT-ted JAX computations. Unlike `jax.lax.fori_loop`, it allows mutation in its body using `Ref`s. `for_loop` will initialize `Ref`s with the values in `init_state`. Each iteration, `body` will be called with the current `Ref`s, which can be read from and written to using `ref_get` and `ref_set`. `for_loop` is semantically equivalent to the following Python code: ```python def for_loop(nsteps, body, init_state): refs = tree_map(make_ref, init_state) for i in range(nsteps): body(i, refs) return tree_map(ref_get, refs) ``` Args: nsteps: Number of iterations body: A callable that takes in the iteration number as its first argument and `Ref`s corresponding to `init_state` as its second argument. `body` is free to read from and write to its `Ref`s. `body` should not return anything. init_state: A Pytree of JAX-compatible values used to initialize the `Ref`s that will be passed into the for loop body. unroll: A positive int specifying, in the underlying operation of the `for` primitive, how many iterations to unroll within a single iteration of a loop. Higher values may speed up execution time at the cost of longer compilation time. Returns: A Pytree of values representing the output of the for loop. """ if unroll < 1: raise ValueError("`unroll` must be a positive integer.") if isinstance(nsteps, int): nsteps = [nsteps] if len(nsteps) > 1: outer_step, *rest_steps = nsteps def wrapped_body(i, refs): vals = tree_map(lambda ref: ref_get(ref, ()), refs) vals = for_loop( rest_steps, functools.partial(body, i), vals, unroll=unroll) tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals) return for_loop(outer_step, wrapped_body, init_state, unroll=unroll) nsteps, = nsteps flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) idx_aval = core.ShapedArray((), jnp.dtype("int32")) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): raise Exception("`body` should not return anything.") # Remove constvars from jaxpr and turn them into `Ref`s jaxpr = _hoist_consts_to_refs(jaxpr) which_linear = (False,) * (len(consts) + len(flat_state)) out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps), reverse=reverse, which_linear=which_linear, unroll=unroll) # Consts are `Ref`s so they are both inputs and outputs. We remove them from # the outputs. out_flat = out_flat[len(consts):] return tree_unflatten(state_tree, out_flat) Carry = TypeVar('Carry') X = TypeVar('X') Y = TypeVar('Y') def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, length: Optional[int] = None, reverse: bool = False, unroll: int = 1) -> Tuple[Carry, Y]: if not callable(f): raise TypeError("scan: f argument should be a callable.") if unroll < 1: raise ValueError("`unroll` must be a positive integer.") xs_flat, xs_tree = tree_flatten(xs) try: lengths = [x.shape[0] for x in xs_flat] except AttributeError as err: msg = "scan got value with no leading axis to scan over: {}." raise ValueError( msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err if length is not None: length = int(length) if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " "leading axis sizes {}.") raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) else: unique_lengths = set(lengths) if len(unique_lengths) > 1: msg = "scan got values with different leading axis sizes: {}." raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) elif len(unique_lengths) == 0: msg = "scan got no values to scan over and `length` not provided." raise ValueError(msg) else: length, = unique_lengths x_shapes = [x.shape[1:] for x in xs_flat] x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat] x_avals = tuple(map(core.ShapedArray, x_shapes, x_dtypes)) def _create_jaxpr(init): init_flat = tree_leaves(init) _, in_tree = tree_flatten((init, xs)) carry_avals = tuple(map(_abstractify, init_flat)) jaxpr, _, out_tree = _initial_style_jaxpr( f, in_tree, carry_avals + x_avals, "scan") return jaxpr, out_tree jaxpr, out_tree = _create_jaxpr(init) _, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals) ys = tree_map(lambda aval: jnp.zeros([length, *aval.shape], aval.dtype), ys_avals) def for_body(i, refs): carry_refs, xs_refs, ys_refs = refs carry = tree_map(lambda x: x[()], carry_refs) x = tree_map(lambda x: x[i], xs_refs) carry, y = f(carry, x) tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry) tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y) assert isinstance(length, int) init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse, unroll=unroll) return init, ys @for_p.def_effectful_abstract_eval def _for_abstract_eval(*avals, jaxpr, **__): # Find out for each of the `Ref`s in our jaxpr what effects they have. jaxpr_aval_effects = state_types.get_ref_state_effects( [v.aval for v in jaxpr.invars], jaxpr.effects)[1:] aval_effects = [set(eff.replace(input_index=eff.input_index - 1) for eff in effs) for aval, effs in zip(avals, jaxpr_aval_effects) if isinstance(aval, AbstractRef)] nonlocal_state_effects = core.join_effects(*aval_effects) return list(avals), nonlocal_state_effects @state_discharge.register_discharge_rule(for_p) def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr, reverse: bool, which_linear: Sequence[bool], nsteps: int, unroll: int ) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]: out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse, which_linear=which_linear, nsteps=nsteps, unroll=unroll) new_invals = [] for aval, out_val in zip(in_avals, out_vals): new_invals.append(out_val if isinstance(aval, AbstractRef) else None) return new_invals, out_vals def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll): del which_linear discharged_jaxpr, consts = discharge_state(jaxpr, ()) def body(i, state): i_ = nsteps - i - 1 if reverse else i return core.eval_jaxpr(discharged_jaxpr, consts, i_, *state) return _for_impl_unrolled(body, nsteps, unroll, *args) def _for_impl_unrolled(body, nsteps, unroll, *args): remainder = nsteps % unroll i = jnp.int32(0) state = list(args) for _ in range(remainder): state = body(i, state) i = i + 1 def cond(carry): i, _ = carry return i < nsteps def while_body(carry): i, state = carry for _ in range(unroll): state = body(i, state) i = i + 1 return i, state _, state = lax.while_loop(cond, while_body, (i, state)) return state mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p)) def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, jaxpr, nsteps, reverse, which_linear, unroll): init_batched = [d is not batching.not_mapped for d in dims] discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) batched = init_batched for _ in range(len(batched)): _, out_batched = batching.batch_jaxpr( core.ClosedJaxpr(discharged_jaxpr, body_consts), axis_size, [False] + batched, instantiate=batched, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) if out_batched == batched: break batched = map(operator.or_, batched, out_batched) else: raise Exception("Invalid fixpoint") args = [batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)] batched_jaxpr_, _ = batching.batch_jaxpr( core.ClosedJaxpr(jaxpr, []), axis_size, [False] + batched, [], axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=which_linear, unroll=unroll) return out_flat, [0 if b else batching.not_mapped for b in batched] batching.axis_primitive_batchers[for_p] = functools.partial(_for_vmap, None) batching.spmd_axis_primitive_batchers[for_p] = _for_vmap def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, unroll): nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] # We need to find out which `Ref`s have nonzero tangents after running the # for loop. Ordinarily we do this with a fixed point on the body jaxpr but # a `for` body jaxpr is stateful and has no outputs. We therefore discharge # the state effect from the jaxpr and we will now have a "symmetric" jaxpr # where the inputs line up with the outputs. We use this discharged jaxpr # for the fixed point. discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr( core.ClosedJaxpr(discharged_jaxpr, body_consts), [False] + nonzero_tangents, instantiate=nonzero_tangents) if out_nonzero_tangents == nonzero_tangents: break nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) else: raise Exception("Invalid fixpoint") tangents = [ad.instantiate_zeros(t) if inst else t for t, inst in zip(tangents, nonzero_tangents)] tangents = [t for t in tangents if type(t) is not ad_util.Zero] closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts jvp_which_linear = which_linear + (True,) * len(tangents) out_flat = for_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=jvp_which_linear, unroll=unroll) # `out_flat` includes constant inputs into the `for_loop` which are converted # into outputs as well. We don't care about these in AD so we throw them out. out_primals, out_tangents = split_list(out_flat, [len(primals)]) out_tangents_iter = iter(out_tangents) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[for_p] = _for_jvp def _partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy): # A simple wrapper around `pe.partial_eval_jaxpr_custom` that assumes all # inputs are instantiated and doesn't ensure any outputs are unknown or # instantiated. return pe.partial_eval_jaxpr_custom( jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy) _save_everything = lambda *_, **__: True def _is_read_only(ref_effects: Set[StateEffect]) -> bool: assert len(ref_effects) > 0 if len(ref_effects) > 1: # Means we must have a write or accum effect so not read-only return False eff, = ref_effects return isinstance(eff, ReadEffect) def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]: # Get effects for each of the jaxpr inputs and remove the loop index. ref_effects = state_types.get_ref_state_effects( [v.aval for v in jaxpr.invars], jaxpr.effects)[1:] # We first assume that *read-only `Ref`s* are loop-invariant. We can safely do # this because the only way something can be loop-varying is if we write to it # at some point. It's *possible* that read-write `Ref`s are loop-invariant but # we conservatively assume they aren't. loop_invar_refs = [_is_read_only(effs) if effs else True for effs in ref_effects] loop_var_refs = map(operator.not_, loop_invar_refs) # We'd like to detect if the outputs of the jaxpr are loop-invariant. An # output is loop-invariant if it is downstream of only loop-invariant values # (seeded by the read-only `Ref`s). If at any point, a loop-varying value # interacts with a loop-invariant value, we produce a loop-varying value. We # can use `partial_eval` to perform this analysis by treating loop-varying # values as "unknown" and loop-invariant values as "known", since when a known # and unknown value interact, they produce an unknown value. loop_var_inputs = [True, *loop_var_refs] _, _, loop_var_outputs, _, _, = _partial_eval_jaxpr_custom( jaxpr, loop_var_inputs, _save_everything) return map(operator.not_, loop_var_outputs) def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, jaxpr: core.Jaxpr, nsteps: int, reverse: bool, which_linear: Tuple[bool, ...], unroll: int) -> List[pe.JaxprTracer]: num_inputs = len(tracers) assert num_inputs == len(jaxpr.invars) - 1 in_unknowns = [not t.pval.is_known() for t in tracers] # We first need to run a fixpoint to determine which of the `Ref`s are unknown # after running the for loop. We want to use the jaxpr to determine which # `Ref`s are unknown after executing the for loop body given which `Ref`s are # unknown before. However, the jaxpr has no outputs. Instead, we discharge # the body and run the fixpoint with the discharged jaxpr. We can do this # because the outputs of the jaxpr are one-to-one with the inputs. discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) discharged_jaxpr = discharged_jaxpr.replace( invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, constvars=[]) for _ in range(num_inputs): jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns] _, _, out_unknowns, _, _, = pe.partial_eval_jaxpr_custom( discharged_jaxpr, jaxpr_in_unknowns, [True] * len(jaxpr_in_unknowns), in_unknowns, False, _save_everything) out_unknowns = list(out_unknowns) if out_unknowns == in_unknowns: break in_unknowns = map(operator.or_, in_unknowns, out_unknowns) else: raise Exception("Invalid fixpoint") del out_unknowns # redundant since it's the same as `in_unknowns` tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore for t, uk in zip(tracers, in_unknowns)) # We use `partial_eval_jaxpr_custom` here because it won't remove effectful # primitives like `get`/`set`. jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \ _partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns], _save_everything) # # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and # regular valued input/outputs. However, we'd like to bind these jaxprs to a # `for`, which expects only `Ref` inputs and no output. We need to convert # both of these jaxprs into ones that are compatible with `for`. # TODO(sharadmv,mattjj): implement "passthrough" optimization. # TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of # passing the loop index as a residual # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` # to output residual values (none of them should be `Ref`s). We'll need to # convert the output residual values into `Ref`s that are initially empty # `Ref`s that are written to at the end of the jaxpr. # # Loop-invariant residual optimization # Here we are interested in finding out which of the residuals are *not* # dependent on the loop index. If a residual is not dependent on the loop # index, we don't need add an extra loop dimension we're reading from when we # convert it from an output into a write. loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout) jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps, jaxpr_known_resout, loop_invar_res) # We now run the known jaxpr to obtain our residual values. known_tracers, _ = partition_list(in_unknowns, tracers) known_vals = [t.pval.get_known() for t in known_tracers] empty_res = map(ad_util.zeros_like_aval, res_avals) jaxpr_known_args = [*known_vals, *empty_res] # We assume the known inputs are nonlinear which is okay to do for AD but not # necessarily okay for general partial eval. jaxpr_known_which_linear = (False,) * len(jaxpr_known_args) out_flat = for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps, reverse=reverse, which_linear=jaxpr_known_which_linear, unroll=unroll) known_outputs, residuals = split_list(out_flat, [len(known_tracers)]) residuals = map(trace.new_instantiated_const, residuals) # Now we handle the `jaxpr_unknown` that expects residual values as inputs. # This jaxpr is the output of `partial_eval_jaxpr_custom` that marks which # inputs are actually used. # `partial_eval_jaxpr_custom` doesn't remove extra inputs/outputs for you # so we use `dce_jaxpr` here to do that. jaxpr_unknown_resin, used_inputs = pe.dce_jaxpr( jaxpr_unknown_resin_, [], [True] * num_res + [True, *in_unknowns]) used_res, (used_i,), used_refs = split_list(used_inputs, [num_res, 1]) assert all(used_res), "All residuals should be used" # To make it compatible with `for`, we need to convert those residual values # into `Ref`s. jaxpr_unknown = _convert_inputs_to_reads(nsteps, len(res_avals), jaxpr_unknown_resin, loop_invar_res) # Since not all inputs are used in jaxpr_unknown, we filter the input tracers # down using the output of `dce_jaxpr`. used_and_known = map(operator.and_, used_refs, map(operator.not_, in_unknowns)) tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k # type: ignore in zip(tracers, used_and_known)] _, known_used = partition_list(used_refs, used_and_known) _, used_tracers = partition_list(used_refs, tracers) _, used_which_linear = partition_list(used_refs, which_linear) which_linear_unknown = (False,) * num_res + tuple(used_which_linear) unknown_inputs = [*residuals, *used_tracers] # Outputs match inputs so we construct output tracers that look like the input # tracers. res_ref_unknown_outputs = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None) for t in unknown_inputs] name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(unknown_inputs) == len(res_ref_unknown_outputs) assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1 eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps, reverse=reverse, which_linear=which_linear_unknown, unroll=unroll), core.no_effects, source) for t in res_ref_unknown_outputs: t.recipe = eqn _, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res]) unknown_outputs, _ = partition_list(known_used, unknown_outputs) return merge_lists(in_unknowns, known_outputs, unknown_outputs) pe.custom_partial_eval_rules[for_p] = _for_partial_eval def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): jaxpr, nsteps, reverse, which_linear, unroll = split_dict( eqn.params, ["jaxpr", "nsteps", "reverse", "which_linear", "unroll"]) num_inputs = len(eqn.invars) # We first need to run a fixpoint to determine which of the `Ref`s are unknown # after running the for loop. However, the jaxpr has no outputs. Instead, we # discharge the body and run the fixpoint with the discharged jaxpr. We can do # this because the outputs of the discharged jaxpr are one-to-one with the # inputs. discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) discharged_jaxpr = discharged_jaxpr.replace( invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, constvars=[]) in_unknowns, in_inst = list(in_unknowns), list(in_inst) out_unknowns, out_inst = in_unknowns, in_inst for _ in range(num_inputs): jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns] _, _, out_unknowns, out_inst, _, = pe.partial_eval_jaxpr_custom( discharged_jaxpr, jaxpr_in_unknowns, True, ensure_out_unknowns=in_unknowns, ensure_out_inst=True, saveable=saveable) out_unknowns = list(out_unknowns) if out_unknowns == in_unknowns: break in_unknowns = map(operator.or_, in_unknowns, out_unknowns) else: if num_inputs > 0: raise Exception("Invalid fixpoint") del out_unknowns # Redundant since it's the same as `in_unknowns` new_inst = [x for x, inst in zip(eqn.invars, in_inst) if type(x) is core.Var and not inst] in_inst = [True] * len(eqn.invars) # We use `partial_eval_jaxpr_custom` here because it won't remove effectful # primitives like `get`/`set`. jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns], [True, *in_inst], [], [], saveable) # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and # non-Ref input/outputs. However, we'd like to bind these jaxprs to a # `for`, which expects only `Ref` inputs and no output. We need to convert # both of these jaxprs into ones that are compatible with `for`. # TODO(sharadmv,mattjj): implement "passthrough" optimization. # TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of # passing the loop index as a residual # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` # to output residual values (none of them should be `Ref`s). We'll need to # convert the output residual values into `Ref`s that are initially empty # `Ref`s that are written to at the end of the jaxpr. # # Loop-invariant residual optimization # Here we are interested in finding out which of the residuals are *not* # dependent on the loop index. If a residual is not dependent on the loop # index, we don't need add an extra loop dimension we're reading from when we # convert it from an output into a write. loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout) jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps, jaxpr_known_resout, loop_invar_res) known_invars, _ = partition_list(in_unknowns, eqn.invars) known_outvars, _ = partition_list(in_unknowns, eqn.outvars) newvar = core.gensym() resvars = map(newvar, res_avals) @lu.wrap_init def known(*known_vals): empty_res = map(ad_util.zeros_like_aval, res_avals) jaxpr_known_args = [*known_vals, *empty_res] jaxpr_known_which_linear = (False,) * len(jaxpr_known_args) return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps, reverse=reverse, which_linear=jaxpr_known_which_linear, unroll=unroll) call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( known, [v.aval for v in known_invars]) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars], core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects, eqn.source_info) jaxpr_staged = _convert_inputs_to_reads(nsteps, len(res_avals), jaxpr_staged_resin_, loop_invar_res) which_linear_unknown = (False,) * num_res + tuple(which_linear) params_staged = dict(eqn.params, jaxpr=jaxpr_staged, reverse=reverse, nsteps=nsteps, which_linear=which_linear_unknown, unroll=unroll) @lu.wrap_init def staged(*res_and_refs): out_flat = for_p.bind(*res_and_refs, **params_staged) _, ans = split_list(out_flat, [num_res]) _, ans = partition_list(out_inst, ans) return ans call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( staged, [v.aval for v in [*resvars, *eqn.invars]]) assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) _, outvars = partition_list(out_inst, eqn.outvars) eqn_staged = pe.new_jaxpr_eqn([*resvars, *eqn.invars], outvars, core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects, eqn.source_info) new_vars = [*new_inst, *resvars] return eqn_known, eqn_staged, in_unknowns, out_inst, new_vars pe.partial_eval_jaxpr_custom_rules[for_p] = _for_partial_eval_custom def _convert_outputs_to_writes( nsteps: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool] ) -> Tuple[core.Jaxpr, List[core.ShapedArray]]: assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals] @lu.wrap_init def eval_jaxpr(i, *refs): # We split the refs into the original input refs and the dummy residual # refs. orig_refs, residual_refs = split_list(refs, [len(in_avals) - 1]) residual_vals = core.eval_jaxpr(jaxpr, (), i, *orig_refs) for res_ref, res_val, loop_invar in zip(residual_refs, residual_vals, loop_invar_res): if loop_invar: res_ref[()] = res_val else: res_ref[i] = res_val return [] # TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype res_ref_avals: List[core.AbstractValue] = [ AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error v.aval.dtype)) # pytype: disable=attribute-error for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)] jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*in_avals, *res_ref_avals]) assert not consts return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error def _convert_inputs_to_reads( nsteps: int, num_res: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool]) -> core.Jaxpr: assert not jaxpr.constvars, "Jaxpr should not have constvars" @lu.wrap_init def eval_jaxpr(i, *refs): residual_refs, orig_refs = split_list(refs, [num_res]) residual_vals = [r[()] if loop_invar else r[i] for r, loop_invar in zip(residual_refs, loop_invar_res)] () = core.eval_jaxpr(jaxpr, (), *residual_vals, i, *orig_refs) return [] res_val_avals, (i_aval,), orig_ref_avals = \ split_list([v.aval for v in jaxpr.invars], [num_res, 1]) res_ref_avals: List[core.AbstractValue] = [ AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error aval.dtype)) # pytype: disable=attribute-error for aval, loop_invar in zip(res_val_avals, loop_invar_res)] jaxpr, _, () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals]) return jaxpr def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: List[bool]) -> core.Jaxpr: def trans(i, *args): # First we want to run the computation to read all the residual refs. We can # do that by using partial evaluation with all linear inputs unknown. res_jaxpr, tangent_jaxpr_, *_ = \ _partial_eval_jaxpr_custom(jaxpr, [False, *which_linear], _save_everything) res_args = [x for x, lin in zip(args, which_linear) if not lin] res = core.eval_jaxpr(res_jaxpr, (), i, *res_args) # Now that we have residual values, we run the tangent jaxpr. It takes as # input the residuals, the loop index, and all the refs (at least, the ones # that are used in the body). Luckily, `tangent_jaxpr_` has all known and # unknown inputs! tangent_jaxpr, used = pe.dce_jaxpr(tangent_jaxpr_, []) used_res, (used_i,), used_ct = split_list(used, [len(res), 1]) primals_args = [*(r for u, r in zip(used_res, res) if u)] if used_i: primals_args = [*primals_args, i] ct_args = [x for x, u in zip(args, used_ct) if u] ad.backward_pass( tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ()) return [] jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) return jaxpr_trans def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll): # if any in_ct is nonzero, we definitely want it in args_ (and the # corresponding x in args could be an undefined primal, but doesnt have to be) # for non-res stuff: # getting and setting => (nonzero ct, UndefinedPrimal arg) # just setting => (nonzero ct, not UndefinedPrimal, dummy value) # just getting => (zero ct , UndefinedPrimal arg) # for res stuff: # (zero ct , not UndefinedPrimal) args_ = [] which_linear_transpose = [] for x, ct in zip(args, in_cts): if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x): # this is a residual, take x! args_.append(x) which_linear_transpose.append(False) elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x): # the loop was 'just getting', plug in a zero args_.append(ad_util.zeros_like_aval(x.aval)) which_linear_transpose.append(False) elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x): # the loop was 'just setting', grab that cotangent! x is dummy args_.append(ct) which_linear_transpose.append(False) elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x): # the loop was 'getting and setting', grab that cotangent! args_.append(ct) which_linear_transpose.append(True) jaxpr_transpose = transpose_jaxpr(jaxpr, which_linear) assert len(args_) == len(jaxpr_transpose.invars) - 1 all_outs = for_p.bind(*args_, jaxpr=jaxpr_transpose, nsteps=nsteps, reverse=not reverse, which_linear=tuple(which_linear_transpose), unroll=unroll) ct_outs = [ct if ad.is_undefined_primal(x) else None for x, ct in zip(args, all_outs)] return ct_outs ad.primitive_transposes[for_p] = _for_transpose ### Testing utility def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): """A `for_loop` implementation that discharges its body right away. Potentially useful for testing and benchmarking. """ flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) idx_aval = core.ShapedArray((), jnp.dtype("int32")) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): raise Exception("`body` should not return anything.") discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) def fori_body(i, carry): i = jnp.int32(i) if reverse: i = nsteps - i - 1 out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, i, *carry) return out_flat out_flat = loops.fori_loop(0, nsteps, fori_body, flat_state) return tree_unflatten(state_tree, out_flat) def run_state(f, init_state): @functools.wraps(f) def wrapped_body(_, *args): return f(*args) return for_loop(1, wrapped_body, init_state)