# Copyright 2021 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import functools
import itertools as it
from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
FrozenSet, Type, Set, List, Sequence, Any)
import numpy as np
import jax.numpy as jnp
from jax import lax
from jax._src import api
from jax._src import linear_util as lu
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import prng
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import tree_flatten
from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache, HashableWrapper)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = Type['JaxException']
Payload = List[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
Out = TypeVar('Out')
## Utils
def popattr(obj, attrname):
val = getattr(obj, attrname)
delattr(obj, attrname)
return val
def setnewattr(obj, name, val):
sentinel = object()
assert getattr(obj, name, sentinel) is sentinel
setattr(obj, name, val)
# Concrete errors
class JaxException(Exception):
"""Python exception which can contain an error message with JAX run-time info."""
def __init__(self, traceback_info):
self.traceback_info = traceback_info
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
# self.with_traceback(self.traceback_info)
def __init_subclass__(cls):
def tree_flatten(self):
return ([], self.traceback_info)
def tree_unflatten(cls, metadata, payload):
del payload
return cls(metadata)
def get_effect_type(self) -> ErrorEffect:
raise NotImplementedError
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect(effects.Effect):
error_type: Type[JaxException]
shape_dtypes: Tuple[api.ShapeDtypeStruct, ...]
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
for sd in x.shape_dtypes)
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
return (unpack(self) < unpack(other))
class DivisionByZeroError(JaxException):
def __str__(self):
return f'division by zero at:\n\n{self.traceback_info}'
def get_effect_type(self):
return ErrorEffect(DivisionByZeroError, ())
class NaNError(JaxException):
def __init__(self, traceback_info, primitive_name):
self.prim = primitive_name
def tree_flatten(self):
return ([], (self.traceback_info, self.prim))
def tree_unflatten(cls, metadata, _):
return cls(*metadata)
def get_effect_type(self):
return ErrorEffect(NaNError, ())
def __str__(self):
return f'nan generated by primitive: {self.prim} at:\n\n{self.traceback_info}'
class OOBError(JaxException):
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
self.prim = primitive_name
self.operand_shape = operand_shape
self._payload = payload
def tree_flatten(self):
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
def tree_unflatten(cls, metadata, payload):
return cls(*metadata, payload[0])
def __str__(self):
return (f'out-of-bounds indexing for array of '
f'shape {self.operand_shape}: '
f'index {self._payload[0]} is out of bounds for axis '
f'{self._payload[1]} with size {self._payload[2]}. '
f'Failed at:\n\n{self.traceback_info}')
def get_effect_type(self):
return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))
class FailedCheckError(JaxException):
def __init__(self, traceback_info, fmt_string, *a, **k):
self.fmt_string = fmt_string
self.args = a
self.kwargs = k
def tree_flatten(self):
return ((self.args, self.kwargs), # leaves
(self.traceback_info, self.fmt_string)) # treedef
def tree_unflatten(cls, metadata, payload):
args, kwargs = payload
return cls(*metadata, *args, **kwargs)
def __str__(self):
return (self.fmt_string.format(*self.args, **self.kwargs)
+ f' `check` failed at:\n\n{self.traceback_info}')
def get_effect_type(self):
vals = jtu.tree_leaves((self.args, self.kwargs))
return ErrorEffect(
tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))
class BatchedError(JaxException):
error_mapping: Dict[Tuple[int, ...], JaxException]
def __post_init__(self):
traceback_info = list(self.error_mapping.values())[0].traceback_info
def __str__(self):
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
for idx, e in self.error_mapping.items())
# Error Value
class Error:
_pred: Dict[ErrorEffect, Bool]
_code: Dict[ErrorEffect, Int]
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
_payload: Dict[ErrorEffect, Payload]
def get(self) -> Optional[str]:
"""Returns error message if error happened, None if no error happened."""
exp = self.get_exception()
if exp is not None:
return str(exp)
return None
def get_exception(self) -> Optional[JaxException]:
"""Returns Python exception if error happened, None if no error happened."""
if any(map(np.shape, self._pred.values())):
return self._get_batched_exception()
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect]:
if min_code is None or code < min_code:
min_code = code
cur_effect = error_effect
if cur_effect is not None:
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
return None
def throw(self):
def __str__(self):
return f'Error({self.get()})'
# Internal helpers
def _get_batched_exception(self) -> Optional[BatchedError]:
shape = np.shape(list(self._pred.values())[0])
error_mapping = {}
for idx in np.ndindex(*shape):
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code:
min_code = code[idx] # type: ignore
cur_effect = error_effect
if cur_effect is not None:
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
error_mapping[idx] = jax_error
if error_mapping:
return BatchedError(error_mapping)
return None
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
new_errs = {**self._pred, **{effect_type: pred}} # type: ignore
new_codes = {**self._code, **{effect_type: code}} # type: ignore
new_payload = {**self._payload, **{effect_type: payload}} # type: ignore
new_metadata = {**self._metadata, **metadata}
return Error(new_errs, new_codes, new_metadata, new_payload)
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
new_err = self._pred.copy()
new_code = self._code.copy()
new_payload = self._payload.copy()
for effect in effects:
if effect not in self._pred.keys():
new_err[effect] = False
new_payload[effect] = list(
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
# The error value associated with this effect will never become True, so
# we don't need to set a meaningful code.
new_code[effect] = -1
return Error(new_err, new_code, self._metadata, new_payload)
def _replace(self, *args, **kwargs):
return dataclasses.replace(self, *args, **kwargs)
# PyTree methods
def tree_flatten(self):
return ((self._pred, self._code, self._payload), (self._metadata))
def tree_unflatten(cls, metadata, data):
pred, code, payload = data
return cls(pred, code, metadata, payload)
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
code = next_code()
effect_type = new_error.get_effect_type()
new_payload, new_metadata = tree_flatten(new_error)
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
def update_error(error, pred, code, metadata, payload, effect_type):
err_of_type = error._pred.get(effect_type, False)
out_err = err_of_type | pred
out_code =, error._code.get(effect_type, -1), code)
cur_payload = error._payload.get(effect_type, None)
if cur_payload is not None:
out_payload = tree_map(functools.partial(, err_of_type), cur_payload, payload)
out_payload = payload
return error._update(effect_type, out_err, out_code, metadata, out_payload)
## Checkify transformation for plumbing functional error values.
def _flatten_and_get_error_metadata_thunk(*invals):
error, out = yield invals, {}
out_vals, out_tree = jtu.tree_flatten((error, out))
yield out_vals, (out_tree, set(error._pred.keys()))
def default_checkify_rule(primitive: core.Primitive, error: Error,
enabled_errors, *invals: core.Value,
**params: Any) -> Tuple[Error, Sequence[core.Value]]:
"""Default rule for primitives in `checkify` interpreter."""
if 'call_jaxpr' not in params:
# Default non-HOP case: just call primitive and don't update error.
return error, primitive.bind(*invals, **params)
# Code below handles call- and map-primitives, by recursively calling
# checkify_jaxpr.
err_vals, err_tree = jtu.tree_flatten(error)
num_error_vals = len(err_vals)
if 'donated_invars' in params:
params = dict(params, donated_invars=(*[False]*num_error_vals,
# call_jaxpr handling
call_jaxpr = params.pop('call_jaxpr')
if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
jaxpr, consts = call_jaxpr, ()
consts_ = tuple(HashableWrapper(c) for c in consts)
partial_checkify = lu.hashable_partial(lu.wrap_init(
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
# map-specific params handling.
if isinstance(primitive, core.MapPrimitive):
# Update `in_axes` and `out_axes_thunk` params for map primitive.
out_val_axes = params.pop('out_axes')
def out_axes_thunk():
out_err_num = metadata()[0].num_leaves - len(out_val_axes)
return (*(0,)*out_err_num, *out_val_axes)
params = dict(params,
in_axes=(*(None,)*num_error_vals, *params['in_axes']),
all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params)
out_tree, _ = metadata()
error, out_vals = tree_unflatten(out_tree, all_vals)
if isinstance(primitive, core.MapPrimitive):
error = _reduce_any_error(error)
return error, out_vals
def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> Tuple[Error, List[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
enabled_errors, err_tree, *err_vals, *args)
def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
enabled_errors, err_tree: PyTreeDef,
*args: core.Value) -> Tuple[Error, List[Any]]:
env: Dict[core.Var, Any] = {}
err_vals, in_args = split_list(args, [err_tree.num_leaves])
error = jtu.tree_unflatten(err_tree, err_vals)
def read_env(var: core.Atom):
if isinstance(var, core.Literal):
return var.val
return env[var]
def write_env(var: core.Var, val: Any):
env[var] = val
map(write_env, jaxpr.constvars, consts)
map(write_env, jaxpr.invars, in_args)
# interpreter loop
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
checkify_rule = error_checks.get(
eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive))
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(eqn.source_info.traceback,
error, outvals = checkify_rule(error, enabled_errors,
*invals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
write_env(eqn.outvars[0], outvals)
return error, map(read_env, jaxpr.outvars)
def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
err_tree, *args):
consts = tuple(c.x for c in hashable_consts)
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
def flatten_fun_output(*args):
ans = yield args, {}
yield tree_flatten(ans)
def _reduce_any_error(error: Error):
out_error = init_error
for error_effect in error._pred.keys():
errs, codes, payloads = (error._pred[error_effect],
reduced_idx = jnp.argsort(errs)[-1]
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
(errs, codes, payloads))
out_error = out_error._update(error_effect, pred, code, {}, payload)
out_error = out_error._replace(_metadata=error._metadata)
return out_error
## check_p primitive
check_p = core.Primitive('check')
check_p.multiple_results = True # zero results
# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
def check_impl(*args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
error = tree_unflatten(err_tree, args)
exc = error.get_exception()
if exc:
raise JaxRuntimeError(str(exc)) from exc
return []
def check_abstract_eval(*args, err_tree, debug):
del debug
return [], set(tree_unflatten(err_tree, args)._pred.keys())
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
'Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
def check_lowering_rule(ctx, *args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
out_op, _, keep_alive = mlir.emit_python_callback(
ctx, callback=functools.partial(python_err, err_tree),
return out_op
def check_lowering_rule_unsupported(*a, debug, **k):
if debug:
return []
raise functionalization_error
def python_err(err_tree, *args):
error = tree_unflatten(err_tree, args)
return []
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
mlir.register_lowering(check_p, check_lowering_rule,
mlir.register_lowering(check_p, check_lowering_rule,
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
if dim is not batching.not_mapped)
batched_args = (batching.bdim_at_front(a, d, size)
for a, d in zip(batched_args, batch_dims))
err = tree_unflatten(err_tree, batched_args)
_check_error(err, debug=debug)
return [], []
batching.primitive_batchers[check_p] = check_batching_rule
def check_jvp_rule(primals, _, *, err_tree, debug):
# Check primals, discard tangents.
check_p.bind(*primals, err_tree=err_tree, debug=debug)
return [], []
ad.primitive_jvps[check_p] = check_jvp_rule
## checkify rules
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def summary() -> str:
return str(source_info_util.summarize(
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
out = prim.bind(*in_vals, **params)
err = check_nans(prim, error, enabled_errors, out)
return err, out
def check_nans(prim, error, enabled_errors, out):
if NaNError not in enabled_errors:
return error
def isnan(x):
if isinstance(x, prng.PRNGKeyArray):
return False
return jnp.any(jnp.isnan(x))
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
if prim.multiple_results else isnan(out))
return assert_func(error, any_nans, NaNError(summary(),
# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
lax.reduce_sum_p, lax.reduce_window_p,
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
for _prim in nan_primitives:
error_checks[_prim] = functools.partial(nan_error_check, _prim)
def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)
if OOBError not in enabled_errors:
return error, out
operand_dims = np.array(operand.shape)
slice_sizes = np.array(slice_sizes)
start_indices = jnp.array(start_indices)
oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_slice", operand.shape, payload))
return error, out
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check
def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices):
out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
if OOBError not in enabled_errors:
return error, out
operand_dims = np.array(operand.shape)
update_dims = np.array(update.shape)
start_indices = jnp.array(start_indices)
oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims)
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_update_slice", operand.shape, payload))
return error, out
error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
if OOBError not in enabled_errors:
return error, out
# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
num_batch_dims = len(start_indices.shape) - 1
upper_bound = operand_dims[np.array(dnums.start_index_map)]
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload))
return error, out
error_checks[lax.gather_p] = gather_error_check
def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
if DivisionByZeroError in enabled_errors:
any_zero = jnp.any(jnp.equal(y, 0))
error = assert_func(error, any_zero, DivisionByZeroError(summary()))
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check
def oob_payload(oob_mask, indices, dims_map, operand_shape):
# Get first OOB index, axis and axis size so it can be added to the error msg.
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
oob_axis_size = jnp.array(operand_shape)[oob_axis]
oob_index = jnp.ravel(indices)[flat_idx]
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
return payload
def scatter_oob(operand, indices, updates, dnums):
# Ref: see clamping code used in scatter_translation_rule
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
pos += 1
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims],
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
lower_oob = jnp.less(indices, 0)
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
oob_mask = jnp.logical_or(lower_oob, upper_oob)
payload = oob_payload(oob_mask, indices,
dnums.scatter_dims_to_operand_dims, operand.shape)
return jnp.any(oob_mask), payload
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
if OOBError not in enabled_errors:
return error, out
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
oob_error = OOBError(summary(),, operand.shape, payload)
error = assert_func(error, out_of_bounds, oob_error)
error = check_nans(prim, error, enabled_errors, out)
return error, out
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
# HOP error check rules
def jaxpr_to_checkify_jaxpr(
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
*flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
jaxpr.consts, enabled_errors,
fun = lu.wrap_init(checkify_jaxpr_partial)
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
out_tree, error_effects = metadata()
return checked_jaxpr, out_tree, error_effects
def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear):
# Get the error-effects out of all branches so the cond can be called with
# a merged error with all these effects.
err_vals, err_tree = jtu.tree_flatten(error)
in_avals = map(get_shaped_aval, [*err_vals, *ops])
def get_error_effects_from_jaxpr(jxpr):
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
return effects
effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches]
merged_error = error._add_placeholder_effects(set().union(*effects))
err_vals, err_tree = jtu.tree_flatten(merged_error)
new_linear = (*[False] * len(err_vals), *linear)
# Update branch jaxprs to be checkified jaxprs.
in_avals = map(get_shaped_aval, [*err_vals, *ops])
new_branches, out_trees, _ = unzip3(
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
err_and_outs = lax.cond_p.bind(
index, *err_vals, *ops,
branches=tuple(new_branches), linear=new_linear)
# we need to merge metadata across out_trees (a tuple)
err0, out = tree_unflatten(out_trees[0], err_and_outs)
merged_metadata = err0._metadata
for tr in out_trees[1:]:
err, _ = tree_unflatten(tr, err_and_outs)
merged_metadata = {**merged_metadata, **err._metadata}
return err0._replace(_metadata=merged_metadata), out
error_checks[lax.cond_p] = cond_error_check
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
# Query body effects to create a merged error containing all effects (such
# that in and out carried error are of the same type).
err_vals, err_tree = jtu.tree_flatten(error)
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
merged_error = error._add_placeholder_effects(effects)
err_vals, err_tree = jtu.tree_flatten(merged_error)
# Create checked-jaxpr, with the needed pre-processing on the inputs.
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
new_in_flat = [*consts, *err_vals, *carry, *xs]
new_linear = (*[False] * len(err_vals), *linear)
tomove = ([False] * len(err_vals) + [True] * len(consts)
+ [False] * (len(carry) + len(xs)))
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
new_in_flat = [*consts, *err_vals, *carry, *xs]
err_and_out = lax.scan_p.bind(
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
num_consts=len(consts), num_carry=len(carry)+len(err_vals),
linear=new_linear, unroll=unroll)
err, out = tree_unflatten(out_tree, err_and_out)
return err, out
error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
enabled_errors, error: Error,
c_consts_num: int) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*c_consts_and_vals):
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
out = body_f(*vals)
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
new_body_f_ = lu.wrap_init(new_body_f)
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
err_vals, err_tree = jtu.tree_flatten(error)
err_vals = map(get_shaped_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
return jaxpr, out_tree, error_effects
def ignore_error_output_jaxpr(jaxpr, num_error_vals):
"""Constructs a checked jaxpr which does not output its error value."""
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
if cond_jaxpr.out_avals[0].shape:
# TODO(lenamartens, sharadmv): support batched while.
raise ValueError('Checkify does not support batched while-loops '
'(checkify-of-vmap-of-while). \nHint: if possible, move '
'the vmap to the outer level to get '
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
# Check if the first cond application will error.
error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)
_, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
enabled_errors, error,
# merged error!
error = error._add_placeholder_effects(error_effects)
err_vals, err_tree = jtu.tree_flatten(error)
checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts)
num_error_vals = len(err_vals)
to_move = ([False] * num_error_vals + [True] * cond_nconsts
+ [True] * body_nconsts + [False] * len(carry))
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
cond_in_flat = [*err_vals, *c_consts, *carry]
cond_in_flat = map(get_shaped_aval, cond_in_flat)
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
err_tree, *cond_in_flat)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry]
all_out_vals = lax.while_p.bind(
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr)
# body_out_tree will have all the metadata of cond because it executes a cond!
error, out = tree_unflatten(body_out_tree, all_out_vals)
return error, out
error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
inline, keep_unused):
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *in_avals)
# Update pjit params to account for extra error values.
num_error_vals = len(err_vals)
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
sharding = sharding_impls.UNSPECIFIED
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
err_and_out = pjit.pjit_p.bind(
return tree_unflatten(out_tree, err_and_out)
error_checks[pjit.pjit_p] = pjit_error_check
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
jvp_jaxpr_thunk, call_jaxpr, **params):
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
# where the other Err' components are trivial (of float0 dtype).
# Semantically, we don't add checks to the JVP rule. To check the result of a
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
# just forwards the input error and code (and trivial tangents) to the output.
err_vals, err_tree = jtu.tree_flatten(in_err)
partial_checkify = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
call_jaxpr.consts, enabled_errors, err_tree))
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
jvp, jvp_out_tree = flatten_fun_output(jvp)
all_outs = custom_derivatives.custom_jvp_call_p.bind(
partial_checkify, jvp, *err_vals, *in_vals, **params)
fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree)
if fst:
err_and_out_tree, _ = out_metadata
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
err_vals, out_vals = split_list(all_outs, [len(err_vals)])
# forward input error to output
out_err = jtu.tree_unflatten(err_tree, err_vals)
return out_err, out_vals
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
# outputs that checkify adds (just forwarding the error data's primal and
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
# Adding another layer of lu.transformation was tricky, though maybe doable.
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
primal_errs = xs[num_consts:num_consts+num_errs]
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
return jvp
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
err_vals, err_tree = jtu.tree_flatten(in_err)
num_errs = err_tree.num_leaves
checkified_fun = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
fun_jaxpr.consts, enabled_errors, err_tree))
checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
def checkified_fwd(*args):
# TODO(lenamartens, sharadmv): why not checkify here?
xs, zeros = args[::2], args[1::2]
xs, zeros = xs[num_errs:], zeros[num_errs:]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
xs_without_consts = xs[num_consts:]
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args))
checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd)
all_outs = custom_derivatives.custom_vjp_call_p.bind(
checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees,
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
if fst:
err_and_out_tree, _ = out_metadata
out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
out_err, out_vals = in_err, all_outs
return out_err, out_vals
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
del debug
new_error = tree_unflatten(err_tree, args)
# Split up new_error into error to be functionalized if it's included in
# enabled_errors (=discharged_error) and an error to be defunctionalized if
# it's not included (=recharged_error)
discharged_error = error
recharged_error = init_error
for error_effect in new_error._pred.keys():
pred = new_error._pred[error_effect]
code = new_error._code[error_effect]
payload = new_error._payload[error_effect]
if error_effect.error_type in enabled_errors:
discharged_error = update_error(discharged_error, pred, code, {}, payload,
recharged_error = update_error(recharged_error, pred, code, {}, payload,
discharged_error = discharged_error._replace(
_metadata={**new_error._metadata, **discharged_error._metadata})
recharged_error = recharged_error._replace(_metadata=new_error._metadata)
# TODO(lenamartens): we actually need to recharge, but this would be a
# breaking API change so leaving for a follow-up.
# check_error(recharged_error)
return discharged_error, []
error_checks[check_p] = check_discharge_rule
## checkify public api
user_checks = frozenset({FailedCheckError})
nan_checks = frozenset({NaNError})
index_checks = frozenset({OOBError})
div_checks = frozenset({DivisionByZeroError})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks
def checkify(f: Callable[..., Out],
errors: FrozenSet[ErrorCategory] = user_checks
) -> Callable[..., Tuple[Error, Out]]:
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added :func:`~check` assertions, or
automatically added checks like NaN checks, depending on the ``errors``
The returned function will return an Error object `err` along with the output
of the original function. ``err.get()`` will either return ``None`` (if no
error occurred) or a string containing an error message. This error message
will correspond to the first error which occurred. ``err.throw()`` will raise
a ValueError with the error message if an error occurred.
By default only user-added :func:`~check` assertions are enabled. You can
enable automatic checks through the ``errors`` argument.
The automatic check sets which can be enabled, and when an error is generated:
- ``user_checks``: a :func:`~check` evaluated to False.
- ``nan_checks``: a floating-point operation generated a NaN value
as output.
- ``div_checks``: a division by zero.
- ``index_checks``: an index was out-of-bounds.
Multiple categories can be enabled together by passing in an error `Set` (eg.
``errors=nan_checks``). Multiple sets can be re-combined (eg.
fun: Callable which can contain user checks (see :func:`~check`).
errors: A set of ErrorCategory values which defines the set of enabled
checks. By default only explicit ``checks`` are enabled
(``user_checks``). You can also for example enable NAN and
DIV errors by passing the ``float_checks`` set, or for
example combine multiple sets through set operations
(``float_checks | user_checks``)
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing the first
failed :func:`~check`, and the second element is the original output of
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> @jax.jit
... def f(x):
... y = jnp.sin(x)
... return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
def checked_fun(*args, **kwargs):
# close over all arguments so they're not turned into abstract values.
in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
return error, jtu.tree_unflatten(out_tree(), out_flat)
return checked_fun
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, :func:`~checkify` it!
pred: if False, a FailedCheckError error is added.
msg: error message if error is added. Can be a format string.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "{x} needs to be positive!", x=x)
... return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
if not is_scalar_pred(pred):
prim_name = 'debug_check' if debug else 'check'
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
if not isinstance(arg, (Array, np.ndarray)):
raise TypeError('Formatting arguments to checkify.check need to be '
'PyTrees of arrays, but got '
f'{repr(arg)} of type {type(arg)}.')
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
_check_error(error, debug=debug)
def _check_error(error, *, debug=False):
if any(map(np.shape, error._pred.values())):
error = _reduce_any_error(error)
err_args, tree_def = tree_flatten(error)
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, Array) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate when running under checkify, otherwise is a no-op.
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
otherwise the check will be dropped.
pred: if False, a FailedCheckError error is added.
msg: error message if error is added.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.debug_check(x!=0, "cannot be zero!")
... return x
>>> _ = f(0) # running without checkify means no debug_check is run.
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.checkify.JaxRuntimeError: cannot be zero!
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, ``check_error`` can be functionalized using
the :func:`~checkify` transformation.
This function is similar to :func:`~check` but with a different signature: whereas
:func:`~check` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both :func:`~check`
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
:func:`~jax.lax.scan`, etc. Both also can
be functionalized by using :func:`~checkify`.
But unlike :func:`~check`, this function is like a direct inverse of
whereas :func:`~checkify` takes as input a function which
can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while :func:`~checkify` goes from
functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
functionalizable Exception effect.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via
:func:`~checkify`) back into Python Exceptions.
error: Error to check.
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through :func:`~jax.jit`, then
re-inject your error value outside of the :func:`~jax.jit`:
>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "must be positive!")
... return x
>>> def with_inner_jit(x):
... checked_f = checkify.checkify(f)
... # a checkified function can be jitted
... error, out = jax.jit(checked_f)(x)
... checkify.check_error(error)
... return out
>>> _ = with_inner_jit(1) # no failed check
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
_check_error(error, debug=False)