683 lines
27 KiB
Python
683 lines
27 KiB
Python
|
# Copyright 2021 The JAX Authors.
|
|||
|
#
|
|||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
# you may not use this file except in compliance with the License.
|
|||
|
# You may obtain a copy of the License at
|
|||
|
#
|
|||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|||
|
#
|
|||
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
# See the License for the specific language governing permissions and
|
|||
|
# limitations under the License.
|
|||
|
"""Allows JAX to call TensorFlow functions with support for autodiff.
|
|||
|
|
|||
|
**Experimental: please give feedback, and expect changes.**
|
|||
|
|
|||
|
This module introduces the function :func:`call_tf` that allows JAX to call
|
|||
|
TensorFlow functions.
|
|||
|
|
|||
|
For examples and details, see
|
|||
|
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
|
|||
|
|
|||
|
"""
|
|||
|
import functools
|
|||
|
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
|||
|
|
|||
|
from absl import logging
|
|||
|
import jax
|
|||
|
from jax import dlpack
|
|||
|
from jax import dtypes
|
|||
|
from jax import numpy as jnp
|
|||
|
from jax import tree_util
|
|||
|
from jax._src import ad_checkpoint
|
|||
|
from jax._src import ad_util
|
|||
|
from jax._src import core
|
|||
|
from jax._src import custom_derivatives
|
|||
|
from jax._src import effects
|
|||
|
from jax._src import util
|
|||
|
from jax._src.lax import control_flow as lax_control_flow
|
|||
|
from jax._src.lib import xla_client
|
|||
|
from jax._src.lib.mlir import ir
|
|||
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
|||
|
from jax._src.lib.mlir.dialects import hlo
|
|||
|
from jax._src.lib.mlir.dialects import stablehlo
|
|||
|
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
|||
|
from jax.interpreters import mlir
|
|||
|
from jax.interpreters import xla
|
|||
|
import numpy as np
|
|||
|
import tensorflow as tf
|
|||
|
|
|||
|
|
|||
|
map = util.safe_map
|
|||
|
zip = util.safe_zip
|
|||
|
|
|||
|
TfConcreteFunction = Any
|
|||
|
TfVal = jax2tf_internal.TfVal
|
|||
|
|
|||
|
# The platforms for which to use DLPack to avoid copying (only works on GPU
|
|||
|
# and CPU at the moment, and only for DeviceArray). For CPU we don't need
|
|||
|
# DLPack, if we are careful.
|
|||
|
_DLPACK_PLATFORMS = ("gpu",)
|
|||
|
|
|||
|
|
|||
|
def call_tf(
|
|||
|
callable_tf: Callable,
|
|||
|
has_side_effects=True,
|
|||
|
ordered=False,
|
|||
|
output_shape_dtype=None,
|
|||
|
call_tf_graph=False,
|
|||
|
) -> Callable:
|
|||
|
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
|
|||
|
|
|||
|
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
|
|||
|
numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
|
|||
|
function must return the same type of results.
|
|||
|
|
|||
|
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
|
|||
|
or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
|
|||
|
``callable_tf`` will be compiled with ``tf.function(callable_tf,
|
|||
|
jit_compile=True)``
|
|||
|
and the resulting XLA computation will be embedded in JAX's XLA computation.
|
|||
|
|
|||
|
If ``call_tf`` appears outside a JAX staging context, it will be called inline
|
|||
|
using TensorFlow eager mode.
|
|||
|
|
|||
|
The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
|
|||
|
``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
|
|||
|
that the gradient will be TensorFlow-accurate, e.g., will respect the
|
|||
|
custom gradients that may be defined for the code in ``callable_tf``.
|
|||
|
|
|||
|
For an example and more details see the
|
|||
|
`README
|
|||
|
<https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
|
|||
|
|
|||
|
Args:
|
|||
|
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
|
|||
|
arguments.
|
|||
|
has_side_effects: if True then it ensures that instances of this primitive
|
|||
|
are not removed or replicated by JAX optimizations such as dead-code
|
|||
|
elimination.
|
|||
|
ordered: If true, calls are modeled as having ordered effects.
|
|||
|
output_shape_dtype: An optional declaration of the expected shape and dtype
|
|||
|
of the result of the called TensorFlow function. If given it will be used
|
|||
|
during JAX tracing to form the abstract values of the results of the
|
|||
|
`call_tf`. If not given then we form a `tf.Graph` for the called
|
|||
|
TensorFlow function and we use the TensorFlow-inferred shapes and types.
|
|||
|
Must be a pytree matching the structure of the nested structure returned
|
|||
|
from the TensorFlow function, containing objects with `.shape` and
|
|||
|
`.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
|
|||
|
call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
|
|||
|
future.
|
|||
|
|
|||
|
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
|
|||
|
op-by-op mode or in a staged context. This callable can be used with JAX's
|
|||
|
reverse-mode autodiff (:func:`jax.grad`).
|
|||
|
"""
|
|||
|
@jax.custom_vjp
|
|||
|
def make_call(*args_jax):
|
|||
|
"""We wrap it all in `make_call` so that we can attach custom VJP."""
|
|||
|
|
|||
|
args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
|
|||
|
# Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
|
|||
|
def canonical_arg(v):
|
|||
|
v = v if getattr(v, "dtype", None) else np.asarray(v)
|
|||
|
dtype = dtypes.canonicalize_dtype(v.dtype)
|
|||
|
if dtype != v.dtype:
|
|||
|
v = v.astype(dtype)
|
|||
|
return v
|
|||
|
|
|||
|
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
|
|||
|
def make_tensorspec(a_jax):
|
|||
|
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
|
|||
|
a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape]
|
|||
|
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
|
|||
|
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
|
|||
|
|
|||
|
if output_shape_dtype is not None:
|
|||
|
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
|
|||
|
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
|
|||
|
else:
|
|||
|
output_avals, output_shape_dtype_tree = None, None
|
|||
|
|
|||
|
res_treedef = None # We'll store here the result treedef
|
|||
|
res_tf_flat = None # For error reporting
|
|||
|
# The function below will be called at least once, either in eager
|
|||
|
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
|
|||
|
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
|
|||
|
args_tf = args_treedef.unflatten(args_tf_flat)
|
|||
|
res_tf = callable_tf(*args_tf)
|
|||
|
|
|||
|
# b/279454591: When `callable_tf` is a tf function with zero outputs, it
|
|||
|
# returns a `StatefulPartitionedCall` (if the function is stateful) or
|
|||
|
# `PartitionedCall` (if the function is stateless) op instead of
|
|||
|
# tf.Tensors. We work around this issue by replacing the output `res_tf`
|
|||
|
# with an empty list.
|
|||
|
|
|||
|
if isinstance(res_tf, tf.Operation):
|
|||
|
assert (
|
|||
|
res_tf.type == "StatefulPartitionedCall"
|
|||
|
or res_tf.type == "PartitionedCall"
|
|||
|
)
|
|||
|
t_out = res_tf.get_attr("Tout")
|
|||
|
# t_out should be an empty list.
|
|||
|
assert not t_out, (
|
|||
|
"The TF function returned an unexpected result, please check its"
|
|||
|
f" function body. res_tf = {res_tf}"
|
|||
|
)
|
|||
|
res_tf = t_out
|
|||
|
|
|||
|
nonlocal res_treedef, res_tf_flat
|
|||
|
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
|
|||
|
assert res_treedef is None or res_treedef == res_treedef_now, (
|
|||
|
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
|
|||
|
res_treedef = res_treedef_now
|
|||
|
if output_avals is not None:
|
|||
|
if res_treedef != output_shape_dtype_tree:
|
|||
|
raise ValueError(
|
|||
|
"The pytree of the TensorFlow function results does not match the "
|
|||
|
"pytree of the declared output_shape_dtype:\n"
|
|||
|
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
|
|||
|
assert len(output_avals) == len(res_tf_flat)
|
|||
|
|
|||
|
checked_res_tf_flat = [
|
|||
|
check_tf_result(i, r_tf, r_aval)
|
|||
|
for i, (r_tf, r_aval) in enumerate(
|
|||
|
zip(res_tf_flat,
|
|||
|
(output_avals
|
|||
|
if output_avals is not None
|
|||
|
else (None,) * len(res_tf_flat))))]
|
|||
|
return checked_res_tf_flat
|
|||
|
|
|||
|
# Prepare a tf.function ahead of time, to cache the concrete functions. This
|
|||
|
# won't be used in op-by-op execution mode.
|
|||
|
# `jit_compile` is not enabled when `call_tf_graph` is True, since the
|
|||
|
# custom call function won't be compilable.
|
|||
|
function_flat_tf = tf.function(
|
|||
|
callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
|
|||
|
|
|||
|
res_jax_flat = call_tf_p.bind(
|
|||
|
*args_flat_jax,
|
|||
|
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
|||
|
callable_flat_tf=callable_flat_tf,
|
|||
|
function_flat_tf=function_flat_tf,
|
|||
|
args_flat_sig_tf=args_flat_sig_tf,
|
|||
|
output_avals=output_avals,
|
|||
|
has_side_effects=has_side_effects,
|
|||
|
ordered=ordered,
|
|||
|
call_tf_graph=call_tf_graph,
|
|||
|
)
|
|||
|
|
|||
|
# We must have called callable_flat_tf by nοw
|
|||
|
assert res_treedef is not None
|
|||
|
return res_treedef.unflatten(res_jax_flat)
|
|||
|
|
|||
|
# Define the fwd and bwd custom_vjp functions
|
|||
|
def make_call_vjp_fwd(*args_jax):
|
|||
|
# Return the primal arguments as the residual
|
|||
|
return make_call(*args_jax), args_jax
|
|||
|
|
|||
|
def make_call_vjp_bwd(residual_jax, ct_res_jax):
|
|||
|
args_jax = residual_jax # residual is the primal argument
|
|||
|
|
|||
|
def tf_vjp_fun(args_tf, ct_res_tf):
|
|||
|
"""Invoke TF gradient."""
|
|||
|
|
|||
|
# TF does not like us to watch non-float vars
|
|||
|
def replace_non_float(arg_tf):
|
|||
|
if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
|
|||
|
return arg_tf
|
|||
|
else:
|
|||
|
# When watched, this will be ignored. When used in results it will
|
|||
|
# result in a floating 0. gradient, which JAX will ignore (and
|
|||
|
# replace it with a float0)
|
|||
|
return tf.zeros((), dtype=tf.float32)
|
|||
|
|
|||
|
watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
|
|||
|
with tf.GradientTape(persistent=True) as tape:
|
|||
|
tape.watch(watched_args_tf)
|
|||
|
res = callable_tf(*args_tf)
|
|||
|
|
|||
|
tf.nest.assert_same_structure(res, ct_res_tf)
|
|||
|
dres_darg = tape.gradient(
|
|||
|
tf.nest.map_structure(replace_non_float, res),
|
|||
|
sources=watched_args_tf,
|
|||
|
output_gradients=ct_res_tf,
|
|||
|
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
|||
|
|
|||
|
dres_darg = tree_util.tree_map(
|
|||
|
lambda x: x if x is None else tf.convert_to_tensor(x),
|
|||
|
dres_darg,
|
|||
|
)
|
|||
|
tf.nest.assert_same_structure(dres_darg, args_tf)
|
|||
|
return dres_darg
|
|||
|
|
|||
|
# Use call_tf to call the VJP function
|
|||
|
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
|
|||
|
# We must make the float0s that JAX expects
|
|||
|
def fix_float0(arg_jax, ct_arg_jax):
|
|||
|
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
|
|||
|
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
|
|||
|
if ct_arg_dtype != ct_arg_jax.dtype:
|
|||
|
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
|
|||
|
ct_arg_dtype))
|
|||
|
return ct_arg_jax
|
|||
|
|
|||
|
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax)
|
|||
|
return ct_args_jax_fixed
|
|||
|
|
|||
|
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
|||
|
return util.wraps(callable_tf)(make_call)
|
|||
|
|
|||
|
|
|||
|
def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal:
|
|||
|
# Check that the TF function returns values of expected types. This
|
|||
|
# improves error reporting, preventing hard-to-diagnose errors downstream
|
|||
|
try:
|
|||
|
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
|
|||
|
except Exception as e:
|
|||
|
msg = ("The called TF function returns a result that is not "
|
|||
|
f"convertible to JAX: {r_tf}.")
|
|||
|
raise ValueError(msg) from e
|
|||
|
|
|||
|
if r_aval is None:
|
|||
|
return r_tf
|
|||
|
# We convert to TF type, and canonicalize to 32-bit if necessary
|
|||
|
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
|
|||
|
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
|
|||
|
# check at runtime that the returned shape matches the declared shape. I wish
|
|||
|
# that tf.ensure_shape did this, but it can only take shapes that contain None
|
|||
|
# not computed shapes. However, in eager mode we should be able to resolve
|
|||
|
# the declared shapes to constants and we get better checking.
|
|||
|
if tf.executing_eagerly():
|
|||
|
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
|
|||
|
else:
|
|||
|
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
|
|||
|
# We do as much checking as we can here, instead of relying on tf.ensure_shape
|
|||
|
# because the latter gives different errors in eager vs. compiled mode.
|
|||
|
# TODO(b/279454591): This strange error is from TF. Eager function suppose
|
|||
|
# return tf Val with concrete shape but not. Here we change exception to warn
|
|||
|
# and bypass it. This case need revisit on TF side.
|
|||
|
try:
|
|||
|
_ = len(r_tf.shape)
|
|||
|
except ValueError as e:
|
|||
|
msg = (
|
|||
|
"The shape check test cannot be performed because the shape of the"
|
|||
|
"`r_tf` tensor cannot be obtained."
|
|||
|
f"r_tf = {r_tf}, r_aval = {r_aval}"
|
|||
|
)
|
|||
|
msg += str(e)
|
|||
|
logging.warning(msg)
|
|||
|
return r_tf
|
|||
|
if (r_tf.dtype != r_aval_dtype_tf or
|
|||
|
len(r_tf.shape) != len(r_aval_shape_tf) or
|
|||
|
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
|
|||
|
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
|
|||
|
msg = ("The shapes or dtypes returned by the TensorFlow function "
|
|||
|
"do not match the declared output_shape_dtype:\n"
|
|||
|
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
|
|||
|
raise ValueError(msg)
|
|||
|
# At this point tf.ensure_shape does not do much, it should never throw an
|
|||
|
# error, albeit it may refine the shape a bit.
|
|||
|
return tf.ensure_shape(r_tf, r_aval_shape_tf)
|
|||
|
|
|||
|
|
|||
|
call_tf_p = core.Primitive("call_tf")
|
|||
|
call_tf_p.multiple_results = True
|
|||
|
|
|||
|
# The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
|
|||
|
def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
|
|||
|
# On GPU we use dlpack to avoid copies of data to the host.
|
|||
|
def _arg_jax_to_tf(arg_jax):
|
|||
|
if (isinstance(arg_jax, jax.Array) and
|
|||
|
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
|
|||
|
arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
|
|||
|
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
|
|||
|
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
|
|||
|
# The following avoids copies to the host on CPU, always for DeviceArray
|
|||
|
# and even for ndarray if they are sufficiently aligned.
|
|||
|
# TODO(necula): on TPU this copies to the host!
|
|||
|
return tf.constant(np.asarray(arg_jax))
|
|||
|
|
|||
|
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
|||
|
with jax2tf_internal.inside_call_tf():
|
|||
|
# Call in TF eager mode
|
|||
|
res_tf_flat = callable_flat_tf(*args_tf_flat)
|
|||
|
|
|||
|
def _res_tf_to_jax(res_tf: TfVal):
|
|||
|
res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
|
|||
|
if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
|
|||
|
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
|
|||
|
res_jax_platform = res_tf_platform.lower()
|
|||
|
if res_jax_platform in _DLPACK_PLATFORMS:
|
|||
|
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
|
|||
|
return jax.dlpack.from_dlpack(res_dlpack)
|
|||
|
|
|||
|
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
|
|||
|
# To handle this special case, we create a numpy copy.
|
|||
|
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
|
|||
|
return jax.device_put(jnp.array(res_tf.numpy()))
|
|||
|
else:
|
|||
|
return jax.device_put(np.asarray(res_tf))
|
|||
|
|
|||
|
return list(map(_res_tf_to_jax, res_tf_flat))
|
|||
|
|
|||
|
|
|||
|
call_tf_p.def_impl(_call_tf_impl)
|
|||
|
|
|||
|
@functools.lru_cache(maxsize=128)
|
|||
|
def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
|
|||
|
with jax2tf_internal.inside_call_tf():
|
|||
|
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
|
|||
|
|
|||
|
|
|||
|
# Mark the effectful instances of call_tf
|
|||
|
class CallTfEffect(effects.Effect):
|
|||
|
__str__ = lambda _: "CallTfEffect"
|
|||
|
|
|||
|
call_tf_effect = CallTfEffect()
|
|||
|
|
|||
|
effects.lowerable_effects.add_type(CallTfEffect)
|
|||
|
effects.control_flow_allowed_effects.add_type(CallTfEffect)
|
|||
|
effects.remat_allowed_effects.add_type(CallTfEffect)
|
|||
|
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
|
|||
|
|
|||
|
|
|||
|
class CallTfOrderedEffect(effects.Effect):
|
|||
|
__str__ = lambda _: "CallTfOrderedEffect"
|
|||
|
|
|||
|
|
|||
|
call_tf_ordered_effect = CallTfOrderedEffect()
|
|||
|
|
|||
|
effects.lowerable_effects.add_type(CallTfOrderedEffect)
|
|||
|
effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
|
|||
|
effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
|
|||
|
effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
|
|||
|
effects.ordered_effects.add_type(CallTfOrderedEffect)
|
|||
|
|
|||
|
|
|||
|
def _call_tf_abstract_eval(
|
|||
|
*args_flat_avals,
|
|||
|
function_flat_tf,
|
|||
|
args_flat_sig_tf,
|
|||
|
has_side_effects,
|
|||
|
ordered,
|
|||
|
output_avals,
|
|||
|
call_tf_graph,
|
|||
|
**__,
|
|||
|
):
|
|||
|
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
|
|||
|
effects = set()
|
|||
|
if ordered:
|
|||
|
effects.add(call_tf_ordered_effect)
|
|||
|
elif has_side_effects:
|
|||
|
effects.add(call_tf_effect)
|
|||
|
|
|||
|
# If no output_avals is given, then we ask TF to infer the output shapes.
|
|||
|
# We call this even if output_avals is given because it will ensure that
|
|||
|
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
|
|||
|
# there is a small cost of calling it more often than needed.
|
|||
|
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
|
|||
|
args_flat_sig_tf)
|
|||
|
# TODO(b/278298710): when `call_tf_graph=True` for non-compilable tf function,
|
|||
|
# Tensorflow shape inference is not supported and the concrete function has
|
|||
|
# no structured output shapes attributes sometimes.
|
|||
|
# So users always need provide output_shape_dtypes. However, in some case if
|
|||
|
# In the case that the tf.function has no return value, the `output_shape_dtype` should be `None`
|
|||
|
if len(concrete_function_flat_tf.outputs) == 0:
|
|||
|
return tuple(), effects
|
|||
|
|
|||
|
if call_tf_graph and output_avals is None:
|
|||
|
raise ValueError(
|
|||
|
"call_tf with `call_tf_graph=True` must provide output_shape_dtype"
|
|||
|
" arg.")
|
|||
|
if output_avals is not None:
|
|||
|
return output_avals, effects
|
|||
|
|
|||
|
def is_fully_known_shape(s):
|
|||
|
return s.rank is not None and all([d is not None for d in s])
|
|||
|
|
|||
|
if all(is_fully_known_shape(s)
|
|||
|
for s in concrete_function_flat_tf.output_shapes):
|
|||
|
avals_from_tf = tuple(
|
|||
|
# We convert to JAX type, and canonicalize to 32-bit if necessary
|
|||
|
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
|
|||
|
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
|
|||
|
concrete_function_flat_tf.output_shapes))
|
|||
|
return avals_from_tf, effects
|
|||
|
|
|||
|
msg = ("call_tf cannot call functions whose output has dynamic shape. "
|
|||
|
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
|
|||
|
"Consider using the `output_shape_dtype` argument to call_tf. "
|
|||
|
"\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
|
|||
|
" for a discussion.")
|
|||
|
raise ValueError(msg)
|
|||
|
|
|||
|
|
|||
|
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
|
|||
|
|
|||
|
|
|||
|
def _call_tf_lowering(
|
|||
|
ctx: mlir.LoweringRuleContext,
|
|||
|
*args_op,
|
|||
|
platform,
|
|||
|
function_flat_tf,
|
|||
|
args_flat_sig_tf,
|
|||
|
has_side_effects,
|
|||
|
ordered,
|
|||
|
call_tf_graph,
|
|||
|
output_avals,
|
|||
|
**_,
|
|||
|
):
|
|||
|
# We use the same TF lowering device as for the embedding JAX computation.
|
|||
|
# One example when this is needed is when the code refers to variables on one
|
|||
|
# device. Or, for sharding annotations (only supported on TPU).
|
|||
|
|
|||
|
if platform in ["cpu", "tpu"]:
|
|||
|
tf_platform = platform.upper()
|
|||
|
elif platform == "cuda":
|
|||
|
tf_platform = "GPU"
|
|||
|
else:
|
|||
|
raise ValueError("platform {platform} not supported")
|
|||
|
|
|||
|
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
|
|||
|
|
|||
|
captured_inputs = []
|
|||
|
if concrete_function_flat_tf.captured_inputs:
|
|||
|
# The function uses either captured variables or tensors.
|
|||
|
msg = (
|
|||
|
"call_tf works best with a TensorFlow function that does not capture "
|
|||
|
"variables or tensors from the context. "
|
|||
|
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
|
|||
|
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
|
|||
|
logging.warning(msg)
|
|||
|
for inp in concrete_function_flat_tf.captured_inputs:
|
|||
|
if inp.dtype == tf.resource: # A variable; lookup by handle
|
|||
|
inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
|
|||
|
assert len(inp_vars) == 1, f"Found {inp_vars}"
|
|||
|
captured_inputs.append(inp_vars[0])
|
|||
|
else:
|
|||
|
captured_inputs.append(inp)
|
|||
|
|
|||
|
captured_ops = tuple(
|
|||
|
mlir.ir_constant(np.asarray(inp), canonicalize_types=False)
|
|||
|
for inp in captured_inputs
|
|||
|
)
|
|||
|
|
|||
|
if call_tf_graph:
|
|||
|
with jax2tf_internal.inside_call_tf():
|
|||
|
return emit_tf_embedded_graph_custom_call(
|
|||
|
ctx,
|
|||
|
concrete_function_flat_tf,
|
|||
|
tuple(args_op) + captured_ops,
|
|||
|
has_side_effects,
|
|||
|
ordered,
|
|||
|
output_avals,
|
|||
|
)
|
|||
|
|
|||
|
def convert_to_spec(x):
|
|||
|
if isinstance(x, tf.TensorSpec):
|
|||
|
return x
|
|||
|
else:
|
|||
|
return tf.TensorSpec.from_tensor(x)
|
|||
|
|
|||
|
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
|
|||
|
|
|||
|
with jax2tf_internal.inside_call_tf():
|
|||
|
# When the TF computation uses variables on a particular device, we must
|
|||
|
# get_compiler_ir for that exact device.
|
|||
|
tf_device_name = f"/device:{tf_platform}:0"
|
|||
|
try:
|
|||
|
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
|
|||
|
stage="hlo_serialized", device_name=tf_device_name)
|
|||
|
except Exception as e:
|
|||
|
msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
|
|||
|
"\ncall_tf can used " +
|
|||
|
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
|||
|
"compilable functions with static output shapes.\n" +
|
|||
|
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
|
|||
|
"\n\nCaught TensorFlow exception: " + str(e))
|
|||
|
raise ValueError(msg) from e
|
|||
|
|
|||
|
xla_comp = xla_client.XlaComputation(func_tf_hlo)
|
|||
|
|
|||
|
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
|||
|
def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
|
|||
|
if not res_shape.is_static():
|
|||
|
msg = ("Compiled TensorFlow function has dynamic output shape " +
|
|||
|
f"{res_shape}. call_tf can used " +
|
|||
|
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
|||
|
"compilable functions with static output shapes. " +
|
|||
|
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
|
|||
|
raise ValueError(msg)
|
|||
|
|
|||
|
res_dtype = res_shape.numpy_dtype()
|
|||
|
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
|
|||
|
return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
|
|||
|
|
|||
|
result_shape = xla_comp.program_shape().result_shape()
|
|||
|
if not result_shape.is_tuple():
|
|||
|
# TF does not wrap singletons as tuples, but JAX expects tuples because
|
|||
|
# call_tf is a multiple_results primitive.
|
|||
|
result_shapes = (result_shape,)
|
|||
|
else:
|
|||
|
result_shapes = result_shape.tuple_shapes() # type: ignore
|
|||
|
|
|||
|
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
|
|||
|
|
|||
|
submodule = mlir.xla_computation_to_mlir_module(xla_comp)
|
|||
|
symtab = ir.SymbolTable(submodule.operation)
|
|||
|
callee_result_types = symtab["main"].type.results
|
|||
|
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
|||
|
f"call_tf_{function_flat_tf.name}",
|
|||
|
submodule)
|
|||
|
call = func_dialect.CallOp(callee_result_types,
|
|||
|
ir.FlatSymbolRefAttr.get(fn),
|
|||
|
tuple(args_op) + captured_ops)
|
|||
|
if result_shape.is_tuple():
|
|||
|
flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
|
|||
|
for i in range(len(result_shapes))]
|
|||
|
else:
|
|||
|
flat_results = call.results
|
|||
|
|
|||
|
if ordered:
|
|||
|
raise NotImplementedError(
|
|||
|
"ordered=True is not supported in the jitted context without"
|
|||
|
" `call_tf_graph=True`"
|
|||
|
)
|
|||
|
|
|||
|
outputs = []
|
|||
|
for op, res_aval, res_shape in zip(flat_results, result_avals,
|
|||
|
result_shapes):
|
|||
|
if res_aval.dtype != res_shape.numpy_dtype():
|
|||
|
op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
|
|||
|
outputs.append(op)
|
|||
|
return outputs
|
|||
|
|
|||
|
|
|||
|
def _register_call_lowering(platform):
|
|||
|
mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
|
|||
|
platform=platform),
|
|||
|
platform=platform)
|
|||
|
for platform in ("cpu", "cuda", "tpu"):
|
|||
|
_register_call_lowering(platform)
|
|||
|
|
|||
|
# Support the call_tf under jax2tf.convert in eager mode
|
|||
|
def _jax2tf_call_tf(*args: TfVal,
|
|||
|
callable_flat_tf: Callable,
|
|||
|
**_) -> TfVal:
|
|||
|
with jax2tf_internal.inside_call_tf():
|
|||
|
res_tf_flat = callable_flat_tf(*args)
|
|||
|
return res_tf_flat
|
|||
|
|
|||
|
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
|
|||
|
|
|||
|
|
|||
|
def emit_tf_embedded_graph_custom_call(
|
|||
|
ctx: mlir.LoweringRuleContext,
|
|||
|
concrete_function_flat_tf,
|
|||
|
operands: Sequence[ir.Value],
|
|||
|
has_side_effects,
|
|||
|
ordered,
|
|||
|
output_avals,
|
|||
|
):
|
|||
|
"""Emits a custom call referencing a tf.Graph embedding of the TF function.
|
|||
|
|
|||
|
All call_tf called function information is stored in tf.metadata.
|
|||
|
This includes:
|
|||
|
(1) The called function name: This name will be used by the runtime to execute
|
|||
|
the callback.
|
|||
|
(2) The called function index in the XLACallModule `function_list` attribute.
|
|||
|
"""
|
|||
|
call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
|
|||
|
if call_tf_concrete_function_list is None:
|
|||
|
raise ValueError(
|
|||
|
"call_tf_graph=True only support exporting by jax2tf.convert currently."
|
|||
|
)
|
|||
|
called_index = add_to_call_tf_concrete_function_list(
|
|||
|
concrete_function_flat_tf, call_tf_concrete_function_list)
|
|||
|
call_target_name = "tf.call_tf_function"
|
|||
|
tf_backend_config = {
|
|||
|
"has_token_input_output": ir.BoolAttr.get(ordered),
|
|||
|
"called_index": mlir.i64_attr(called_index),
|
|||
|
}
|
|||
|
result_avals = output_avals if output_avals is not None else tuple()
|
|||
|
|
|||
|
operands = list(operands)
|
|||
|
result_types = list(
|
|||
|
util.flatten([mlir.aval_to_ir_types(aval) for aval in result_avals])
|
|||
|
)
|
|||
|
if ordered:
|
|||
|
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect)[0])
|
|||
|
result_types.insert(0, mlir.token_type()[0])
|
|||
|
|
|||
|
custom_call = hlo.CustomCallOp(
|
|||
|
result_types,
|
|||
|
operands,
|
|||
|
call_target_name=ir.StringAttr.get(call_target_name),
|
|||
|
has_side_effect=ir.BoolAttr.get(has_side_effects),
|
|||
|
api_version=mlir.i32_attr(2),
|
|||
|
called_computations=ir.ArrayAttr.get([]),
|
|||
|
backend_config=ir.StringAttr.get(""),
|
|||
|
)
|
|||
|
# Store TF metadata in unregistered attribute
|
|||
|
custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
|
|||
|
tf_backend_config
|
|||
|
)
|
|||
|
|
|||
|
results = list(custom_call.results)
|
|||
|
if ordered:
|
|||
|
token = results.pop(0)
|
|||
|
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: (token,)}))
|
|||
|
|
|||
|
return results
|
|||
|
|
|||
|
|
|||
|
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: List[Any]) -> int:
|
|||
|
func_name = concrete_tf_fn.function_def.signature.name
|
|||
|
assert func_name not in [f.function_def.signature.name for f in call_tf_concrete_function_list]
|
|||
|
called_index = len(call_tf_concrete_function_list)
|
|||
|
call_tf_concrete_function_list.append(concrete_tf_fn)
|
|||
|
return called_index
|