Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/host_callback.py
2023-06-19 00:49:18 +02:00

2066 lines
84 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Primitives for calling Python functions on the host from JAX accelerator code.
**Experimental: please give feedback, and expect changes.**
This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
to the host and invoke user-defined Python functions on the host, optionally
returning results back to the device computation.
We show below how these functions can be used. We start with :func:`call`,
and we discuss examples of calling from JAX to arbitrary Python functions
on the CPU, e.g., to use NumPy CPU custom kernels. Then we
show uses of :func:`id_tap` and :func:`id_print`, which have the restriction
that they cannot return values from the host to the device.
These primitives are generally faster
because they are executed asynchronously with the device code.
In particular, they can be used to tap into and to debug JAX code.
Using :func:`call` to call a host function and return results to device
-----------------------------------------------------------------------
Use :func:`call` to invoke a computation on the host and return
NumPy arrays to the device computation.
Host computation is useful, e.g., when a device computation needs some data
that requires I/O on the host, or it needs a library that is available on the
host and you do not want to code it in JAX.
For example, eigen decomposition for general matrices in JAX does not work on TPU.
We can call the Numpy implementation from any JAX accelerator computation,
using a host computation::
# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
return np.linalg.eigvals(m)
# This function is used in JAX
def device_fun(m):
# We send "m" to the host, asking it to call "host_eig" and return the result.
# We have to specify the result shape and dtype, either in the form of an
# example return value or any object that has `shape` and `dtype` attributes,
# e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
return hcb.call(host_eig, m,
# Given an input of shape (..., d, d), eig output has shape (..., d)
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
The :func:`call` function and the Python host function both take a single argument
and return a single result, but those can be pytrees. Note that we must tell
the :func:`call` what shape and dtype to expect from the host invocation, using
the ``result_shape`` keyword argument.
This is important because the device code is compiled with that expectation.
There will be an error raised at runtime if the actual invocation produces a
different result shape. In general, **such errors and also exceptions raised
by the host computation may be difficult to debug**. See the Debugging section
below.
This is a problem for :func:`call` but not for :func:`id_tap` because for the
latter the device code does not expect a returned value.
The :func:`call` API can be used inside a jit or pmap computation or inside
cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be
separate calls to the host from each of the participating devices::
def host_sin(x, *, device):
# The ``device`` argument is passed due to ``call_with_device=True`` below.
print(f"Invoking host_sin with {x.shape} on {device}")
return np.sin(x)
# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
result_shape=x,
# Ask that the `host_sin` function be passed `device=dev`
call_with_device=True))(
np.ones((2, 4), dtype=np.float32))
# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1
Note that :func:`call` does not support any JAX transformations, but as we
show below one can make use of the
existing support for `Custom differentiation in JAX <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.
Using :func:`id_tap` to call a Python function on the host, with no returned values
-----------------------------------------------------------------------------------
The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when
you just want the side effects of your Python callback. These functions have
the advantage that once the arguments have been sent to the host, the device
computation can proceed without waiting for the Python callback to return.
For :func:`id_tap` you can specify your Python callback to be called, while
:func:`id_print` uses a built-in callback that prints the arguments to
`stdout` on the host.
The Python function passed
to :func:`id_tap` takes two positional arguments (the value tapped
from the device computation along with a ``transforms`` tuple,
described below). Optionally, the function may be passed a keyword argument
``device`` with the Device from which the value was tapped.
A few examples::
def host_func(arg, transforms):
...do something with arg...
# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)
# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree
# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap
# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)
# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
The above examples can all be adapted to use :func:`id_print` instead, with
the difference that :func:`id_print` prints on the host the positional argument,
along with any additional kwargs and the automatic kwarg ``transforms``.
Using :func:`barrier_wait` to wait until all callbacks have executed
--------------------------------------------------------------------
If your Python callbacks have side-effects you may need to wait until the
computation has finished to ensure that the side-effects have been observed.
You can use the :func:`barrier_wait` function for that purpose::
accumulator = []
def host_log(arg, transforms):
# We just record the arguments in a list
accumulator.append(arg)
def device_fun(c):
id_tap(host_log, x)
id_tap(host_log, 2. * x)
jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)
# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.
Note that :func:`barrier_wait` will start one
tiny computation with one tap on each of the `jax.local_devices()` and
will wait for all these taps to be received.
An alternative to using :func:`barrier_wait` is to just wait for the end
of the computation, if all the callbacks are :func:`call`::
accumulator = p[]
def host_log(arg):
# We just record the arguments in a list
accumulator.append(arg)
return 0. # return something
def device_fun(c):
y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
return y + z # return something that uses both results
res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()
Behavior under parallelization transformations
----------------------------------------------
In presence of :func:`jax.pmap` the code will run on multiple devices and
each device will tap its values independently.
It may be helpful to use the ``tap_with_device`` option for :func:`id_print`
or :func:`id_tap`, so that you see which device is sending which data::
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.) # from the first device
# device=cpu:1 what=x,x^2: (4., 16.) # from the second device
When using :func:`jax.pmap` with multiple devices on multiple hosts, every
host will receive callbacks from all of its local devices, with an operand
that corresponds to each device slice. For a
:func:`call`, the callback must return to each device only the slice of the
result that pertains to the corresponding device.
When using the experimental :func:`pjit.pjit` the code will run on multiple
devices on different shards of the input. The current implementation of
host callbacks will ensure that a single device will collect and outfeed
the entire operand, in a single callback. The callback function is supposed
to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices::
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_shardings=(P("d"),),
out_shardings=(P("d"),))(np.array([3., 4.]))
# device=TPU:0 what=x,x^2: ( [3., 4.],
# [9., 16.] )
Note that the collection of the operand on one device may result in OOM if
the operand was sharded across devices.
When using :func:`pjit.pjit` with multiple devices on multiple hosts, only
the host for the device 0 (w.r.t. the mesh) will receive the callback, with
the operand collected
from all participating devices on all hosts. For a :func:`call`, the callback
must return the entire array for all devices on all hosts.
Behavior under JAX autodiff transformations
-------------------------------------------
When used under a JAX autodiff transformation, the host callback functions
operate on the primal values only. Consider the following example::
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.)
When used under :func:`jax.jvp` there will be one callback with the primal
values only::
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
Similarly for :func:`jax.grad`, we get a callback from the forward computation
only::
jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)
If you want to invoke the callback on the tangents during a :func:`jax.jvp`,
you can use a custom_jvp. For example, you can define a function that does
nothing interesting except that its custom_jvp will print the tangents::
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents")
return primals, tangents
Then you use this function in the places where you want to tap the tangents::
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
print_tangents((x, y))
return y * x
jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)
You can do a similar thing for the cotangents during :func:`jax.grad`. This
time you must be careful to use in the rest of the computation the values whose
cotangents you want to tap. Hence we make the ``print_cotangents`` return
its argument::
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
(x1, y1) = print_cotangents((x, y))
# Must use the output of print_cotangents
return y1 * x1
jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals
for the backward pass, then the callbacks from the primal computation will
be called twice::
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)
The callbacks are, in order from: the primal computation of the inner ``power3``,
the primal computation of the outer ``power3``, and the rematerialization
of the residuals for the inner ``power3``.
Behavior under jax.vmap
-----------------------
The host callback functions :func:`id_print` and :func:`id_tap` support the
vectorization transformation :func:`jax.vmap`.
For :func:`jax.vmap` the arguments to the callback are batched,
and the callback function is
passed an additional special ``transforms`` containing a list of transformation descriptors
in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the
batched dimensions for the tapped values (one entry per argument, `
`None`` denotes an argument that was broadcast).
jax.vmap(power3)(np.array([2., 3.]))
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])
See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`.
For more usage example, see tests/host_callback_test.py.
Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support
------------------------------------------------------------------------------------
Another possible use for host computation is to invoke a library written for
another framework, such as TensorFlow.
In this case it becomes interesting to support JAX autodiff for host callbacks
by deferring to the autodiff mechanism in TensorFlow,
using the :func:`jax.custom_vjp` mechanism.
This is relatively easy to do, once one understands both the JAX custom VJP
and the TensorFlow autodiff mechanisms.
The code for how this can be done is shown in the ``call_tf_full_ad``
function in `host_callback_to_tf_test.py <https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py>`_.
This example supports arbitrary higher-order differentiation as well.
Note that if you just want to call TensorFlow functions from JAX, you can also
use the `jax2tf.call_tf function <https://github.com/google/jax/blob/main/jax/experimental/jax2tf/call_tf.py>`_.
Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support
------------------------------------------------------------------------------------------------
It should not be surprising that we can use host computation to invoke a JAX
computation on another device. The arguments are sent from the accelerator to
the host, and then to the outside device on which the JAX host
computation will run, and then the results are sent back to the original accelerator.
The code for how this can be done is shown in the ``call_jax_other_device function``
in `host_callback_test.py <https://github.com/google/jax/blob/main/tests/host_callback_test.py>`_.
Low-level details and debugging
-------------------------------
The host callback functions will be executed for each device in the order in
which the send operations were performed on the device.
The host callback functions for multiple devices may be interleaved.
The data from the devices is received by separate threads managed by the JAX
runtime (one thread per device). The runtime maintains a buffer of
configurable size (see the flag ``--jax_host_callback_max_queue_byte_size``).
When the buffer is full, all the receiving threads are paused
which eventually pauses the computation on devices. The runtime has one
additional thread for each device to invoke the Python user functions with the
received data. If the processing of the callbacks is slow, it may actually
lead to the runtime buffer filling up, and eventually pausing the computation
on the devices when they need to send something.
For more details on the outfeed receiver runtime mechanism see
`runtime code
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
In order to pause the execution until all data from computations already
started on devices has arrived and has been processed, use :func:`barrier_wait`.
Exceptions from the user-defined callback functions are logged along with their
stack traces, but the receiving threads are not stopped. Instead the last
exception is recorded and the subsequent :func:`barrier_wait` will
raise :exc:`CallbackException` if any exception had occurred
in one of the tap functions. This exception will include the text and the
stack trace of the last exception encountered.
One further complication arises for callback functions that must return
results to the call origin device, such as :func:`call()`. This is handled
differently on CPU/GPU devices compared to TPU devices.
On CPU/GPU devices, in order to avoid the device computation
being stuck waiting for a result that will never arrive, in case of any
error during the processing of the callback (whether raised by the user-code
itself or due to a mismatch of the returned value and the expected return_shape)
we send the device a "fake" result of shape ``int8[12345]``.
This will make the device
computation abort because the received data is different than the one that
it expects. On CPU the runtime will crash with a distinctive error message:
```
Check failed: buffer->length() == buffer_length (12345 vs. ...)
```
On GPU, the failure is more user-friendly and will be surfaced to the Python
program as:
```
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
```
To debug the underlying cause for these messages, see the Debugging section.
On TPU devices, there is currently no shape check for infeed, so we take the
safer route of not sending this fake result in case of errors. This means
that the computation will hang, and no exception will be raised (but any
exceptions in the callback functions will still appear in the logs).
The current implementation uses the outfeed mechanism provided by XLA. The
mechanism itself is quite primitive in the sense that a receiver must know
exactly the shape of each incoming packet, and how many packets are expected.
This makes it hard to use for multiple kinds of data in the same computation,
and it is practically impossible to use it under conditionals or in loops
of non-constant iteration count. Furthermore, code that uses the outfeed
mechanism directly cannot be transformed by JAX. All these limitations are
addressed by the host callback functions. The tapping API introduced here
makes it easy to share the outfeed mechanism for multiple purposes, while
supporting all transformations.
**Note that after you have used the host callback functions, you cannot
use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver`
if you later need to use lax.outfeed.
Since the actual calls to your callback functions are made from the C++
receiver, it may be hard to debug the calls. In particular, the stack trace
will not include the calling code. You can use the flag
``jax_host_callback_inline`` (or the environment variable
``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are
inlined. This works only if the calls are outside a staging context
(:func:`~jax.jit` or a control-flow primitive).
The C++ `receiver
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_
is started automatically on the first call to :func:`id_tap`. In order to stop
it properly, upon start an ``atexit`` handler is registered to call
:func:`barrier_wait` with the logging name "at_exit".
There are a few environment variables that you can use to turn on logging
for the C++ outfeed `receiver backend
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
* ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below.
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave
like INFO logs. This may be too much, but you will see which modules are
logging relevant info, and then you can select which modules to log from.
* ``TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
Python, without the extension).
You should also use the ``--verbosity=2`` flag so that you see the logs
from Python.
For example, you can try to enable logging in the ``host_callback`` module:
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
If you want to enable logging in lower-level implementation modules try:
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
(For bazel tests use --test_arg=--vmodule=...
Still to do:
* More performance tests.
* Explore implementation with outside compilation for TPU.
* Explore implementation with XLA CustomCall for CPU and GPU.
"""
import atexit
import functools
import itertools
import logging
import math
import threading
import traceback
from typing import (Any, Callable, Dict, List, Optional, Sequence,
Tuple, cast)
import warnings
from jax._src import api
from jax._src import core
from jax import config
from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax._src.interpreters import ad, batching, pxla
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src.lib import pytree
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib.mlir.dialects import hlo
import numpy as np
FLAGS = config.FLAGS
logger = logging.getLogger(__name__)
def _inline_host_callback() -> bool:
return FLAGS.jax_host_callback_inline
def _use_outfeed(platform: str) -> bool:
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
"""Should be called whenever outfeed (or infeed) will be used."""
if xb.using_pjrt_c_api(backend):
raise NotImplementedError(
"host_callback functionality isn't supported with the new Cloud TPU "
"runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html"
" and "
"https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
" for alternatives. Please file a feature request at "
"https://github.com/google/jax/issues if none of the alternatives are "
"sufficent.")
xops = xla_client._xla.ops
XlaOp = xla_client.XlaOp
XlaShape = xla_client.Shape
XlaBuilder = xla_client.XlaBuilder
XlaDevice = xla_client.Device
XlaLocalClient = xla_client.Client
DType = Any
def id_tap(tap_func,
arg,
*,
result=None,
tap_with_device=False,
device_index=0,
**kwargs):
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
**Experimental: please give feedback, and expect changes!**
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
value of the argument.
Args:
tap_func: tap function to call like ``tap_func(arg, transforms)``, with
``arg`` as described below and where ``transforms`` is the sequence of
applied JAX transformations in the form ``(name, params)``. If the
`tap_with_device` optional argument is True, then the invocation also
includes the device from which the value is tapped as a keyword argument:
``tap_func(arg, transforms, device=dev)``.
arg: the argument passed to the tap function, can be a pytree of JAX
types.
result: if given, specifies the return value of ``id_tap``. This value is
not passed to the tap function, and in fact is not sent from the device to
the host. If the ``result`` parameter is not specified then the return
value of ``id_tap`` is ``arg``.
tap_with_device: if True then the tap function is invoked with the
device from which the tap originates as a keyword argument.
device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
Returns:
``arg``, or ``result`` if given.
The order of execution is by data dependency: after all the arguments and
the value of ``result`` if present, are computed and before the returned
value is used. At least one of the returned values of ``id_tap`` must be
used in the rest of the computation, or else this operation has no effect.
Tapping works even for code executed on accelerators and even for code under
JAX transformations.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
if kwargs:
msg = (
"Support for **kwargs in ``id_tap`` has been removed. Instead, "
"pre-apply keyword arguments, either by using a closure or by passing "
"``functools.partial(tap_func, **kwargs)``.")
raise TypeError(msg)
if FLAGS.jax_host_callback_ad_transforms:
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
'backwards compatibility mode. This flag, and the behavior '
'it enabled will be removed soon.',
FutureWarning)
if result is not None:
flat_results, result_treedef = pytree.flatten(result)
for r in flat_results:
dispatch.check_arg(r)
call_res = _call(
tap_func,
arg,
call_with_device=tap_with_device,
result_shape=None,
identity=True,
device_index=device_index)
if result is not None:
# Return the results, but add a dependency on the call, to ensure it
# is kept in the graph.
if FLAGS.jax_host_callback_ad_transforms:
call_flat_results, _ = pytree.flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
for r in flat_results]
else:
call_flat_results = flat_results
return result_treedef.unflatten(call_flat_results)
else:
return result
else:
return call_res
def id_print(arg,
*,
result=None,
tap_with_device=False,
device_index=0,
output_stream=None,
threshold=None,
**kwargs):
"""Like :func:`id_tap` with a printing tap function.
**Experimental: please give feedback, and expect changes!**
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
with the arrays stringified with ``numpy.array2string``.
See the :func:`id_tap` documentation.
Additional keyword arguments:
* ``tap_with_device`` if True, will print also the device from which
the value originates.
* ``output_stream`` if given then it will be used instead of the
built-in ``print``. The string will be passed as
``output_stream.write(s)``.
* ``threshold`` is passed to ``numpy.array2string``.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
printer = functools.partial(_print_tap_func,
output_stream=output_stream,
threshold=threshold, **kwargs)
return id_tap(
printer,
arg,
result=result,
tap_with_device=tap_with_device,
device_index=device_index)
def call(callback_func: Callable, arg, *,
result_shape=None,
call_with_device=False,
device_index=0):
"""Make a call to the host, and expect a result.
**Experimental: please give feedback, and expect changes!**
Args:
callback_func: The Python function to invoke on the host as
``callback_func(arg)``. If the ``call_with_device`` optional argument is True,
then the invocation also includes the ``device`` kwarg with the device
from which the call originates: ``callback_func(arg, device=dev)``. This function
must return a pytree of numpy ndarrays.
arg: the argument passed to the callback function, can be a pytree of JAX
types.
result_shape: a value that describes the expected shape and dtype of the
result. This can be a numeric scalar, from which a shape and dtype are
obtained, or an object that has ``.shape`` and ``.dtype`` attributes.
If the result of the callback is a pytree, then ``result_shape`` should
also be a pytree with the same structure. In particular, ``result_shape``
can be `()` or `None` if the function does not have any results.
The device code containing ``call`` is compiled with the expected result shape and dtype,
and an error will be raised at runtime if the actual ``callback_func``
invocation returns a different kind of result.
call_with_device: if True then the callback function is invoked with the
device from which the call originates as a keyword argument.
device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
Returns:
the result of the ``callback_func`` invocation.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
return _call(callback_func, arg, result_shape=result_shape,
call_with_device=call_with_device, identity=False,
device_index=device_index)
# We need the wrapper function to have hash and equality defined since it is
# used as a primitive keyword argument, and we want a compilation cache hit if
# the user uses the same function twice.
class _CallbackWrapper:
def __init__(self, callback_func, identity, call_with_device):
self.callback_func = callback_func
self.identity = identity
self.call_with_device = call_with_device
def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device))
def __eq__(self, other):
return (self.callback_func == other.callback_func and
self.identity == other.identity and
self.call_with_device == other.call_with_device)
def __call__(self, arg, device, transforms):
if self.identity:
# For id_tap, we pass the transforms, for backwards compatibility
if self.call_with_device:
return self.callback_func(arg, transforms, device=device)
else:
return self.callback_func(arg, transforms)
else:
if self.call_with_device:
return self.callback_func(arg, device=device)
else:
return self.callback_func(arg)
# Helper function to implement both `call` and `id_tap`. The two cases are
# differentiated by the `identity` flag.
def _call(callback_func: Callable,
arg,
*,
result_shape=None,
call_with_device=False,
device_index=0,
identity=False):
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
api.check_callable(callback_func)
flat_args, arg_treedef = pytree.flatten(arg)
for arg in flat_args:
dispatch.check_arg(arg)
# See definition of outside_call_p for what parameters it takes
params: Dict[str, Any] = {}
# TODO: wrap function
params["callback"] = _CallbackWrapper(callback_func, identity,
call_with_device)
params["identity"] = identity
params["arg_treedef"] = arg_treedef
params["device_index"] = device_index
if not identity:
# Turn abstract values into ShapesDtypeStruct
flat_results_shape, result_treedef = pytree.flatten(result_shape)
try:
flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.dtype(r, canonicalize=True))
for r in flat_results_shape]
except Exception:
msg = ("result_shape should be a pytree of values with structure "
"matching the expected result of the callback function. The "
"values must be either numeric scalars, or must have 'shape' and "
f"'dtype' attributes. Got {result_shape}")
raise ValueError(msg)
params["result_treedef"] = result_treedef
params["flat_results_aval"] = tuple(flat_results_aval)
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
# We need the lock for when we use the CustomCall implementation of callbacks.
# The outfeed implementation is driven by a single thread from C++.
_print_tap_lock = threading.Lock()
def _print_tap_func(
arg, transforms, *, device=None,
output_stream=None, threshold=1024, **kwargs):
"""The consumer for id_print.
We provide this as a simple tapping function for printing.
This is **experimental** and may not want to add many features to it;
it should be easy for the user to roll their own printing function.
Args:
device: the device from which the value originates (only if
``tap_with_device`` was used for :func:`id_print`).
output_stream: a function whose `write` method is called with the strings to
be output.
threshold: the value of numpy.array2string threshold parameter.
**kwargs: all other keyword args are printed before printing `arg`.
"""
def emit_str(s: str):
if output_stream is not None:
output_stream.write(s + "\n")
else:
print(s)
if transforms:
kwargs['transforms'] = [(name, params) if params else name
for name, params in transforms]
if device is not None:
kwargs['device'] = device
kv_pairs = " ".join([
f"{k}: {v}" for k, v in sorted(kwargs.items())
])
def pp_val(arg) -> pp.Doc:
if isinstance(arg, tuple):
return pp.group(pp.concat([
pp.text("( "),
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
pp.text(" )")
]))
elif isinstance(arg, list):
return pp.group(pp.concat([
pp.text("[ "),
pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])),
pp.text(" ]")
]))
elif isinstance(arg, dict):
return pp.group(pp.concat([
pp.text("{ "),
pp.nest(2, pp.join(pp.brk(), [
pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items())
])),
pp.text(" }")
]))
elif isinstance(arg, np.ndarray):
return pp.text(np.array2string(arg, threshold=threshold))
else:
return pp.text(str(arg))
with _print_tap_lock:
if kv_pairs:
emit_str(kv_pairs)
emit_str(str(pp_val(arg)))
def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
return tuple(core.raise_to_shaped(core.get_aval(v)) for v in vals)
### The id_tap_dep primitive
# The id_tap_dep_p primitive is used to create a dependency of the result of
# id_tap on the actual tap operation. This is only needed when the
# id_tap function is used with the `result` parameter. This primitive acts
# as the identity operator on the first argument.
#
# For example, given `id_tap(f, (a, b), result=(r, s)`, we convert this to
#
# a1, b1 = outside_call_p(f, a, b)
# r1 = id_tap_dep_p(r, a1)
# s1 = id_tap_dep_p(s, a1)
#
# There are always two arguments and the result is equal to the first.
id_tap_dep_p = core.Primitive("id_tap_dep")
id_tap_dep_p.multiple_results = False
id_tap_dep_p.def_impl(lambda r, _: r)
xla.register_translation(id_tap_dep_p,
lambda ctx, avals_in, avals_out, a_res, a_tap: [a_res])
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
def _id_tap_dep_jvp_rule(primals, tangents):
if FLAGS.jax_host_callback_ad_transforms:
assert False
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]),
id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1]))
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if FLAGS.jax_host_callback_ad_transforms:
assert False
if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res)
else:
ct_res = None
if ad.is_undefined_primal(arg_tap):
ct_tap = ad.Zero(arg_tap.aval)
else:
ct_tap = None
return (ct_res, ct_tap)
ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
def _id_tap_dep_batching_rule(batched_args, batch_dims):
if FLAGS.jax_host_callback_ad_transforms:
assert False
arg_res, arg_tap = batched_args
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule
### The outside_call primitive
"""
This primitive is used to implement the `call` and `id_tap` functions.
It takes several positional arguments that are the flattened
according to `arg_treedef`.
The result of the primitive is computed based on the `identity` parameter,
as follows:
* if `identity` is True, then the results are the same as the
positional arguments of the primitive (except perhaps the last couple of
arguments, see `has_token`). In this case, `result_treedef` and
`flat_results_aval` are ignored, and `args_treedef` describes the result also.
* if `identity` is False, then the results are those from
the call to the outside computation:
flatten(callback(arg_treedef.unflatten(args), device=...))
In this case, the callback results must match `result_treedef`
and `flat_results_aval`.
It takes the following parameters:
* callback: the function to invoke with the unflattened arguments,
the device and the transforms: `callback(arrays, device, transforms)`
* arg_treedef: the treedef for the argument.
* identity: see description above.
* result_treedef, flat_results_aval: describes the expected result of the
callback. Only used when not `identity`.
* transforms: a tuple of the transformations that have been applied. Each
element of the tuple is itself a tuple with the first element the name
of the transform. The remaining elements depend on the transform. For
example, for `batch`, the parameters are the dimensions that have been
batched, and for `mask` the logical shapes. These are unpacked by
_outside_call_run_callback before passing to the user function.
* has_token: a boolean, when True it means that the last positional argument
is the current token. In this case, the result of the primitive is
going to be the non-token positional arguments, along with the updated
token. The tokens and this parameter are added after all the JAX
transformations, just before staging XLA.
* device_index: an integer, denotes from which device the invocation is from.
Works only when using the outfeed implementation mechanism, i.e., does
not work on CPU unless --jax_host_callback_outfeed=True.
"""
outside_call_p = core.Primitive("outside_call")
outside_call_p.multiple_results = True
core.outfeed_primitives.add(outside_call_p)
def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
identity, **params) -> Sequence[pe.AbstractValue]:
if identity:
# Do some validation here
assert "result_treedef" not in params
assert "flat_results_aval" not in params
return args_a
assert params["device_index"] is not None
assert params["result_treedef"] is not None
assert params["flat_results_aval"] is not None
flat_results_aval = params["flat_results_aval"]
if "has_token" in params and params["has_token"]:
assert len(args_a) >= 2
return flat_results_aval + args_a[-2:]
else:
return flat_results_aval
outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
def _outside_call_impl(*args, **params):
assert "has_token" not in params
if _inline_host_callback():
device_index = params["device_index"]
device = xb.devices()[device_index]
results = _outside_call_run_callback(args, device, send_infeed=False, **params)
return results
else:
# We use the jitted-version of the primitive even for eager execution, both
# so that we do not duplicate logic, but also so that all outfeed is received
# by the outfeed_listeners, in the same thread from a given device. If we were
# to process the tap here, it would be coming from the main thread. Also,
# even in eager execution some primitives, such as while, are compiled.
# It would be confusing to process a sequence "id_tap; while" in two
# different threads.
return dispatch.apply_primitive(outside_call_p, *args, **params)
outside_call_p.def_impl(_outside_call_impl)
def _outside_call_translation_rule(ctx,
avals_in,
avals_out,
*args_op: XlaOp,
has_token,
identity,
device_index,
flat_results_aval=(),
**params):
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
current_token = args_op[-2]
current_itoken = args_op[-1]
comp = ctx.builder
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
"The last two arguments must be tokens")
args_to_outfeed = args_op[:-2]
# Some platforms refuse to infeed empty arrays. We generate constants
# instead.
non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
flat_results_aval))
need_callback_results_on_device = (not identity and
len(non_empty_flat_results_aval) > 0)
use_outfeed = _use_outfeed(ctx.platform)
# TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is
# bumped past 0.3.8.
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
send_infeed = use_outfeed and need_callback_results_on_device
generated_infeed = False # Keep track if we emitted an infeed op
if use_outfeed:
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
callback_id = _register_callback(
functools.partial(
_outside_call_run_callback,
send_infeed=send_infeed,
identity=identity,
flat_results_aval=flat_results_aval,
**params))
next_token = _callback_handler_data.receiver.add_outfeed(
comp, current_token, callback_id, args_to_outfeed, device_index)
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
else:
empty_results = [
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
for aval in flat_results_aval
if _aval_is_empty(aval)
]
if non_empty_flat_results_aval:
assert need_callback_results_on_device
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
# We shard the infeed as AssignedDevice(device_index). This must match the
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
# this kind of sharding, we use a custom translation for infeed.
array_sharding_proto = xla_client.OpSharding()
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
array_sharding_proto.tile_assignment_dimensions = [1]
array_sharding_proto.tile_assignment_devices = [device_index]
token_sharding_proto = xla_client.OpSharding()
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
infeed_sharding_proto = xla.tuple_sharding_proto(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])
shape = [
shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x)
]
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)
non_empty_results = [
xops.GetTupleElement(outs, i)
for i in range(len(non_empty_flat_results_aval))
]
generated_infeed = True
results = [
empty_results.pop(0)
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
for result_aval in flat_results_aval
]
else:
results = empty_results
next_itoken = current_itoken
else: # !use_outfeed : CustomCall implementation
if device_index != 0:
raise ValueError("The device_index feature works only when using outfeed.")
# TODO(necula): this is a weak attempt to get the device. This works
# inside pmap, but does not work when we just execute on a single device,
# because in such executions we always get replica_id == 0.
replica_id = xla_client.ops.ReplicaId(comp)
callback_operands = (current_token, replica_id) + args_to_outfeed
if identity:
callback_flat_results_aval = (core.abstract_token,)
else:
callback_flat_results_aval = (core.abstract_token,) + flat_results_aval
def wrapped_callback(*args):
token, replica_id, *arrays = args
result_arrays = _outside_call_run_callback(
arrays,
xb.local_devices()[replica_id],
send_infeed=False,
# The same parameters as outside_call_p
identity=identity,
flat_results_aval=flat_results_aval,
**params)
if identity:
# For identity, we do not pass the any results back to the device
result_arrays = ()
return (token,) + result_arrays
result_shapes = [
xla.aval_to_xla_shapes(res_aval)[0]
for res_aval in callback_flat_results_aval
]
backend = ctx.module_context.backend
token_and_results_op, keep_alive = backend.emit_python_callback(
wrapped_callback,
comp,
callback_operands,
result_shapes,
operand_layouts=None,
has_side_effects=True)
_callback_handler_data.keep_alives.append(keep_alive)
next_token, *results = (xops.GetTupleElement(token_and_results_op, i)
for i in range(len(callback_flat_results_aval)))
# We must put the two tokens at the end
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
assert generated_infeed == send_infeed, (
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
assert identity or len(results) == len(flat_results_aval), (
f"got {len(results)} but expected {len(flat_results_aval)}. "
f"identity = {identity}")
return results + [next_token, next_itoken]
xla.register_translation(outside_call_p, _outside_call_translation_rule)
def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
*args,
has_token: bool,
identity: bool,
device_index: int,
flat_results_aval=(),
**params):
"""MLIR Lowering for `CustomCall`-based HCB."""
platform = ctx.module_context.platform
use_outfeed = _use_outfeed(platform)
if use_outfeed:
# Fall back to XLA path if we are using the outfeed
# TODO(sharadmv): update to use MLIR for this path as well and delete
# XLA lowering
return mlir.xla_fallback_lowering(outside_call_p)(
ctx,
*args,
has_token=has_token,
identity=identity,
flat_results_aval=flat_results_aval,
device_index=device_index,
**params)
else:
if device_index != 0:
raise ValueError("The device_index feature works only when using outfeed.")
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
current_token = args[-2]
current_itoken = args[-1]
assert current_token.type == hlo.TokenType.get(), "The last two arguments must be tokens"
assert current_itoken.type == hlo.TokenType.get(), "The last two arguments must be tokens"
args_to_outfeed = args[:-2]
# TODO(necula): this is a weak attempt to get the device. This works
# inside pmap, but does not work when we just execute on a single device,
# because in such executions we always get replica_id == 0.
replica_id = hlo.ReplicaIdOp()
callback_operands = [replica_id, *args_to_outfeed]
callback_operand_avals = [
core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]]
if identity:
callback_flat_results_aval = []
else:
callback_flat_results_aval = [*flat_results_aval]
def wrapped_callback(*args):
replica_id, *arrays = args
result_arrays = _outside_call_run_callback(
arrays,
xb.local_devices()[replica_id],
send_infeed=False,
# The same parameters as outside_call_p
identity=identity,
flat_results_aval=flat_results_aval,
**params)
if identity:
# For identity, we do not pass the any results back to the device
result_arrays = ()
return result_arrays
if isinstance(
ctx.module_context.axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
):
# Apply maximal sharding so pjit only executes the callback on device device_index.
sharding = xla_client.OpSharding()
sharding.type = xla_client.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [device_index]
else:
sharding = None
results, next_token, keep_alive = mlir.emit_python_callback(ctx,
wrapped_callback, current_token, callback_operands,
callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type]
has_side_effect=True, sharding=sharding)
_callback_handler_data.keep_alives.append(keep_alive)
# We must put the two tokens at the end
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
assert identity or len(results) == len(flat_results_aval), (
f"got {len(results)} but expected {len(flat_results_aval)}. "
f"identity = {identity}")
return results + [next_token, next_itoken]
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
def _outside_call_run_callback(
arrays, device, *,
send_infeed=True,
# The same parameters as outside_call_p
callback, arg_treedef,
identity, result_treedef=None, flat_results_aval=None,
transforms=(), has_token=False):
"""Performs the callback:
callback(arg, device, transforms)
Called during the device computation once we have the argument, either from
an inlined callback or from an XLA computation outfeed.
Returns the flat list of result arrays. If `send_infeed` then it will also send
the flat list of results to the device.
"""
def _unpack_transforms(transforms) -> Tuple[Tuple[str, Dict[str, Any]], ...]:
def _unpack_transform(name, *params):
if name == "batch":
return name, dict(batch_dims=params[0])
elif name == "mask":
return name, dict(logical_shapes=5)
else:
assert not params, f"{name}, {params}"
return name, dict()
return tuple(_unpack_transform(*t) for t in transforms)
try:
arg = api.tree_unflatten(arg_treedef, arrays)
unpacked_transforms = _unpack_transforms(transforms)
logger.debug(
"Outside call invoking call_func %s, device=%s, transforms=%s",
callback, device, unpacked_transforms
)
res = callback(arg, device, unpacked_transforms)
if identity:
return tuple(arrays)
else: # Check the type of the callback results
assert result_treedef is not None
assert flat_results_aval is not None
actual_flat_results, actual_result_treedef = pytree.flatten(res)
if actual_result_treedef != result_treedef:
msg = (f"Callback func {callback} should have returned a result "
f"with pytree {result_treedef} but returned "
f"{actual_result_treedef}")
raise TypeError(msg)
canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results))
actual_flat_results_aval = _values_to_avals(canonical_flat_results)
logger.debug(
"Outside call %s result %s. Sending to infeed for device %s.",
callback, flat_results_aval, device,
)
if not all(ea.strip_weak_type() == ra.strip_weak_type()
for ea, ra in util.safe_zip(flat_results_aval,
actual_flat_results_aval)):
msg = (f"Callback func {callback} should have returned a result "
"with abstract values "
f"{result_treedef.unflatten(flat_results_aval)} "
f"but returned {actual_result_treedef.unflatten(actual_flat_results_aval)}")
raise TypeError(msg)
if send_infeed:
# Do not send the 0-sized arrays
non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r),
canonical_flat_results))
device.transfer_to_infeed(non_empty_canonical_flat_results)
return canonical_flat_results
except Exception as e:
logger.error("Outside call %s threw exception %s.", callback, e)
if send_infeed:
# Prepare some results to send in case of error. We are sending something
# with a distinctive shape (int8[12345]), one that is unlikely to be what the device
# expects. This should have the effect to abort the device computation,
# with an error message that we recognize. On TPU there seem to be no
# such check, and if we send anything at all the device computation will
# use some garbage data. So, on TPU we prefer to not send anything and let
# the computation hang.
# TODO: implement a proper error handling for TPU
if device.platform != "tpu":
canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
logger.debug("Outside call consumer %s exception %s. Sending to infeed the error result.",
callback, e)
device.transfer_to_infeed(tuple(canonical_flat_results))
else:
logger.debug("Outside call consumer %s exception %s. On TPU we do not send infeed.",
callback, e)
raise e # Let the exception propagate
def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
"""Adds the `transform` to the params["transforms"].
Uses a tuple representation internally, will be unpacked before the
callback by _ConsumerCallable.
"""
new_transform = (name, *transform_params)
return dict(
params, transforms=(params.get("transforms", ()) + (new_transform,)))
def _aval_is_empty(aval) -> bool:
return math.prod(aval.shape) == 0
def _instantiate_zeros(tan, arg):
"""Turn special ad.zero tangents into arrays of 0s for sending to host.
Args:
tan: the tangent.
arg: the argument for which we need to instantiate the tangent
Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type
and shape
"""
if type(tan) is not ad.Zero:
return tan
return ad.instantiate_zeros_aval(tan.aval, tan)
def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
if FLAGS.jax_host_callback_ad_transforms:
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
(arg_treedef.unflatten(primals),
arg_treedef.unflatten(tangents_instantiated)))
out_all = outside_call_p.bind(
*jvp_flat_args,
**dict(_add_transform(params, "jvp"),
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
else:
out_primals_tapped = outside_call_p.bind(*primals, **params)
return tuple(out_primals_tapped), tangents
ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
def _outside_call_partial_eval_rule(trace, *args, **params):
# partial eval is used after jvp and before transpose.
if not FLAGS.jax_host_callback_ad_transforms:
# TODO: just remote the partial eval rule
return trace.default_process_primitive(outside_call_p, args, params)
transforms = params.get("transforms", ())
if not transforms or transforms[-1] != ("jvp",):
# We are not in the process of computing VJP
return trace.default_process_primitive(outside_call_p, args, params)
# The args have been prepared by the id_tap_jvp_rule: primals, tangents. The
# result is a pair of the primal outputs and output tangents.
# One invariant that JAX requires is that if the primals arguments are known
# then the primal outputs must be known. So, if the primal arguments are known
# and some of the tangents are unknown, then we must split the tap into
# one for the primals (thus the output will be considered known), and a
# separate tap for the tangents.
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.")
assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = util.split_list(args, [nr_primals])
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)
if not (all_primals_known and some_tangents_unknown):
return trace.default_process_primitive(outside_call_p, args, params)
prims, _ = params["arg_treedef"].unflatten(args)
_, primals_treedef = api.tree_flatten(prims)
outs_known = trace.default_process_primitive(
outside_call_p, primals,
dict(params,
arg_treedef=primals_treedef,
transforms=transforms[:-1]))
# Now compute the unknowns using the whole tap, and merge them with the tapped ones
outs_all_unknown = trace.default_process_primitive(outside_call_p, args, params)
outs_primals_unknown, outs_tangents_unknown = util.split_list(
outs_all_unknown, [nr_primals])
outs_combined = (
[pe.JaxprTracer(trace, pe.PartialVal.known(primal_known),
primal_unknown.recipe)
for primal_known, primal_unknown in util.safe_zip(outs_known, outs_primals_unknown)] +
outs_tangents_unknown)
return tuple(outs_combined)
pe.custom_partial_eval_rules[outside_call_p] = _outside_call_partial_eval_rule
def _outside_call_transpose_rule(cts, *args, **params):
if not params["identity"]:
raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.")
assert "has_token" not in params
assert len(cts) == len(args)
cts_instantiated = tuple(map(_instantiate_zeros, cts, args))
# The args have been prepared by the id_tap_jvp_rule: tapped_primals, tapped_tangents, rest_primals, rest_tangents
transforms = params.get("transforms", ())
if not transforms or transforms[-1] != ("jvp",):
# TODO: I should understand better when can this happen. It seems to arise
# in scan.
return outside_call_p.bind(
*cts_instantiated,
**_add_transform(params, "transpose"))
if not FLAGS.jax_host_callback_ad_transforms:
assert False
assert len(args) % 2 == 0
nr_primals = len(args) // 2
args_unflat, tan_unflat = params["arg_treedef"].unflatten(args)
_, vjp_arg_treedef = api.tree_flatten(args_unflat)
# We want to tap the cts_tapped_tangents
cts_primals, cts_tangents = util.split_list(cts_instantiated, [nr_primals])
cts_tangents_through_tap = outside_call_p.bind(
*cts_tangents,
**dict(_add_transform(params, "transpose"),
arg_treedef=vjp_arg_treedef))
return cts_primals + cts_tangents_through_tap
ad.primitive_transposes[outside_call_p] = _outside_call_transpose_rule
def _outside_call_batching_rule(batched_args, batch_dims, **params):
if not params["identity"]:
raise NotImplementedError("batching rules are implemented only for id_tap, not for call.")
assert "has_token" not in params
new_params = _add_transform(params, "batch", batch_dims)
res = outside_call_p.bind(*batched_args, **new_params)
return res, batch_dims
batching.primitive_batchers[outside_call_p] = _outside_call_batching_rule
####
#### Jaxpr rewriting logic to thread the tokens through stateful primitives.
####
def _rewrite_closed_jaxpr(cjaxpr: core.ClosedJaxpr, has_input_token: bool,
has_output_token: bool) -> core.ClosedJaxpr:
"""Rewrites a ClosedJaxpr to thread the token, if needed."""
new_jaxpr = _rewrite_jaxpr(cjaxpr.jaxpr, has_input_token, has_output_token)
return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts)
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
has_output_token: bool) -> core.Jaxpr:
"""Rewrite a Jaxpr to thread the token, if needed."""
assert has_input_token or not has_output_token
if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr):
return jaxpr
mk_new_var = core.gensym([jaxpr])
eqns: List[core.JaxprEqn] = []
# store the incoming tokens
last_token_var = mk_new_var(core.abstract_token)
last_itoken_var = mk_new_var(core.abstract_token)
if has_input_token:
invars = jaxpr.invars + [last_token_var, last_itoken_var]
else:
invars = jaxpr.invars
# We need tokens but none is given in input; make one depending on all invars
eqns.append(
core.new_jaxpr_eqn(jaxpr.invars, [last_token_var],
lax.create_token_p, {}, core.no_effects, source_info_util.current()))
eqns.append(
core.new_jaxpr_eqn(jaxpr.invars, [last_itoken_var],
lax.create_token_p, {}, core.no_effects, source_info_util.current()))
for eqn in jaxpr.eqns:
if not core.primitive_uses_outfeed(eqn.primitive, eqn.params):
eqns.append(eqn)
else:
output_token_var = mk_new_var(last_token_var.aval)
output_itoken_var = mk_new_var(last_itoken_var.aval)
_rewrite_eqn(eqn, eqns, last_token_var, output_token_var,
last_itoken_var, output_itoken_var, mk_new_var)
last_token_var = output_token_var
last_itoken_var = output_itoken_var
outvars = jaxpr.outvars + ([last_token_var, last_itoken_var] if has_output_token else [])
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr.effects)
return new_jaxpr
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
input_token_var: core.Var, output_token_var: core.Var,
input_itoken_var: core.Var, output_itoken_var: core.Var,
mk_new_var: Callable[[core.AbstractValue], core.Var]):
"""Rewrite an `eqn` and append equations to `eqns`.
This is only called if the current primitive uses outfeed.
Assume that the current token is in `input_token_var` and the resulting
token must end in `output_token_var`.
Append the result of rewriting to `eqns`.
"""
if eqn.primitive is outside_call_p:
assert "has_token" not in eqn.params
eqns.append(eqn.replace(invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(eqn.params, has_token=True)))
elif eqn.primitive is lax.while_p:
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
eqn.params,
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
if core.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
input_itoken_var, output_itoken_var,
mk_new_var)
return
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True),
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True, False))))
elif eqn.primitive is lax.cond_p:
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
index, *operands = eqn.invars
new_invars = [index, *operands, input_token_var, input_itoken_var]
eqns.append(
eqn.replace(
invars=new_invars, outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
branches=tuple(
_rewrite_closed_jaxpr(jaxpr, True, True)
for jaxpr in branches),
linear=(*linear, False, False))))
elif eqn.primitive is lax.scan_p:
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
eqn.params,
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
"unroll"])
# We add the tokens right at the end of carry
nr_const_and_carry = num_consts + num_carry
new_invars = eqn.invars[0:nr_const_and_carry] + [
input_token_var, input_itoken_var] + eqn.invars[nr_const_and_carry:]
new_jaxpr = _rewrite_closed_jaxpr(carry_jaxpr, True, True)
# The rewrite has put the token at end, it has to be at end of carry
new_jaxpr_invars = new_jaxpr.jaxpr.invars
new_jaxpr_invars = (
new_jaxpr_invars[0:nr_const_and_carry] + new_jaxpr_invars[-2:] +
new_jaxpr_invars[nr_const_and_carry:-2])
new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(invars=new_jaxpr_invars))
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
new_jaxpr_outvars = (
new_jaxpr_outvars[0:num_carry] + new_jaxpr_outvars[-2:] +
new_jaxpr_outvars[num_carry:-2])
new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(outvars=new_jaxpr_outvars))
eqns.append(
eqn.replace(
invars=new_invars,
# Output token is at the end of carry result
outvars=(eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] +
eqn.outvars[num_carry:]),
params=dict(
eqn.params,
jaxpr=new_jaxpr,
num_carry=num_carry + 2,
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
elif eqn.primitive is pxla.xla_pmap_p:
# We broadcast the input token into an array of tokens
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
# Sharding/unsharding of tokens in pmap_translation are special
# cased to just pass-through the token
in_axes=eqn.params["in_axes"] + (None, None),
out_axes=eqn.params["out_axes"] + (0, 0))))
elif eqn.primitive is custom_derivatives.custom_jvp_call_p:
fun_jaxpr = eqn.params["call_jaxpr"]
def unreachable_thunk():
assert False, "Should not be reached"
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
jvp_jaxpr_thunk=unreachable_thunk
)))
elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
fun_jaxpr = eqn.params["fun_jaxpr"]
new_invars = [*eqn.invars, input_token_var, input_itoken_var]
def unreachable_thunk():
assert False, "Should not be reached"
eqns.append(
eqn.replace(
invars=new_invars,
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
fwd_jaxpr_thunk=unreachable_thunk,
# The following are illegal values for the parameters, they
# should not be needed because this rewrite is just before
# compilation to XLA, which does not use those parameters.
bwd="illegal param",
out_trees="illegal param")))
elif eqn.primitive is pjit.pjit_p:
jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
in_shardings=(
eqn.params["in_shardings"]
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
),
out_shardings=(
eqn.params["out_shardings"]
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
),
),
)
)
elif eqn.primitive is ad_checkpoint.remat_p:
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
jaxpr=_rewrite_jaxpr(jaxpr_, True, True),
)))
else:
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
input_token_var: core.Var,
output_token_var: core.Var,
input_itoken_var: core.Var,
output_itoken_var: core.Var,
mk_new_var: Callable):
"""Rewrite a while whose cond has outfeed"""
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
transformed_cond_jaxpr = _rewrite_closed_jaxpr(cond_jaxpr, True, True)
carry_invars = eqn.invars[cond_nconsts + body_nconsts:]
# pred1, token1, itoken1 = rewrite(COND)(cond_consts, carry_invars, input_token, input_itoken)
pred1_and_token1 = [
mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars
]
eqns.append(
core.new_jaxpr_eqn(
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var],
pred1_and_token1, core.call_p,
dict(
call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_before"),
transformed_cond_jaxpr.jaxpr.effects,
eqn.source_info))
# Make a new cond "lambda pred, carry, token, itoken: pred"
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
new_cond_invars = (
[new_cond_pred_invar] + [mk_new_var(cv.aval) for cv in carry_invars] +
[mk_new_var(input_token_var.aval),
mk_new_var(input_itoken_var.aval)])
new_cond_jaxpr = core.ClosedJaxpr(
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], [], set()), [])
# Make a new body:
# "lambda cond_constvars, body_constvars, pred, carry, token, itoken:
# carry2, token2, itoken2 = rewrite(BODY)(body_constvars, carry, token, itoken)
# pred2, token3, itoken3 = rewrite(COND)(cond_constvars, carry2, token2, itoken2)
# (pred2, carry2, token3, itoken3)
transformed_body_jaxpr = _rewrite_closed_jaxpr(body_jaxpr, True, True)
new_body_invars_cond_constvars = [
mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]
]
new_body_invars_body_constvars = [
mk_new_var(v.aval)
for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts]
]
new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
new_body_invars_token = mk_new_var(input_token_var.aval)
new_body_invars_itoken = mk_new_var(input_itoken_var.aval)
new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
new_body_token2 = mk_new_var(input_token_var.aval)
new_body_itoken2 = mk_new_var(input_itoken_var.aval)
new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
new_body_token3 = mk_new_var(input_token_var.aval)
new_body_itoken3 = mk_new_var(input_itoken_var.aval)
new_body_eqns = [
core.new_jaxpr_eqn(
new_body_invars_body_constvars + new_body_invars_carry +
[new_body_invars_token, new_body_invars_itoken],
new_body_carry2 + [new_body_token2, new_body_itoken2],
core.call_p,
dict(
call_jaxpr=transformed_body_jaxpr.jaxpr,
name="body"),
transformed_body_jaxpr.effects,
eqn.source_info),
core.new_jaxpr_eqn(
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
[new_body_pred2, new_body_token3, new_body_itoken3], core.call_p,
dict(
call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_body"),
transformed_cond_jaxpr.effects,
eqn.source_info)
]
effects = core.join_effects(*(eqn.effects for eqn in new_body_eqns))
new_body_jaxpr = core.ClosedJaxpr(
core.Jaxpr([], (new_body_invars_cond_constvars +
new_body_invars_body_constvars + [new_body_invars_pred] +
new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken]),
([new_body_pred2] + new_body_carry2 + [new_body_token3, new_body_itoken3]),
new_body_eqns, effects), [])
pred_out = mk_new_var(cond_jaxpr.out_avals[0])
eqns.append(
core.new_jaxpr_eqn(
(eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] +
carry_invars + pred1_and_token1[1:]),
([pred_out] + eqn.outvars + [output_token_var, output_itoken_var]),
lax.while_p,
dict(
cond_jaxpr=new_cond_jaxpr,
cond_nconsts=0,
body_jaxpr=new_body_jaxpr,
body_nconsts=cond_nconsts + body_nconsts),
new_body_jaxpr.effects,
eqn.source_info))
# We need an identity primitive to simplify rewriting
id_p = core.Primitive("id")
id_p.multiple_results = True
id_p.def_impl(lambda *args: args)
id_p.def_abstract_eval(lambda *args: args)
xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
class CallbackException(Exception):
"""Signals that some callback function had exceptions.
Raised by :func:`barrier_wait`. See the :mod:`jax.experimental.host_callback`
module documentation for details.
"""
pass
TapFunctionException = CallbackException # For backwards compatibility
class _CallbackHandlerData:
"""Keep track of the outfeed receiver data."""
receiver: Any
initialized: bool
on_exit: bool
lock: threading.Lock
last_callback_exception: Optional[Tuple[Exception, str]]
clients: Tuple[XlaLocalClient, ...]
devices: Tuple[XlaDevice, ...]
consumer_registry: Dict[Callable, int]
consumer_registry_by_id: Dict[int, Callable]
def __init__(self):
self.receiver = None # Initialize lazily, when first needed
self.initialized = False
self.on_exit = False
self.lock = threading.Lock()
self.last_callback_exception = None
self.clients = ()
self.devices = ()
# The consumer registries must be live for the lifetime of the program,
# because we may have cached compilations that embed consumer ids, and we
# do not want the id reused for other shapes.
# Used only for the outfeed mechanism.
self.callback_registry = dict()
self.callback_registry_by_id = dict()
# For now we keep here the keep_alives for the emit_python_callback. This is
# a leak. We ought to attach these to the executable.
self.keep_alives = []
def stop(self):
"""Wait for all pending outfeeds and stop the receiver."""
self.receiver = None # GC will trigger the destructor
self.initialized = False
self.clients = ()
self.devices = ()
# Do not clear the consumer registries.
_callback_handler_data = _CallbackHandlerData()
# This function is called from C++; it must not allow exceptions through.
def _callback_input_received(device, consumer_id, arrays: Tuple):
array_repr = ", ".join([f"({a.dtype}{a.shape})" for a in arrays])
logger.debug("Callback input received on device %s for consumer %s arrays: %s",
device, consumer_id, array_repr)
callback = _callback_handler_data.callback_registry_by_id.get(consumer_id)
assert callback is not None, "We should have crashed in the runtime"
try:
return callback(arrays, device)
except Exception as e:
formatted_e = traceback.format_exc()
logger.error("Postponing exception raised in callback function: %s", formatted_e)
_callback_handler_data.last_callback_exception = (e, formatted_e)
def _register_callback(callback: Callable) -> int:
"""Registers a callback function, cache by hash of callback.
The callback is a function to be invoked as `callback(arrays, device)`.
"""
callback_id = _callback_handler_data.callback_registry.get(callback)
if callback_id is not None:
return callback_id
callback_id = hash(callback) & 0xFFFFFFFC # pybind11 has trouble here with large ints
callback_id += 1 # Reserve the consumer ID 0
assert callback_id not in _callback_handler_data.callback_registry, (
"callback id collision")
_callback_handler_data.callback_registry[callback] = callback_id
_callback_handler_data.callback_registry_by_id[callback_id] = callback
return callback_id
def _initialize_outfeed_receiver(
max_callback_queue_size_bytes: int = int(256 * 1e6)):
"""Creates and starts the outfeed_receiver.
This function is called lazily only when we compile an id_tap.
Args:
* clients: the list of clients (backends) on whose devices to listen on.
* max_callback_queue_size_bytes: an optional integer to bound the maximum
size of arrays in the callback queue. When this limit is reached the
device listener pauses.
"""
outfeed_receiver_module = xla_extension.outfeed_receiver
with _callback_handler_data.lock:
if _callback_handler_data.initialized:
return
# By default, all devices on all supported backends.
clients = [backend for name, backend in xb.backends().items()
if name in ("cpu", "cuda", "rocm", "tpu")]
devices = list(
itertools.chain(*[backend.local_devices() for backend in clients]))
_callback_handler_data.clients = clients # type: ignore[assignment]
_callback_handler_data.devices = devices # type: ignore[assignment]
clients_with_outfeed = [c for c in clients if _use_outfeed(c.platform)]
for client in clients_with_outfeed:
_raise_if_using_outfeed_with_pjrt_c_api(client)
if clients_with_outfeed:
devices_with_outfeed = list(
itertools.chain(*[backend.local_devices() for backend in clients_with_outfeed]))
if logger.isEnabledFor(logging.DEBUG):
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
device_repr, max_callback_queue_size_bytes)
_callback_handler_data.receiver = outfeed_receiver_module.start(
_callback_input_received, tuple(clients_with_outfeed),
max_callback_queue_size_bytes,
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
def exit_handler():
# Prevent logging usage during compilation, gives errors under pytest
dispatch._on_exit = True # type: ignore[protected-access]
if not _callback_handler_data.on_exit:
_callback_handler_data.on_exit = True
barrier_wait("at_exit")
atexit.register(exit_handler) # We wait as long as we have callbacks
_callback_handler_data.initialized = True
def barrier_wait(logging_name: Optional[str] = None):
"""Blocks the calling thread until all current outfeed is processed.
Waits until all callbacks from computations already running on all devices
have been received and processed by the Python callbacks. Raises
CallbackException if there were exceptions while processing the callbacks.
This works by enqueueing a special tap computation to all devices to which
we are listening for outfeed. Once all those tap computations are done, we
return from barrier_wait.
Note: If any of the devices are busy and cannot accept new computations,
this will deadlock.
Args:
logging_name: an optional string that will be used in the logging statements
for this invocation. See `Debugging` in the module documentation.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
logging_name = logging_name or ""
logger.debug("barrier_wait[%s]: start", logging_name)
lock = threading.Lock()
cv = threading.Condition(lock=lock)
devices_at_barrier = [] # Protected by lock
def barrier_tap_received(dev_idx, _):
device = _callback_handler_data.devices[dev_idx]
logger.debug(
"barrier_wait[%s]: at barrier_tap for device %s. Thread %s",
logging_name, device, threading.current_thread()
)
with lock:
devices_at_barrier.append(device)
if logger.isEnabledFor(logging.DEBUG):
waiting_for_devices = [d for d in _callback_handler_data.devices
if d not in devices_at_barrier]
logger.debug(
"barrier_wait[%s]: still waiting for %s devices at barrier (%s)",
logging_name, len(waiting_for_devices), waiting_for_devices
)
cv.notify()
for d_idx, d in enumerate(_callback_handler_data.devices):
logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d)
x_on_dev = api.device_put(d_idx, device=d)
api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev)
logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name)
with lock:
cv.wait_for(lambda: len(devices_at_barrier) == len(_callback_handler_data.devices))
logger.debug("barrier_wait[%s]: done", logging_name)
if _callback_handler_data.last_callback_exception is not None:
last_exception, formatted_last_exception = _callback_handler_data.last_callback_exception
_callback_handler_data.last_callback_exception = None
raise CallbackException(
"There were exceptions during callback processing. "
f"Last one was: {formatted_last_exception}") from last_exception
def stop_outfeed_receiver():
"""Stops the outfeed receiver runtime.
This waits for all outfeeds from computations already running on all devices,
and then stops the outfeed receiver runtime. The runtime will be restarted
next time you use a tap function.
It should not be necessary to use this function, unless you want to start
using lax.outfeed directly after having used host callbacks.
"""
_callback_handler_data.stop()