790 lines
36 KiB
Python
790 lines
36 KiB
Python
|
# Copyright 2022 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Module for the `for_loop` primitive."""
|
||
|
import functools
|
||
|
import operator
|
||
|
|
||
|
from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||
|
|
||
|
import jax.numpy as jnp
|
||
|
from jax import lax
|
||
|
from jax.api_util import flatten_fun_nokwargs
|
||
|
from jax._src.interpreters import ad
|
||
|
from jax._src.interpreters import batching
|
||
|
from jax._src.interpreters import mlir
|
||
|
from jax._src.interpreters import partial_eval as pe
|
||
|
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
|
||
|
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
|
||
|
|
||
|
from jax._src import ad_util
|
||
|
from jax._src import core
|
||
|
from jax._src import dispatch
|
||
|
from jax._src import dtypes
|
||
|
from jax._src import linear_util as lu
|
||
|
from jax._src import source_info_util
|
||
|
from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect)
|
||
|
from jax._src.state import discharge as state_discharge
|
||
|
from jax._src.state import primitives as state_primitives
|
||
|
from jax._src.state import utils as state_utils
|
||
|
from jax._src.state import types as state_types
|
||
|
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||
|
split_list, split_dict)
|
||
|
from jax._src.lax.control_flow import loops
|
||
|
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
|
||
|
|
||
|
## JAX utilities
|
||
|
|
||
|
map, unsafe_map = safe_map, map
|
||
|
zip, unsafe_zip = safe_zip, zip
|
||
|
|
||
|
## Helpful type aliases
|
||
|
S = TypeVar('S')
|
||
|
T = TypeVar('T')
|
||
|
class Ref(Generic[T]): pass
|
||
|
Array = Any
|
||
|
|
||
|
ref_set = state_primitives.ref_set
|
||
|
ref_get = state_primitives.ref_get
|
||
|
ref_addupdate = state_primitives.ref_addupdate
|
||
|
discharge_state = state_discharge.discharge_state
|
||
|
|
||
|
|
||
|
## `for_loop` implementation
|
||
|
|
||
|
for_p = core.Primitive('for')
|
||
|
for_p.multiple_results = True
|
||
|
|
||
|
### Tracing utilities
|
||
|
|
||
|
def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||
|
all_const_avals = [var.aval for var in jaxpr.constvars]
|
||
|
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
|
||
|
jaxpr.constvars]
|
||
|
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
||
|
const_avals = map(AbstractRef, const_avals)
|
||
|
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
||
|
i_aval, *arg_avals = [var.aval for var in jaxpr.invars]
|
||
|
in_avals = [i_aval, *merged_const_avals, *arg_avals]
|
||
|
num_consts = len(merged_const_avals)
|
||
|
|
||
|
def _hoist(i, *consts_args):
|
||
|
all_consts, args = split_list(consts_args, [num_consts])
|
||
|
consts, const_refs = partition_list(is_const_ref, all_consts)
|
||
|
# We immediately read the const values out of the `Ref`s.
|
||
|
consts = map(lambda x: ref_get(x, ()), consts)
|
||
|
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||
|
return core.eval_jaxpr(jaxpr, all_consts, i, *args)
|
||
|
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||
|
lu.wrap_init(_hoist), in_avals)
|
||
|
assert not consts, "All consts should have been converted to refs"
|
||
|
return hoisted_jaxpr
|
||
|
|
||
|
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||
|
state_avals: Sequence[core.AbstractValue]
|
||
|
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
|
||
|
f, out_tree_thunk = flatten_fun_nokwargs(
|
||
|
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
|
||
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||
|
f, state_avals)
|
||
|
return jaxpr, consts, out_tree_thunk()
|
||
|
|
||
|
def for_loop(nsteps: Union[int, Sequence[int]],
|
||
|
body: Callable[[Array, Ref[S]], None], init_state: S,
|
||
|
*, reverse: bool = False, unroll: int = 1) -> S:
|
||
|
"""A for-loop combinator that allows read/write semantics in the loop body.
|
||
|
|
||
|
`for_loop` is a higher-order function that enables writing loops that can be
|
||
|
staged out in JIT-ted JAX computations. Unlike `jax.lax.fori_loop`, it allows
|
||
|
mutation in its body using `Ref`s.
|
||
|
|
||
|
`for_loop` will initialize `Ref`s with the values in `init_state`. Each
|
||
|
iteration, `body` will be called with the current `Ref`s, which can be read
|
||
|
from and written to using `ref_get` and `ref_set`.
|
||
|
|
||
|
`for_loop` is semantically equivalent to the following Python code:
|
||
|
|
||
|
```python
|
||
|
def for_loop(nsteps, body, init_state):
|
||
|
refs = tree_map(make_ref, init_state)
|
||
|
for i in range(nsteps):
|
||
|
body(i, refs)
|
||
|
return tree_map(ref_get, refs)
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
nsteps: Number of iterations
|
||
|
body: A callable that takes in the iteration number as its first argument
|
||
|
and `Ref`s corresponding to `init_state` as its second argument.
|
||
|
`body` is free to read from and write to its `Ref`s. `body` should
|
||
|
not return anything.
|
||
|
init_state: A Pytree of JAX-compatible values used to initialize the `Ref`s
|
||
|
that will be passed into the for loop body.
|
||
|
unroll: A positive int specifying, in the underlying operation of the
|
||
|
`for` primitive, how many iterations to unroll within a single iteration
|
||
|
of a loop. Higher values may speed up execution time at the cost of longer
|
||
|
compilation time.
|
||
|
Returns:
|
||
|
A Pytree of values representing the output of the for loop.
|
||
|
"""
|
||
|
if unroll < 1:
|
||
|
raise ValueError("`unroll` must be a positive integer.")
|
||
|
if isinstance(nsteps, int):
|
||
|
nsteps = [nsteps]
|
||
|
if len(nsteps) > 1:
|
||
|
outer_step, *rest_steps = nsteps
|
||
|
def wrapped_body(i, refs):
|
||
|
vals = tree_map(lambda ref: ref_get(ref, ()), refs)
|
||
|
vals = for_loop(
|
||
|
rest_steps, functools.partial(body, i), vals, unroll=unroll)
|
||
|
tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals)
|
||
|
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
|
||
|
nsteps, = nsteps
|
||
|
flat_state, state_tree = tree_flatten(init_state)
|
||
|
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||
|
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
|
||
|
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||
|
body, state_tree, [idx_aval, *state_avals])
|
||
|
if out_tree != tree_structure(None):
|
||
|
raise Exception("`body` should not return anything.")
|
||
|
# Remove constvars from jaxpr and turn them into `Ref`s
|
||
|
jaxpr = _hoist_consts_to_refs(jaxpr)
|
||
|
which_linear = (False,) * (len(consts) + len(flat_state))
|
||
|
out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps),
|
||
|
reverse=reverse, which_linear=which_linear,
|
||
|
unroll=unroll)
|
||
|
# Consts are `Ref`s so they are both inputs and outputs. We remove them from
|
||
|
# the outputs.
|
||
|
out_flat = out_flat[len(consts):]
|
||
|
return tree_unflatten(state_tree, out_flat)
|
||
|
|
||
|
Carry = TypeVar('Carry')
|
||
|
X = TypeVar('X')
|
||
|
Y = TypeVar('Y')
|
||
|
|
||
|
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||
|
init: Carry,
|
||
|
xs: X,
|
||
|
length: Optional[int] = None,
|
||
|
reverse: bool = False,
|
||
|
unroll: int = 1) -> Tuple[Carry, Y]:
|
||
|
if not callable(f):
|
||
|
raise TypeError("scan: f argument should be a callable.")
|
||
|
if unroll < 1:
|
||
|
raise ValueError("`unroll` must be a positive integer.")
|
||
|
xs_flat, xs_tree = tree_flatten(xs)
|
||
|
|
||
|
try:
|
||
|
lengths = [x.shape[0] for x in xs_flat]
|
||
|
except AttributeError as err:
|
||
|
msg = "scan got value with no leading axis to scan over: {}."
|
||
|
raise ValueError(
|
||
|
msg.format(', '.join(str(x) for x in xs_flat
|
||
|
if not hasattr(x, 'shape')))) from err
|
||
|
|
||
|
if length is not None:
|
||
|
length = int(length)
|
||
|
if not all(length == l for l in lengths):
|
||
|
msg = ("scan got `length` argument of {} which disagrees with "
|
||
|
"leading axis sizes {}.")
|
||
|
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
|
||
|
else:
|
||
|
unique_lengths = set(lengths)
|
||
|
if len(unique_lengths) > 1:
|
||
|
msg = "scan got values with different leading axis sizes: {}."
|
||
|
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
||
|
elif len(unique_lengths) == 0:
|
||
|
msg = "scan got no values to scan over and `length` not provided."
|
||
|
raise ValueError(msg)
|
||
|
else:
|
||
|
length, = unique_lengths
|
||
|
|
||
|
x_shapes = [x.shape[1:] for x in xs_flat]
|
||
|
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
|
||
|
x_avals = tuple(map(core.ShapedArray, x_shapes, x_dtypes))
|
||
|
|
||
|
def _create_jaxpr(init):
|
||
|
init_flat = tree_leaves(init)
|
||
|
_, in_tree = tree_flatten((init, xs))
|
||
|
|
||
|
carry_avals = tuple(map(_abstractify, init_flat))
|
||
|
jaxpr, _, out_tree = _initial_style_jaxpr(
|
||
|
f, in_tree, carry_avals + x_avals, "scan")
|
||
|
return jaxpr, out_tree
|
||
|
jaxpr, out_tree = _create_jaxpr(init)
|
||
|
_, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals)
|
||
|
ys = tree_map(lambda aval: jnp.zeros([length, *aval.shape], aval.dtype),
|
||
|
ys_avals)
|
||
|
def for_body(i, refs):
|
||
|
carry_refs, xs_refs, ys_refs = refs
|
||
|
carry = tree_map(lambda x: x[()], carry_refs)
|
||
|
x = tree_map(lambda x: x[i], xs_refs)
|
||
|
carry, y = f(carry, x)
|
||
|
tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry)
|
||
|
tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y)
|
||
|
assert isinstance(length, int)
|
||
|
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse,
|
||
|
unroll=unroll)
|
||
|
return init, ys
|
||
|
|
||
|
@for_p.def_effectful_abstract_eval
|
||
|
def _for_abstract_eval(*avals, jaxpr, **__):
|
||
|
# Find out for each of the `Ref`s in our jaxpr what effects they have.
|
||
|
jaxpr_aval_effects = state_types.get_ref_state_effects(
|
||
|
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
|
||
|
aval_effects = [set(eff.replace(input_index=eff.input_index - 1)
|
||
|
for eff in effs) for aval, effs
|
||
|
in zip(avals, jaxpr_aval_effects)
|
||
|
if isinstance(aval, AbstractRef)]
|
||
|
nonlocal_state_effects = core.join_effects(*aval_effects)
|
||
|
return list(avals), nonlocal_state_effects
|
||
|
|
||
|
@state_discharge.register_discharge_rule(for_p)
|
||
|
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
|
||
|
reverse: bool, which_linear: Sequence[bool],
|
||
|
nsteps: int, unroll: int
|
||
|
) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
|
||
|
out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse,
|
||
|
which_linear=which_linear, nsteps=nsteps,
|
||
|
unroll=unroll)
|
||
|
new_invals = []
|
||
|
for aval, out_val in zip(in_avals, out_vals):
|
||
|
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
|
||
|
return new_invals, out_vals
|
||
|
|
||
|
def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
|
||
|
del which_linear
|
||
|
discharged_jaxpr, consts = discharge_state(jaxpr, ())
|
||
|
def body(i, state):
|
||
|
i_ = nsteps - i - 1 if reverse else i
|
||
|
return core.eval_jaxpr(discharged_jaxpr, consts, i_, *state)
|
||
|
return _for_impl_unrolled(body, nsteps, unroll, *args)
|
||
|
|
||
|
def _for_impl_unrolled(body, nsteps, unroll, *args):
|
||
|
remainder = nsteps % unroll
|
||
|
i = jnp.int32(0)
|
||
|
state = list(args)
|
||
|
|
||
|
for _ in range(remainder):
|
||
|
state = body(i, state)
|
||
|
i = i + 1
|
||
|
|
||
|
def cond(carry):
|
||
|
i, _ = carry
|
||
|
return i < nsteps
|
||
|
def while_body(carry):
|
||
|
i, state = carry
|
||
|
for _ in range(unroll):
|
||
|
state = body(i, state)
|
||
|
i = i + 1
|
||
|
return i, state
|
||
|
_, state = lax.while_loop(cond, while_body, (i, state))
|
||
|
return state
|
||
|
|
||
|
mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
|
||
|
for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p))
|
||
|
|
||
|
def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
|
||
|
jaxpr, nsteps, reverse, which_linear, unroll):
|
||
|
init_batched = [d is not batching.not_mapped for d in dims]
|
||
|
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
|
||
|
batched = init_batched
|
||
|
for _ in range(len(batched)):
|
||
|
_, out_batched = batching.batch_jaxpr(
|
||
|
core.ClosedJaxpr(discharged_jaxpr, body_consts),
|
||
|
axis_size, [False] + batched, instantiate=batched,
|
||
|
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
|
||
|
if out_batched == batched:
|
||
|
break
|
||
|
batched = map(operator.or_, batched, out_batched)
|
||
|
else:
|
||
|
raise Exception("Invalid fixpoint")
|
||
|
args = [batching.broadcast(x, axis_size, 0) if now_bat and not was_bat
|
||
|
else batching.moveaxis(x, d, 0) if now_bat else x
|
||
|
for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)]
|
||
|
batched_jaxpr_, _ = batching.batch_jaxpr(
|
||
|
core.ClosedJaxpr(jaxpr, []), axis_size, [False] + batched, [],
|
||
|
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
|
||
|
batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts
|
||
|
out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps,
|
||
|
reverse=reverse, which_linear=which_linear,
|
||
|
unroll=unroll)
|
||
|
return out_flat, [0 if b else batching.not_mapped for b in batched]
|
||
|
batching.axis_primitive_batchers[for_p] = functools.partial(_for_vmap, None)
|
||
|
batching.spmd_axis_primitive_batchers[for_p] = _for_vmap
|
||
|
|
||
|
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear,
|
||
|
unroll):
|
||
|
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||
|
# We need to find out which `Ref`s have nonzero tangents after running the
|
||
|
# for loop. Ordinarily we do this with a fixed point on the body jaxpr but
|
||
|
# a `for` body jaxpr is stateful and has no outputs. We therefore discharge
|
||
|
# the state effect from the jaxpr and we will now have a "symmetric" jaxpr
|
||
|
# where the inputs line up with the outputs. We use this discharged jaxpr
|
||
|
# for the fixed point.
|
||
|
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
|
||
|
for _ in range(len(nonzero_tangents)):
|
||
|
_, out_nonzero_tangents = ad.jvp_jaxpr(
|
||
|
core.ClosedJaxpr(discharged_jaxpr, body_consts),
|
||
|
[False] + nonzero_tangents, instantiate=nonzero_tangents)
|
||
|
if out_nonzero_tangents == nonzero_tangents:
|
||
|
break
|
||
|
nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents)
|
||
|
else:
|
||
|
raise Exception("Invalid fixpoint")
|
||
|
tangents = [ad.instantiate_zeros(t) if inst else t
|
||
|
for t, inst in zip(tangents, nonzero_tangents)]
|
||
|
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
|
||
|
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, [])
|
||
|
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
||
|
jvp_which_linear = which_linear + (True,) * len(tangents)
|
||
|
out_flat = for_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr,
|
||
|
nsteps=nsteps, reverse=reverse,
|
||
|
which_linear=jvp_which_linear, unroll=unroll)
|
||
|
# `out_flat` includes constant inputs into the `for_loop` which are converted
|
||
|
# into outputs as well. We don't care about these in AD so we throw them out.
|
||
|
out_primals, out_tangents = split_list(out_flat, [len(primals)])
|
||
|
out_tangents_iter = iter(out_tangents)
|
||
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
||
|
for p, nz in zip(out_primals, nonzero_tangents)]
|
||
|
return out_primals, out_tangents
|
||
|
ad.primitive_jvps[for_p] = _for_jvp
|
||
|
|
||
|
|
||
|
def _partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy):
|
||
|
# A simple wrapper around `pe.partial_eval_jaxpr_custom` that assumes all
|
||
|
# inputs are instantiated and doesn't ensure any outputs are unknown or
|
||
|
# instantiated.
|
||
|
return pe.partial_eval_jaxpr_custom(
|
||
|
jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)
|
||
|
|
||
|
_save_everything = lambda *_, **__: True
|
||
|
|
||
|
def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
|
||
|
assert len(ref_effects) > 0
|
||
|
if len(ref_effects) > 1:
|
||
|
# Means we must have a write or accum effect so not read-only
|
||
|
return False
|
||
|
eff, = ref_effects
|
||
|
return isinstance(eff, ReadEffect)
|
||
|
|
||
|
def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
|
||
|
# Get effects for each of the jaxpr inputs and remove the loop index.
|
||
|
ref_effects = state_types.get_ref_state_effects(
|
||
|
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
|
||
|
# We first assume that *read-only `Ref`s* are loop-invariant. We can safely do
|
||
|
# this because the only way something can be loop-varying is if we write to it
|
||
|
# at some point. It's *possible* that read-write `Ref`s are loop-invariant but
|
||
|
# we conservatively assume they aren't.
|
||
|
loop_invar_refs = [_is_read_only(effs) if effs else True
|
||
|
for effs in ref_effects]
|
||
|
loop_var_refs = map(operator.not_, loop_invar_refs)
|
||
|
|
||
|
# We'd like to detect if the outputs of the jaxpr are loop-invariant. An
|
||
|
# output is loop-invariant if it is downstream of only loop-invariant values
|
||
|
# (seeded by the read-only `Ref`s). If at any point, a loop-varying value
|
||
|
# interacts with a loop-invariant value, we produce a loop-varying value. We
|
||
|
# can use `partial_eval` to perform this analysis by treating loop-varying
|
||
|
# values as "unknown" and loop-invariant values as "known", since when a known
|
||
|
# and unknown value interact, they produce an unknown value.
|
||
|
loop_var_inputs = [True, *loop_var_refs]
|
||
|
_, _, loop_var_outputs, _, _, = _partial_eval_jaxpr_custom(
|
||
|
jaxpr, loop_var_inputs, _save_everything)
|
||
|
return map(operator.not_, loop_var_outputs)
|
||
|
|
||
|
|
||
|
def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||
|
jaxpr: core.Jaxpr, nsteps: int, reverse: bool,
|
||
|
which_linear: Tuple[bool, ...],
|
||
|
unroll: int) -> List[pe.JaxprTracer]:
|
||
|
num_inputs = len(tracers)
|
||
|
assert num_inputs == len(jaxpr.invars) - 1
|
||
|
in_unknowns = [not t.pval.is_known() for t in tracers]
|
||
|
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
||
|
# after running the for loop. We want to use the jaxpr to determine which
|
||
|
# `Ref`s are unknown after executing the for loop body given which `Ref`s are
|
||
|
# unknown before. However, the jaxpr has no outputs. Instead, we discharge
|
||
|
# the body and run the fixpoint with the discharged jaxpr. We can do this
|
||
|
# because the outputs of the jaxpr are one-to-one with the inputs.
|
||
|
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
||
|
discharged_jaxpr = discharged_jaxpr.replace(
|
||
|
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
|
||
|
constvars=[])
|
||
|
for _ in range(num_inputs):
|
||
|
jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns]
|
||
|
_, _, out_unknowns, _, _, = pe.partial_eval_jaxpr_custom(
|
||
|
discharged_jaxpr, jaxpr_in_unknowns, [True] * len(jaxpr_in_unknowns),
|
||
|
in_unknowns, False, _save_everything)
|
||
|
out_unknowns = list(out_unknowns)
|
||
|
if out_unknowns == in_unknowns:
|
||
|
break
|
||
|
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
||
|
else:
|
||
|
raise Exception("Invalid fixpoint")
|
||
|
del out_unknowns # redundant since it's the same as `in_unknowns`
|
||
|
tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore
|
||
|
for t, uk in zip(tracers, in_unknowns))
|
||
|
|
||
|
# We use `partial_eval_jaxpr_custom` here because it won't remove effectful
|
||
|
# primitives like `get`/`set`.
|
||
|
jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \
|
||
|
_partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns],
|
||
|
_save_everything)
|
||
|
# # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
|
||
|
# regular valued input/outputs. However, we'd like to bind these jaxprs to a
|
||
|
# `for`, which expects only `Ref` inputs and no output. We need to convert
|
||
|
# both of these jaxprs into ones that are compatible with `for`.
|
||
|
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
|
||
|
# TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of
|
||
|
# passing the loop index as a residual
|
||
|
|
||
|
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
||
|
# to output residual values (none of them should be `Ref`s). We'll need to
|
||
|
# convert the output residual values into `Ref`s that are initially empty
|
||
|
# `Ref`s that are written to at the end of the jaxpr.
|
||
|
|
||
|
# # Loop-invariant residual optimization
|
||
|
# Here we are interested in finding out which of the residuals are *not*
|
||
|
# dependent on the loop index. If a residual is not dependent on the loop
|
||
|
# index, we don't need add an extra loop dimension we're reading from when we
|
||
|
# convert it from an output into a write.
|
||
|
loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout)
|
||
|
|
||
|
jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps,
|
||
|
jaxpr_known_resout,
|
||
|
loop_invar_res)
|
||
|
# We now run the known jaxpr to obtain our residual values.
|
||
|
known_tracers, _ = partition_list(in_unknowns, tracers)
|
||
|
known_vals = [t.pval.get_known() for t in known_tracers]
|
||
|
empty_res = map(ad_util.zeros_like_aval, res_avals)
|
||
|
jaxpr_known_args = [*known_vals, *empty_res]
|
||
|
# We assume the known inputs are nonlinear which is okay to do for AD but not
|
||
|
# necessarily okay for general partial eval.
|
||
|
jaxpr_known_which_linear = (False,) * len(jaxpr_known_args)
|
||
|
out_flat = for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
|
||
|
reverse=reverse, which_linear=jaxpr_known_which_linear,
|
||
|
unroll=unroll)
|
||
|
known_outputs, residuals = split_list(out_flat, [len(known_tracers)])
|
||
|
residuals = map(trace.new_instantiated_const, residuals)
|
||
|
|
||
|
# Now we handle the `jaxpr_unknown` that expects residual values as inputs.
|
||
|
# This jaxpr is the output of `partial_eval_jaxpr_custom` that marks which
|
||
|
# inputs are actually used.
|
||
|
# `partial_eval_jaxpr_custom` doesn't remove extra inputs/outputs for you
|
||
|
# so we use `dce_jaxpr` here to do that.
|
||
|
jaxpr_unknown_resin, used_inputs = pe.dce_jaxpr(
|
||
|
jaxpr_unknown_resin_, [], [True] * num_res + [True, *in_unknowns])
|
||
|
used_res, (used_i,), used_refs = split_list(used_inputs, [num_res, 1])
|
||
|
assert all(used_res), "All residuals should be used"
|
||
|
# To make it compatible with `for`, we need to convert those residual values
|
||
|
# into `Ref`s.
|
||
|
jaxpr_unknown = _convert_inputs_to_reads(nsteps, len(res_avals),
|
||
|
jaxpr_unknown_resin,
|
||
|
loop_invar_res)
|
||
|
# Since not all inputs are used in jaxpr_unknown, we filter the input tracers
|
||
|
# down using the output of `dce_jaxpr`.
|
||
|
used_and_known = map(operator.and_, used_refs, map(operator.not_, in_unknowns))
|
||
|
tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k # type: ignore
|
||
|
in zip(tracers, used_and_known)]
|
||
|
_, known_used = partition_list(used_refs, used_and_known)
|
||
|
_, used_tracers = partition_list(used_refs, tracers)
|
||
|
_, used_which_linear = partition_list(used_refs, which_linear)
|
||
|
which_linear_unknown = (False,) * num_res + tuple(used_which_linear)
|
||
|
unknown_inputs = [*residuals, *used_tracers]
|
||
|
# Outputs match inputs so we construct output tracers that look like the input
|
||
|
# tracers.
|
||
|
res_ref_unknown_outputs = [
|
||
|
pe.JaxprTracer(trace, pe.PartialVal.unknown(t.aval), None)
|
||
|
for t in unknown_inputs]
|
||
|
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||
|
source = source_info_util.current().replace(name_stack=name_stack)
|
||
|
|
||
|
assert len(unknown_inputs) == len(res_ref_unknown_outputs)
|
||
|
assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1
|
||
|
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
|
||
|
for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps,
|
||
|
reverse=reverse,
|
||
|
which_linear=which_linear_unknown,
|
||
|
unroll=unroll),
|
||
|
core.no_effects, source)
|
||
|
for t in res_ref_unknown_outputs: t.recipe = eqn
|
||
|
_, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res])
|
||
|
unknown_outputs, _ = partition_list(known_used, unknown_outputs)
|
||
|
return merge_lists(in_unknowns, known_outputs, unknown_outputs)
|
||
|
pe.custom_partial_eval_rules[for_p] = _for_partial_eval
|
||
|
|
||
|
def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||
|
jaxpr, nsteps, reverse, which_linear, unroll = split_dict(
|
||
|
eqn.params, ["jaxpr", "nsteps", "reverse", "which_linear", "unroll"])
|
||
|
num_inputs = len(eqn.invars)
|
||
|
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
||
|
# after running the for loop. However, the jaxpr has no outputs. Instead, we
|
||
|
# discharge the body and run the fixpoint with the discharged jaxpr. We can do
|
||
|
# this because the outputs of the discharged jaxpr are one-to-one with the
|
||
|
# inputs.
|
||
|
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
||
|
discharged_jaxpr = discharged_jaxpr.replace(
|
||
|
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
|
||
|
constvars=[])
|
||
|
in_unknowns, in_inst = list(in_unknowns), list(in_inst)
|
||
|
out_unknowns, out_inst = in_unknowns, in_inst
|
||
|
for _ in range(num_inputs):
|
||
|
jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns]
|
||
|
_, _, out_unknowns, out_inst, _, = pe.partial_eval_jaxpr_custom(
|
||
|
discharged_jaxpr, jaxpr_in_unknowns, True,
|
||
|
ensure_out_unknowns=in_unknowns, ensure_out_inst=True,
|
||
|
saveable=saveable)
|
||
|
out_unknowns = list(out_unknowns)
|
||
|
if out_unknowns == in_unknowns:
|
||
|
break
|
||
|
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
||
|
else:
|
||
|
if num_inputs > 0: raise Exception("Invalid fixpoint")
|
||
|
del out_unknowns # Redundant since it's the same as `in_unknowns`
|
||
|
new_inst = [x for x, inst in zip(eqn.invars, in_inst)
|
||
|
if type(x) is core.Var and not inst]
|
||
|
in_inst = [True] * len(eqn.invars)
|
||
|
|
||
|
# We use `partial_eval_jaxpr_custom` here because it won't remove effectful
|
||
|
# primitives like `get`/`set`.
|
||
|
jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res = \
|
||
|
pe.partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns],
|
||
|
[True, *in_inst], [], [], saveable)
|
||
|
|
||
|
# `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
|
||
|
# non-Ref input/outputs. However, we'd like to bind these jaxprs to a
|
||
|
# `for`, which expects only `Ref` inputs and no output. We need to convert
|
||
|
# both of these jaxprs into ones that are compatible with `for`.
|
||
|
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
|
||
|
# TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of
|
||
|
# passing the loop index as a residual
|
||
|
|
||
|
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
||
|
# to output residual values (none of them should be `Ref`s). We'll need to
|
||
|
# convert the output residual values into `Ref`s that are initially empty
|
||
|
# `Ref`s that are written to at the end of the jaxpr.
|
||
|
|
||
|
# # Loop-invariant residual optimization
|
||
|
# Here we are interested in finding out which of the residuals are *not*
|
||
|
# dependent on the loop index. If a residual is not dependent on the loop
|
||
|
# index, we don't need add an extra loop dimension we're reading from when we
|
||
|
# convert it from an output into a write.
|
||
|
loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout)
|
||
|
|
||
|
jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps,
|
||
|
jaxpr_known_resout,
|
||
|
loop_invar_res)
|
||
|
|
||
|
known_invars, _ = partition_list(in_unknowns, eqn.invars)
|
||
|
known_outvars, _ = partition_list(in_unknowns, eqn.outvars)
|
||
|
newvar = core.gensym()
|
||
|
resvars = map(newvar, res_avals)
|
||
|
|
||
|
@lu.wrap_init
|
||
|
def known(*known_vals):
|
||
|
empty_res = map(ad_util.zeros_like_aval, res_avals)
|
||
|
jaxpr_known_args = [*known_vals, *empty_res]
|
||
|
jaxpr_known_which_linear = (False,) * len(jaxpr_known_args)
|
||
|
return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
|
||
|
reverse=reverse, which_linear=jaxpr_known_which_linear,
|
||
|
unroll=unroll)
|
||
|
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||
|
known, [v.aval for v in known_invars])
|
||
|
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||
|
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
|
||
|
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||
|
call_jaxpr.effects, eqn.source_info)
|
||
|
|
||
|
jaxpr_staged = _convert_inputs_to_reads(nsteps, len(res_avals),
|
||
|
jaxpr_staged_resin_,
|
||
|
loop_invar_res)
|
||
|
which_linear_unknown = (False,) * num_res + tuple(which_linear)
|
||
|
params_staged = dict(eqn.params, jaxpr=jaxpr_staged, reverse=reverse,
|
||
|
nsteps=nsteps,
|
||
|
which_linear=which_linear_unknown,
|
||
|
unroll=unroll)
|
||
|
|
||
|
@lu.wrap_init
|
||
|
def staged(*res_and_refs):
|
||
|
out_flat = for_p.bind(*res_and_refs, **params_staged)
|
||
|
_, ans = split_list(out_flat, [num_res])
|
||
|
_, ans = partition_list(out_inst, ans)
|
||
|
return ans
|
||
|
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||
|
staged, [v.aval for v in [*resvars, *eqn.invars]])
|
||
|
assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars)
|
||
|
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||
|
_, outvars = partition_list(out_inst, eqn.outvars)
|
||
|
eqn_staged = pe.new_jaxpr_eqn([*resvars, *eqn.invars], outvars,
|
||
|
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||
|
call_jaxpr.effects, eqn.source_info)
|
||
|
new_vars = [*new_inst, *resvars]
|
||
|
return eqn_known, eqn_staged, in_unknowns, out_inst, new_vars
|
||
|
|
||
|
pe.partial_eval_jaxpr_custom_rules[for_p] = _for_partial_eval_custom
|
||
|
|
||
|
def _convert_outputs_to_writes(
|
||
|
nsteps: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool]
|
||
|
) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
|
||
|
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
||
|
|
||
|
in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals]
|
||
|
@lu.wrap_init
|
||
|
def eval_jaxpr(i, *refs):
|
||
|
# We split the refs into the original input refs and the dummy residual
|
||
|
# refs.
|
||
|
orig_refs, residual_refs = split_list(refs, [len(in_avals) - 1])
|
||
|
residual_vals = core.eval_jaxpr(jaxpr, (), i, *orig_refs)
|
||
|
for res_ref, res_val, loop_invar in zip(residual_refs, residual_vals,
|
||
|
loop_invar_res):
|
||
|
if loop_invar:
|
||
|
res_ref[()] = res_val
|
||
|
else:
|
||
|
res_ref[i] = res_val
|
||
|
return []
|
||
|
# TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype
|
||
|
res_ref_avals: List[core.AbstractValue] = [
|
||
|
AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error
|
||
|
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
|
||
|
v.aval.dtype)) # pytype: disable=attribute-error
|
||
|
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
|
||
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||
|
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||
|
assert not consts
|
||
|
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error
|
||
|
|
||
|
def _convert_inputs_to_reads(
|
||
|
nsteps: int, num_res: int, jaxpr: core.Jaxpr,
|
||
|
loop_invar_res: Sequence[bool]) -> core.Jaxpr:
|
||
|
assert not jaxpr.constvars, "Jaxpr should not have constvars"
|
||
|
|
||
|
@lu.wrap_init
|
||
|
def eval_jaxpr(i, *refs):
|
||
|
residual_refs, orig_refs = split_list(refs, [num_res])
|
||
|
residual_vals = [r[()] if loop_invar else r[i] for r, loop_invar
|
||
|
in zip(residual_refs, loop_invar_res)]
|
||
|
() = core.eval_jaxpr(jaxpr, (), *residual_vals, i, *orig_refs)
|
||
|
return []
|
||
|
|
||
|
res_val_avals, (i_aval,), orig_ref_avals = \
|
||
|
split_list([v.aval for v in jaxpr.invars], [num_res, 1])
|
||
|
res_ref_avals: List[core.AbstractValue] = [
|
||
|
AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error
|
||
|
AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error
|
||
|
aval.dtype)) # pytype: disable=attribute-error
|
||
|
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
|
||
|
|
||
|
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
|
||
|
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
|
||
|
return jaxpr
|
||
|
|
||
|
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: List[bool]) -> core.Jaxpr:
|
||
|
def trans(i, *args):
|
||
|
# First we want to run the computation to read all the residual refs. We can
|
||
|
# do that by using partial evaluation with all linear inputs unknown.
|
||
|
res_jaxpr, tangent_jaxpr_, *_ = \
|
||
|
_partial_eval_jaxpr_custom(jaxpr, [False, *which_linear],
|
||
|
_save_everything)
|
||
|
res_args = [x for x, lin in zip(args, which_linear) if not lin]
|
||
|
res = core.eval_jaxpr(res_jaxpr, (), i, *res_args)
|
||
|
|
||
|
# Now that we have residual values, we run the tangent jaxpr. It takes as
|
||
|
# input the residuals, the loop index, and all the refs (at least, the ones
|
||
|
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
|
||
|
# unknown inputs!
|
||
|
tangent_jaxpr, used = pe.dce_jaxpr(tangent_jaxpr_, [])
|
||
|
used_res, (used_i,), used_ct = split_list(used, [len(res), 1])
|
||
|
primals_args = [*(r for u, r in zip(used_res, res) if u)]
|
||
|
if used_i:
|
||
|
primals_args = [*primals_args, i]
|
||
|
ct_args = [x for x, u in zip(args, used_ct) if u]
|
||
|
ad.backward_pass(
|
||
|
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
|
||
|
return []
|
||
|
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
|
||
|
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
||
|
return jaxpr_trans
|
||
|
|
||
|
def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll):
|
||
|
# if any in_ct is nonzero, we definitely want it in args_ (and the
|
||
|
# corresponding x in args could be an undefined primal, but doesnt have to be)
|
||
|
# for non-res stuff:
|
||
|
# getting and setting => (nonzero ct, UndefinedPrimal arg)
|
||
|
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
|
||
|
# just getting => (zero ct , UndefinedPrimal arg)
|
||
|
# for res stuff:
|
||
|
# (zero ct , not UndefinedPrimal)
|
||
|
args_ = []
|
||
|
which_linear_transpose = []
|
||
|
for x, ct in zip(args, in_cts):
|
||
|
if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x):
|
||
|
# this is a residual, take x!
|
||
|
args_.append(x)
|
||
|
which_linear_transpose.append(False)
|
||
|
elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x):
|
||
|
# the loop was 'just getting', plug in a zero
|
||
|
args_.append(ad_util.zeros_like_aval(x.aval))
|
||
|
which_linear_transpose.append(False)
|
||
|
elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x):
|
||
|
# the loop was 'just setting', grab that cotangent! x is dummy
|
||
|
args_.append(ct)
|
||
|
which_linear_transpose.append(False)
|
||
|
elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x):
|
||
|
# the loop was 'getting and setting', grab that cotangent!
|
||
|
args_.append(ct)
|
||
|
which_linear_transpose.append(True)
|
||
|
|
||
|
jaxpr_transpose = transpose_jaxpr(jaxpr, which_linear)
|
||
|
assert len(args_) == len(jaxpr_transpose.invars) - 1
|
||
|
all_outs = for_p.bind(*args_, jaxpr=jaxpr_transpose, nsteps=nsteps,
|
||
|
reverse=not reverse,
|
||
|
which_linear=tuple(which_linear_transpose),
|
||
|
unroll=unroll)
|
||
|
ct_outs = [ct if ad.is_undefined_primal(x) else None
|
||
|
for x, ct in zip(args, all_outs)]
|
||
|
return ct_outs
|
||
|
ad.primitive_transposes[for_p] = _for_transpose
|
||
|
|
||
|
### Testing utility
|
||
|
|
||
|
def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
|
||
|
"""A `for_loop` implementation that discharges its body right away.
|
||
|
|
||
|
Potentially useful for testing and benchmarking.
|
||
|
"""
|
||
|
flat_state, state_tree = tree_flatten(init_state)
|
||
|
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||
|
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
|
||
|
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||
|
body, state_tree, [idx_aval, *state_avals])
|
||
|
if out_tree != tree_structure(None):
|
||
|
raise Exception("`body` should not return anything.")
|
||
|
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
|
||
|
|
||
|
def fori_body(i, carry):
|
||
|
i = jnp.int32(i)
|
||
|
if reverse:
|
||
|
i = nsteps - i - 1
|
||
|
out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
|
||
|
i, *carry)
|
||
|
return out_flat
|
||
|
out_flat = loops.fori_loop(0, nsteps, fori_body, flat_state)
|
||
|
return tree_unflatten(state_tree, out_flat)
|
||
|
|
||
|
def run_state(f, init_state):
|
||
|
@functools.wraps(f)
|
||
|
def wrapped_body(_, *args):
|
||
|
return f(*args)
|
||
|
return for_loop(1, wrapped_body, init_state)
|