503 lines
21 KiB
Python
503 lines
21 KiB
Python
![]() |
# Copyright 2022 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Module for state primitives."""
|
||
|
from functools import partial
|
||
|
|
||
|
from typing import Any, List, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
from jax._src import ad_util
|
||
|
from jax._src import core
|
||
|
from jax._src import pretty_printer as pp
|
||
|
from jax._src.interpreters import ad
|
||
|
from jax._src.interpreters import batching
|
||
|
from jax._src.interpreters import partial_eval as pe
|
||
|
from jax._src.lax import lax
|
||
|
from jax._src.typing import Array
|
||
|
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||
|
AccumEffect)
|
||
|
from jax._src.util import safe_map, safe_zip, tuple_insert
|
||
|
|
||
|
|
||
|
## General utilities
|
||
|
|
||
|
## JAX utilities
|
||
|
|
||
|
map, unsafe_map = safe_map, map
|
||
|
zip, unsafe_zip = safe_zip, zip
|
||
|
|
||
|
## get/swap/addupdate implementations
|
||
|
|
||
|
# `get` reads a value from a `Ref` type, a.k.a.:
|
||
|
# a = get_p.bind(x)
|
||
|
# or we can read using indices:
|
||
|
# a = get_p.bind(x, 0, 1)
|
||
|
# Staging out `a = get_p.bind(x)` where the aval of `x` is
|
||
|
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||
|
# a:f32[3] <- x[]
|
||
|
get_p = core.Primitive("get")
|
||
|
|
||
|
def _get_impl(ref: AbstractRef, *idx: int, **_):
|
||
|
del ref, idx
|
||
|
raise ValueError("Cannot run stateful primitive.")
|
||
|
get_p.def_impl(_get_impl)
|
||
|
|
||
|
Indexer = Tuple[Union[int, slice, Array], ...]
|
||
|
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||
|
|
||
|
def _is_trivial_indexer(idx: Indexer) -> bool:
|
||
|
if idx is ...:
|
||
|
return True
|
||
|
if type(idx) is tuple:
|
||
|
if len(idx) == 0:
|
||
|
return True
|
||
|
return len(idx) == 1 and idx[0] is ...
|
||
|
return False
|
||
|
|
||
|
def _unpack_idx(idx: Indexer, ndim: int
|
||
|
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
|
||
|
if _is_trivial_indexer(idx):
|
||
|
idx = tuple(slice(None) for _ in range(ndim))
|
||
|
indexed_dims_ = []
|
||
|
non_slice_idx = []
|
||
|
for i in idx:
|
||
|
if isinstance(i, slice):
|
||
|
if i.start is not None or i.stop is not None or i.step is not None:
|
||
|
raise NotImplementedError("Reference indexing only supports trivial slices")
|
||
|
indexed_dims_.append(False)
|
||
|
else:
|
||
|
non_slice_idx.append(i)
|
||
|
indexed_dims_.append(True)
|
||
|
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
|
||
|
import jax.numpy as jnp
|
||
|
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
|
||
|
|
||
|
def _get_slice_output_shape(in_shape: Tuple[int, ...],
|
||
|
idx_shapes: Tuple[Tuple[int, ...], ...],
|
||
|
indexed_dims: Tuple[bool, ...]) -> Tuple[int, ...]:
|
||
|
shape_suffix = [d for i, d in zip(indexed_dims, in_shape) if not i]
|
||
|
shape_prefix, = set(idx_shapes) or [()] # tie fighter
|
||
|
# Move shape prefix dimensions to the front
|
||
|
shape = (*shape_prefix, *shape_suffix)
|
||
|
return shape
|
||
|
|
||
|
def _get_indexer(ref: AbstractRef, idx: Indexer
|
||
|
) -> Tuple[Indexer, Tuple[bool, ...]]:
|
||
|
if isinstance(ref.inner_aval, core.ShapedArray):
|
||
|
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||
|
else:
|
||
|
if not _is_trivial_indexer(idx):
|
||
|
raise ValueError(
|
||
|
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
|
||
|
non_slice_idx, indexed_dims = (), ()
|
||
|
return non_slice_idx, indexed_dims
|
||
|
|
||
|
def ref_get(ref: Any, idx: Indexer) -> Array:
|
||
|
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||
|
ref_aval = core.get_aval(ref)
|
||
|
if not isinstance(ref_aval, AbstractRef):
|
||
|
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
|
||
|
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||
|
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
|
||
|
|
||
|
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||
|
# b = swap_p.bind(x, a)
|
||
|
# It generalizes the setting operation for a `Ref` as we can ignore the return
|
||
|
# value:
|
||
|
# _ = swap_p.bind(x, a)
|
||
|
# `swap_p` also takes in index arguments following the value, i.e.:
|
||
|
# _ = swap_p.bind(x, a, 0, 1)
|
||
|
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
|
||
|
# `Ref((3,), np.dtype('float32'))` and the aval of `a` is
|
||
|
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||
|
# b:f32[3], x:Ref{f32[3]} <- x, a
|
||
|
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
|
||
|
# `Ref((3,), np.dtype('float32'))` , the aval of `a` is
|
||
|
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
|
||
|
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
|
||
|
# x:Ref{f32[3]}[i, j] <- a
|
||
|
swap_p = core.Primitive("swap")
|
||
|
|
||
|
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
|
||
|
del ref, value, idx
|
||
|
raise ValueError("Cannot run stateful primitive.")
|
||
|
swap_p.def_impl(_swap_impl)
|
||
|
|
||
|
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
|
||
|
"""Sets a `Ref`'s value and returns the original value."""
|
||
|
ref_aval = core.get_aval(ref)
|
||
|
if not isinstance(ref_aval, AbstractRef):
|
||
|
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
|
||
|
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||
|
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
|
||
|
|
||
|
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
|
||
|
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||
|
ref_swap(ref, idx, value)
|
||
|
|
||
|
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
|
||
|
# Semantically,
|
||
|
# ```
|
||
|
# addupdate ref a *idx
|
||
|
# ```
|
||
|
# is equivalent to
|
||
|
# ```
|
||
|
# b = get ref *idx
|
||
|
# c = add b x
|
||
|
# _ = swap ref c *idx
|
||
|
# ```
|
||
|
addupdate_p = core.Primitive('addupdate')
|
||
|
addupdate_p.multiple_results = True
|
||
|
|
||
|
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
|
||
|
del ref, idx, value
|
||
|
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||
|
addupdate_p.def_impl(_addupdate_impl)
|
||
|
|
||
|
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
|
||
|
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||
|
ref_aval = core.get_aval(ref)
|
||
|
if not isinstance(ref_aval, AbstractRef):
|
||
|
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
|
||
|
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||
|
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
|
||
|
|
||
|
## get/set/addupdate abstract evaluation rules
|
||
|
|
||
|
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
|
||
|
indexed_dims):
|
||
|
if not isinstance(ref_aval, AbstractRef):
|
||
|
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
||
|
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||
|
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||
|
raise ValueError("`get` with nontrivial indexing must be called "
|
||
|
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||
|
if len(indexed_dims) != len(ref_aval.shape):
|
||
|
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||
|
if sum(indexed_dims) != len(idx):
|
||
|
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||
|
idx_shapes = tuple(i.shape for i in idx)
|
||
|
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
|
||
|
out_aval = ref_aval.inner_aval.update(shape=shape)
|
||
|
else:
|
||
|
if idx:
|
||
|
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||
|
out_aval = ref_aval.inner_aval
|
||
|
return (out_aval, {ReadEffect(0)})
|
||
|
get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||
|
|
||
|
def _swap_abstract_eval(ref_aval: AbstractRef,
|
||
|
val_aval: core.AbstractValue,
|
||
|
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
|
||
|
out_aval: core.AbstractValue
|
||
|
if not isinstance(ref_aval, AbstractRef):
|
||
|
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||
|
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||
|
if len(indexed_dims) != len(ref_aval.shape):
|
||
|
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||
|
if sum(indexed_dims) != len(idx):
|
||
|
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||
|
val_aval = core.raise_to_shaped(val_aval)
|
||
|
assert isinstance(val_aval, core.ShapedArray)
|
||
|
idx_shapes = tuple(i.shape for i in idx)
|
||
|
expected_output_shape = _get_slice_output_shape(
|
||
|
ref_aval.shape, idx_shapes, indexed_dims)
|
||
|
if expected_output_shape != val_aval.shape:
|
||
|
raise ValueError("Invalid shape for `swap`. "
|
||
|
f"Ref shape: {ref_aval.shape}. "
|
||
|
f"Value shape: {val_aval.shape}. "
|
||
|
f"Indices: {idx}. ")
|
||
|
if ref_aval.dtype != val_aval.dtype:
|
||
|
raise ValueError("Invalid dtype for `swap`. "
|
||
|
f"Ref dtype: {ref_aval.dtype}. "
|
||
|
f"Value shape: {val_aval.dtype}. ")
|
||
|
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
|
||
|
else:
|
||
|
if idx:
|
||
|
raise ValueError("`swap` with nontrivial indexing must be called "
|
||
|
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||
|
out_aval = ref_aval.inner_aval
|
||
|
return (out_aval, {WriteEffect(0)})
|
||
|
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||
|
|
||
|
|
||
|
def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
||
|
val_aval: core.AbstractValue,
|
||
|
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
|
||
|
if not isinstance(ref_aval, AbstractRef):
|
||
|
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||
|
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||
|
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
||
|
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||
|
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||
|
if len(indexed_dims) != len(ref_aval.shape):
|
||
|
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||
|
if sum(indexed_dims) != len(idx):
|
||
|
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||
|
val_aval = core.raise_to_shaped(val_aval)
|
||
|
assert isinstance(val_aval, core.ShapedArray)
|
||
|
idx_shapes = tuple(i.shape for i in idx)
|
||
|
slice_shape = _get_slice_output_shape(
|
||
|
ref_aval.shape, idx_shapes, indexed_dims)
|
||
|
if slice_shape != val_aval.shape:
|
||
|
raise ValueError("Invalid shape for `addupdate`. "
|
||
|
f"Ref shape: {ref_aval.shape}. "
|
||
|
f"Value shape: {val_aval.shape}. "
|
||
|
f"Indices: {idx}. ")
|
||
|
if ref_aval.dtype != val_aval.dtype:
|
||
|
raise ValueError("Invalid dtype for `addupdate`. "
|
||
|
f"Ref dtype: {ref_aval.dtype}. "
|
||
|
f"Value shape: {val_aval.dtype}. ")
|
||
|
elif idx:
|
||
|
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
||
|
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||
|
return [], {AccumEffect(0)}
|
||
|
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||
|
|
||
|
## Pretty printing for `get` and `swap` in jaxprs
|
||
|
|
||
|
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
|
||
|
foreground=pp.Color.GREEN)
|
||
|
|
||
|
def _pp_idx(context, non_slice_idx, indexed_dims):
|
||
|
idx_iter = iter(non_slice_idx)
|
||
|
idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':'
|
||
|
for indexed in indexed_dims)
|
||
|
assert next(idx_iter, None) is None
|
||
|
return pp.text(idx)
|
||
|
|
||
|
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
|
||
|
# Pretty prints `a = get x i` as `x[i] <- a`
|
||
|
y, = eqn.outvars
|
||
|
x, *idx = eqn.invars
|
||
|
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
||
|
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||
|
# TODO more general get
|
||
|
return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([
|
||
|
pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))])
|
||
|
core.pp_eqn_rules[get_p] = _get_pp_rule
|
||
|
|
||
|
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
|
||
|
y, = eqn.outvars
|
||
|
x, v, *idx = eqn.invars
|
||
|
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
||
|
if type(y) is core.DropVar:
|
||
|
# In the case of a set (ignored return value),
|
||
|
# pretty print `_ = swap x v i` as `x[i] <- v`
|
||
|
del y
|
||
|
return pp.concat([
|
||
|
pp_ref(pp.concat([
|
||
|
pp.text(core.pp_var(x, context)),
|
||
|
pp.text('['), idx, pp.text(']')
|
||
|
])), pp.text(' <- '), pp.text(core.pp_var(v, context))])
|
||
|
else:
|
||
|
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
|
||
|
x_i = pp.concat([pp.text(core.pp_var(x, context)),
|
||
|
pp.text('['), idx, pp.text(']')])
|
||
|
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
||
|
return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '),
|
||
|
pp_ref(x_i), pp.text(', '),
|
||
|
pp.text(core.pp_var(v, context))])
|
||
|
core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
||
|
|
||
|
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
|
||
|
# pretty-print ` = addupdate x i v` as `x[i] += v`
|
||
|
() = eqn.outvars
|
||
|
x, v, *idx = eqn.invars
|
||
|
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
|
||
|
return pp.concat([
|
||
|
pp_ref(pp.concat([
|
||
|
pp.text(core.pp_var(x, context)),
|
||
|
pp.text('['), idx, pp.text(']')
|
||
|
])), pp.text(' += '), pp.text(core.pp_var(v, context))])
|
||
|
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||
|
|
||
|
## get/swap/addupdate JVP rules
|
||
|
|
||
|
def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||
|
ref_primal, *idx = primals
|
||
|
assert isinstance(ref_primal.aval, AbstractRef)
|
||
|
ref_tangent, *_ = tangents
|
||
|
assert isinstance(ref_tangent.aval, AbstractRef)
|
||
|
return (get_p.bind(ref_primal, *idx, **params),
|
||
|
get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type]
|
||
|
ad.primitive_jvps[get_p] = _get_jvp
|
||
|
|
||
|
def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||
|
ref_primal, x_primal, *idx = primals
|
||
|
assert isinstance(ref_primal.aval, AbstractRef)
|
||
|
ref_tangent, x_tangent, *_ = tangents
|
||
|
assert isinstance(ref_tangent.aval, AbstractRef)
|
||
|
x_tangent = ad_util.instantiate(x_tangent)
|
||
|
return (swap_p.bind(ref_primal, x_primal, *idx, **params), # type: ignore[arg-type]
|
||
|
swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type]
|
||
|
ad.primitive_jvps[swap_p] = _swap_jvp
|
||
|
|
||
|
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any], **params: Any):
|
||
|
ref_primal, x_primal, *idx = primals
|
||
|
ref_tangent, x_tangent, *_ = tangents
|
||
|
x_tangent = ad_util.instantiate(x_tangent)
|
||
|
addupdate_p.bind(ref_primal, x_primal, *idx, **params)
|
||
|
addupdate_p.bind(ref_tangent, x_tangent, *idx, **params)
|
||
|
return [], []
|
||
|
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
|
||
|
|
||
|
## get/swap/addupdate transpose rules
|
||
|
|
||
|
def _get_transpose(g, ref, *idx, **params):
|
||
|
# get transpose is addupdate
|
||
|
if type(g) is not ad_util.Zero:
|
||
|
addupdate_p.bind(ref, g, *idx, **params)
|
||
|
return [None] + [None] * len(idx)
|
||
|
ad.primitive_transposes[get_p] = _get_transpose
|
||
|
|
||
|
def _swap_transpose(g, ref, x, *idx, **params):
|
||
|
# swap transpose is swap
|
||
|
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
|
||
|
return [None, x_bar] + [None] * len(idx)
|
||
|
ad.primitive_transposes[swap_p] = _swap_transpose
|
||
|
|
||
|
def addupdate_transpose(cts_in, ref, x, *idx, **params):
|
||
|
# addupdate transpose is get
|
||
|
del cts_in, x
|
||
|
g = get_p.bind(ref, *idx, **params)
|
||
|
return [None, g] + [None] * len(idx)
|
||
|
ad.primitive_transposes[addupdate_p] = addupdate_transpose
|
||
|
|
||
|
## get/swap/addupdate partial_eval_custom rules
|
||
|
|
||
|
def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn):
|
||
|
if any(unks_in):
|
||
|
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
|
||
|
return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res
|
||
|
elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params):
|
||
|
return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), []
|
||
|
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
|
||
|
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res
|
||
|
|
||
|
pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom,
|
||
|
get_p)
|
||
|
pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom,
|
||
|
swap_p)
|
||
|
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
|
||
|
_state_partial_eval_custom, addupdate_p)
|
||
|
|
||
|
## get/swap/addupdate batching rules
|
||
|
|
||
|
def _output_bdim(indexed_dims: Tuple[bool, ...], ref_dim: int,
|
||
|
idxs_shape: Tuple[int, ...]):
|
||
|
num_idxs_to_left = sum(indexed_dims[:ref_dim])
|
||
|
return ref_dim - num_idxs_to_left + len(idxs_shape)
|
||
|
|
||
|
def _get_vmap(batched_args, batched_dims, *, indexed_dims):
|
||
|
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||
|
if d is not batching.not_mapped}
|
||
|
ref, *idxs = batched_args
|
||
|
ref_dim, *idx_dims = batched_dims
|
||
|
|
||
|
ref_is_batched = ref_dim is not batching.not_mapped
|
||
|
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
||
|
bdim_out = 0
|
||
|
|
||
|
if idx_is_batched:
|
||
|
# If at least one of the idx is batched, we broadcast them all and move the
|
||
|
# batch dim to the front.
|
||
|
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
|
||
|
in zip(idxs, idx_dims))
|
||
|
idxs_shape, = {i.shape for i in idxs} or [()]
|
||
|
if ref_is_batched:
|
||
|
# If ref is batched, we are doing a `get` with an additional axis. If `idxs`
|
||
|
# are also batched, then we are indexing into the batch axis with an `iota`.
|
||
|
indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched)
|
||
|
if idx_is_batched:
|
||
|
# If we have batched idx, we need to insert the new iota index. The place
|
||
|
# where we add in the new `iota` index is `ref_dim` so we need to compute
|
||
|
# what `ref_dim` *would be* if we inserted it into `idxs` instead, because
|
||
|
# `idxs` doesn't include the non indexed dims.
|
||
|
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||
|
if i_dim].index(ref_dim)
|
||
|
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||
|
idxs = tuple_insert(idxs, idx_place, iota)
|
||
|
else:
|
||
|
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||
|
return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out
|
||
|
batching.primitive_batchers[get_p] = _get_vmap
|
||
|
|
||
|
def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
|
||
|
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||
|
if d is not batching.not_mapped}
|
||
|
ref, val, *idxs = batched_args
|
||
|
ref_dim, val_dim, *idx_dims = batched_dims
|
||
|
ref_is_batched = ref_dim is not batching.not_mapped
|
||
|
val_is_batched = val_dim is not batching.not_mapped
|
||
|
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
||
|
if idx_is_batched:
|
||
|
# If at least one of the idx is batched, we broadcast them all and move the
|
||
|
# batch dim to the front.
|
||
|
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
|
||
|
in zip(idxs, idx_dims))
|
||
|
idxs_shape, = {i.shape for i in idxs} or [()]
|
||
|
if ref_is_batched and not idx_is_batched:
|
||
|
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
|
||
|
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||
|
if not val_is_batched:
|
||
|
val = batching.broadcast(val, axis_size, 0)
|
||
|
val_dim = 0
|
||
|
val = batching.moveaxis(val, val_dim, bdim_out)
|
||
|
elif idx_is_batched:
|
||
|
assert ref_is_batched and val_is_batched
|
||
|
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
||
|
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||
|
if i_dim].index(ref_dim)
|
||
|
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||
|
idxs = tuple_insert(idxs, idx_place, iota)
|
||
|
val = batching.moveaxis(val, val_dim, 0)
|
||
|
bdim_out = 0
|
||
|
return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out
|
||
|
batching.primitive_batchers[swap_p] = _swap_vmap
|
||
|
|
||
|
def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
|
||
|
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
|
||
|
if d is not batching.not_mapped}
|
||
|
ref, val, *idxs = batched_args
|
||
|
ref_dim, val_dim, *idx_dims = batched_dims
|
||
|
ref_is_batched = ref_dim is not batching.not_mapped
|
||
|
val_is_batched = val_dim is not batching.not_mapped
|
||
|
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
|
||
|
if idx_is_batched:
|
||
|
# If at least one of the idx is batched, we ensure all have bdims at front.
|
||
|
idxs = tuple(batching.bdim_at_front(i, d, axis_size)
|
||
|
for i, d in zip(idxs, idx_dims))
|
||
|
idxs_shape, = {i.shape for i in idxs} or [()]
|
||
|
if ref_is_batched and not idx_is_batched:
|
||
|
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
|
||
|
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
|
||
|
if not val_is_batched:
|
||
|
val = batching.broadcast(val, axis_size, 0)
|
||
|
val_dim = 0
|
||
|
val = batching.moveaxis(val, val_dim, bdim_out)
|
||
|
elif idx_is_batched:
|
||
|
assert ref_is_batched and val_is_batched
|
||
|
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
|
||
|
idx_place = [i for i, i_dim in enumerate(indexed_dims)
|
||
|
if i_dim].index(ref_dim)
|
||
|
idxs_shape, = {i.shape for i in idxs} or [()]
|
||
|
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
|
||
|
idxs = tuple_insert(idxs, idx_place, iota)
|
||
|
val = batching.moveaxis(val, val_dim, 0)
|
||
|
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
|
||
|
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
|