Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/state/primitives.py

503 lines
21 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state primitives."""
from functools import partial
from typing import Any, List, Tuple, Union
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.typing import Array
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect)
from jax._src.util import safe_map, safe_zip, tuple_insert
## General utilities
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
## get/swap/addupdate implementations
# `get` reads a value from a `Ref` type, a.k.a.:
# a = get_p.bind(x)
# or we can read using indices:
# a = get_p.bind(x, 0, 1)
# Staging out `a = get_p.bind(x)` where the aval of `x` is
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")
def _get_impl(ref: AbstractRef, *idx: int, **_):
del ref, idx
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
Indexer = Tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
def _is_trivial_indexer(idx: Indexer) -> bool:
if idx is ...:
return True
if type(idx) is tuple:
if len(idx) == 0:
return True
return len(idx) == 1 and idx[0] is ...
return False
def _unpack_idx(idx: Indexer, ndim: int
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
if _is_trivial_indexer(idx):
idx = tuple(slice(None) for _ in range(ndim))
indexed_dims_ = []
non_slice_idx = []
for i in idx:
if isinstance(i, slice):
if i.start is not None or i.stop is not None or i.step is not None:
raise NotImplementedError("Reference indexing only supports trivial slices")
indexed_dims_.append(False)
else:
non_slice_idx.append(i)
indexed_dims_.append(True)
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
import jax.numpy as jnp
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
def _get_slice_output_shape(in_shape: Tuple[int, ...],
idx_shapes: Tuple[Tuple[int, ...], ...],
indexed_dims: Tuple[bool, ...]) -> Tuple[int, ...]:
shape_suffix = [d for i, d in zip(indexed_dims, in_shape) if not i]
shape_prefix, = set(idx_shapes) or [()] # tie fighter
# Move shape prefix dimensions to the front
shape = (*shape_prefix, *shape_suffix)
return shape
def _get_indexer(ref: AbstractRef, idx: Indexer
) -> Tuple[Indexer, Tuple[bool, ...]]:
if isinstance(ref.inner_aval, core.ShapedArray):
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
else:
if not _is_trivial_indexer(idx):
raise ValueError(
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
non_slice_idx, indexed_dims = (), ()
return non_slice_idx, indexed_dims
def ref_get(ref: Any, idx: Indexer) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
# It generalizes the setting operation for a `Ref` as we can ignore the return
# value:
# _ = swap_p.bind(x, a)
# `swap_p` also takes in index arguments following the value, i.e.:
# _ = swap_p.bind(x, a, 0, 1)
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
# `Ref((3,), np.dtype('float32'))` and the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# b:f32[3], x:Ref{f32[3]} <- x, a
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
# `Ref((3,), np.dtype('float32'))` , the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
del ref, value, idx
raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl)
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref, idx, value)
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically,
# ```
# addupdate ref a *idx
# ```
# is equivalent to
# ```
# b = get ref *idx
# c = add b x
# _ = swap ref c *idx
# ```
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
del ref, idx, value
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
## get/set/addupdate abstract evaluation rules
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
indexed_dims):
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`get` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
idx_shapes = tuple(i.shape for i in idx)
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
out_aval = ref_aval.inner_aval.update(shape=shape)
else:
if idx:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)})
get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
expected_output_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
else:
if idx:
raise ValueError("`swap` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
slice_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if slice_shape != val_aval.shape:
raise ValueError("Invalid shape for `addupdate`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
elif idx:
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
return [], {AccumEffect(0)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs
pp_ref = partial(pp.color, intensity=pp.Intensity.NORMAL,
foreground=pp.Color.GREEN)
def _pp_idx(context, non_slice_idx, indexed_dims):
idx_iter = iter(non_slice_idx)
idx = ','.join(core.pp_var(next(idx_iter), context) if indexed else ':'
for indexed in indexed_dims)
assert next(idx_iter, None) is None
return pp.text(idx)
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
# Pretty prints `a = get x i` as `x[i] <- a`
y, = eqn.outvars
x, *idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
# TODO more general get
return pp.concat([lhs, pp.text(' <- '), pp_ref(pp.concat([
pp.text(core.pp_var(x, context)), pp.text('['), idx, pp.text(']')]))])
core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
y, = eqn.outvars
x, v, *idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
if type(y) is core.DropVar:
# In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v`
del y
return pp.concat([
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')
])), pp.text(' <- '), pp.text(core.pp_var(v, context))])
else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
x_i = pp.concat([pp.text(core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')])
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), pp_ref(x_i), pp.text(' <- '),
pp_ref(x_i), pp.text(', '),
pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
# pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars
x, v, *idx = eqn.invars
idx = _pp_idx(context, idx, eqn.params["indexed_dims"])
return pp.concat([
pp_ref(pp.concat([
pp.text(core.pp_var(x, context)),
pp.text('['), idx, pp.text(']')
])), pp.text(' += '), pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
ref_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, AbstractRef)
return (get_p.bind(ref_primal, *idx, **params),
get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type]
ad.primitive_jvps[get_p] = _get_jvp
def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, x_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, AbstractRef)
x_tangent = ad_util.instantiate(x_tangent)
return (swap_p.bind(ref_primal, x_primal, *idx, **params), # type: ignore[arg-type]
swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type]
ad.primitive_jvps[swap_p] = _swap_jvp
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any], **params: Any):
ref_primal, x_primal, *idx = primals
ref_tangent, x_tangent, *_ = tangents
x_tangent = ad_util.instantiate(x_tangent)
addupdate_p.bind(ref_primal, x_primal, *idx, **params)
addupdate_p.bind(ref_tangent, x_tangent, *idx, **params)
return [], []
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
## get/swap/addupdate transpose rules
def _get_transpose(g, ref, *idx, **params):
# get transpose is addupdate
if type(g) is not ad_util.Zero:
addupdate_p.bind(ref, g, *idx, **params)
return [None] + [None] * len(idx)
ad.primitive_transposes[get_p] = _get_transpose
def _swap_transpose(g, ref, x, *idx, **params):
# swap transpose is swap
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
return [None, x_bar] + [None] * len(idx)
ad.primitive_transposes[swap_p] = _swap_transpose
def addupdate_transpose(cts_in, ref, x, *idx, **params):
# addupdate transpose is get
del cts_in, x
g = get_p.bind(ref, *idx, **params)
return [None, g] + [None] * len(idx)
ad.primitive_transposes[addupdate_p] = addupdate_transpose
## get/swap/addupdate partial_eval_custom rules
def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn):
if any(unks_in):
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res
elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params):
return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), []
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res
pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom,
get_p)
pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom,
swap_p)
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
_state_partial_eval_custom, addupdate_p)
## get/swap/addupdate batching rules
def _output_bdim(indexed_dims: Tuple[bool, ...], ref_dim: int,
idxs_shape: Tuple[int, ...]):
num_idxs_to_left = sum(indexed_dims[:ref_dim])
return ref_dim - num_idxs_to_left + len(idxs_shape)
def _get_vmap(batched_args, batched_dims, *, indexed_dims):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, *idxs = batched_args
ref_dim, *idx_dims = batched_dims
ref_is_batched = ref_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
bdim_out = 0
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
in zip(idxs, idx_dims))
idxs_shape, = {i.shape for i in idxs} or [()]
if ref_is_batched:
# If ref is batched, we are doing a `get` with an additional axis. If `idxs`
# are also batched, then we are indexing into the batch axis with an `iota`.
indexed_dims = tuple_insert(indexed_dims, ref_dim, idx_is_batched)
if idx_is_batched:
# If we have batched idx, we need to insert the new iota index. The place
# where we add in the new `iota` index is `ref_dim` so we need to compute
# what `ref_dim` *would be* if we inserted it into `idxs` instead, because
# `idxs` doesn't include the non indexed dims.
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
else:
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
return get_p.bind(ref, *idxs, indexed_dims=indexed_dims), bdim_out
batching.primitive_batchers[get_p] = _get_vmap
def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, val, *idxs = batched_args
ref_dim, val_dim, *idx_dims = batched_dims
ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
idxs = tuple(batching.bdim_at_front(i, d, axis_size) for i, d
in zip(idxs, idx_dims))
idxs_shape, = {i.shape for i in idxs} or [()]
if ref_is_batched and not idx_is_batched:
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
if not val_is_batched:
val = batching.broadcast(val, axis_size, 0)
val_dim = 0
val = batching.moveaxis(val, val_dim, bdim_out)
elif idx_is_batched:
assert ref_is_batched and val_is_batched
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
bdim_out = 0
return swap_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), bdim_out
batching.primitive_batchers[swap_p] = _swap_vmap
def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, val, *idxs = batched_args
ref_dim, val_dim, *idx_dims = batched_dims
ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in idx_dims)
if idx_is_batched:
# If at least one of the idx is batched, we ensure all have bdims at front.
idxs = tuple(batching.bdim_at_front(i, d, axis_size)
for i, d in zip(idxs, idx_dims))
idxs_shape, = {i.shape for i in idxs} or [()]
if ref_is_batched and not idx_is_batched:
indexed_dims = tuple_insert(indexed_dims, ref_dim, False)
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
if not val_is_batched:
val = batching.broadcast(val, axis_size, 0)
val_dim = 0
val = batching.moveaxis(val, val_dim, bdim_out)
elif idx_is_batched:
assert ref_is_batched and val_is_batched
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
idxs_shape, = {i.shape for i in idxs} or [()]
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
batching.primitive_batchers[addupdate_p] = _addupdate_vmap