# Copyright 2020 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides JAX and TensorFlow interoperation APIs.""" from functools import partial import contextlib import math import operator import os import re import threading from typing import ( Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast) import warnings from absl import logging import numpy as np import jax from jax import lax from jax import config from jax import custom_derivatives from jax import random from jax import numpy as jnp from jax import tree_util from jax import sharding from jax.experimental import maps from jax.experimental.jax2tf import shape_poly from jax.experimental.jax2tf import impl_no_xla from jax.experimental.jax2tf import jax_export from jax.interpreters import xla from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api from jax._src import api_util from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import op_shardings from jax._src import sharding_impls from jax._src import pjit from jax._src import prng from jax._src import random as random_internal from jax._src import source_info_util from jax._src import util from jax._src.interpreters import ad from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg from jax._src.lax import slicing as lax_slicing from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client from jax._src.numpy.ufuncs import logaddexp import tensorflow as tf # type: ignore[import] # These don't have public equivalents. # pylint: disable=g-direct-tensorflow-import from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import] from tensorflow.core.framework import attr_value_pb2 # type: ignore[import] try: from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import] except ModuleNotFoundError: # This can be removed when TF 2.10 support is no longer needed. from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import] from tensorflow.python.framework import ops as tf_ops # type: ignore[import] from tensorflow.python.eager import context as tf_context # type: ignore[import] # pylint: enable=g-direct-tensorflow-import NameStack = source_info_util.NameStack PolyShape = shape_poly.PolyShape DType = Any # A temporary internal flag, to enable the wrapping of jax.jit functions # with tf.function(jit_compile=True). See #7389. This change has triggered a # number of failures in TF. We keep this until we are confident that it does # not create problems. # TODO(b/207464757): figure out why this change breaks test _WRAP_JAX_JIT_WITH_TF_FUNCTION = False # The scope name need to be a valid TensorFlow name. See # https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731 _VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$") _INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]") map = util.safe_map zip = util.safe_zip def _sanitize_scope_name(name): scope_name = _INVALID_SCOPE_CHAR.sub("_", name) if not _VALID_SCOPE_REGEX.match(scope_name): scope_name = f".{scope_name}" return scope_name # A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any PrecisionType = int # Enum xla_data.PrecisionConfig.Precision def _is_tfval(v: TfVal) -> bool: if isinstance(v, (tf.Tensor, tf.Variable)): return True try: # Include all convertible types, even if not supported on accelerators. with tf.device("CPU"): tf.constant(v) return True except: return False # The implementation rules for primitives. The rule will be called with the # arguments (TfVal) and must return TfVal (or a sequence thereof, # if primitive.multiple_results). The exception are primarily the # control-flow primitives. tf_impl: Dict[core.Primitive, Callable[..., Any]] = {} # Some primitive implementation rules need the abstract values of arguments # and the results. This is the case for the primitives implemented using # _convert_jax_impl and those that need to adjust the shape of the outputs # due to missing TF shape inference rules for TFXLA ops. The rules for these # primitives should be added to `tf_impl_with_avals`. # The abstract value are passed to the implementation as two special kwargs # `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a # core.ShapedArray, or a tuple thereof when primitive.multiple_results). tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {} # XLA is not linked in all environments when converting a primitive. If this is # the case, we first search for implementation rules for primitives in the # following map. These implementations are workarounds, making use of TF ops # that do work when XLA is not linked in. tf_impl_no_xla = impl_no_xla.tf_impl_no_xla # In order to ensure that JAX picks up the proper user-frame for source # locations we will register the TensorFlow source path as an internal # path with source_info_util. The typical stack when a JAX primitive # conversion happens is: # jax2tf.process_primitive (top of stack) # jax tracing machinery ... # tf.custom_gradient machinery ... # jax2tf.converted_fun # tf function machinery ... # user code invokes the converted function on TF tensors # # We need to skip over not only JAX internal frames, but TF internal frames # also. # We register the TensorFlow source path lazily _has_registered_tf_source_path = False class _ThreadLocalState(threading.local): def __init__(self): # XLA is not linked in all environments; when converting a primitive, if this # variable is disabled, we try harder to use only standard TF ops if they are # applicable to the concrete use case; if the resulting conversion path ends up # requiring a TFXLA operation, an exception is thrown instead. self.enable_xla = True # Keep track if we are inside a call_tf. In that context we disable the # safety check that we are not inside JAX transformations. self.inside_call_tf = False # Maps dimension variables to TF expressions, for non-native lowering self.shape_env: Sequence[Tuple[str, TfVal]] = () # Whether to actually include XLA op metadata in the generated TF ops # TODO(b/189306134): implement support for XLA metadata self.include_xla_op_metadata = False # A cache for the tf.convert_to_tensor for constants. We try to preserve # sharing for constants, to enable tf.Graph to take advantage of it. # See https://github.com/google/jax/issues/7992. self.constant_cache = None # None means that we don't use a cache. We # may be outside a conversion scope. # A cache for the outside tf name_scope when the converted # function is running. We will add this as the prefix to the generated tf op # name. For example, the tf op name will be like # "{tf_outer_name_scope}/JAX_NAME_STACKS" self.tf_outer_name_scope = "" # A dict collecting all tf concrete_functions called by stablehlo.custom_call # This is used only by native serialization (unlike all the other # thread-local state). self.call_tf_concrete_function_list: Optional[List[Any]] = None _thread_local_state = _ThreadLocalState() def _get_current_name_stack() -> Union[NameStack, str]: return source_info_util.current_name_stack() @contextlib.contextmanager def inside_call_tf(): # Set the inside_call_tf flag for a context. prev = _thread_local_state.inside_call_tf _thread_local_state.inside_call_tf = True try: yield finally: _thread_local_state.inside_call_tf = prev def get_thread_local_state_call_tf_concrete_function_list() -> ( Optional[List[Any]] ): return _thread_local_state.call_tf_concrete_function_list @partial(api_util.api_hook, tag="jax2tf_convert") def convert(fun_jax: Callable, *, polymorphic_shapes=None, with_gradient=True, enable_xla=True, # TODO(necula): remove the experimental flag experimental_native_lowering="default", native_serialization="default", native_serialization_platforms=(), native_serialization_strict_checks=True) -> Callable: """Allows calling a JAX function from a TensorFlow program. See [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) for more details about usage and common problems. Args: fun_jax: target JAX function to be called. Its arguments and return value should be JAX arrays, or nested standard Python containers (tuple/list/dict) thereof (pytrees). polymorphic_shapes: Specifies input shapes to be treated polymorphically during lowering. .. warning:: The shape-polymorphic lowering is an experimental feature. It is meant to be sound, but it is known to reject some JAX programs that are shape polymorphic. The details of this feature can change. It should be `None` (all arguments are monomorphic), a single PolyShape or string (applies to all arguments), or a tuple/list of the same length as the function arguments. For each argument the shape specification should be `None` (monomorphic argument), or a Python object with the same pytree structure as the argument. See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A shape specification for an array argument should be an object `PolyShape(dim0, dim1, ..., dimn)` where each `dim` is a dimension specification: a positive integer denoting a monomorphic dimension of the given size, or a string denoting a dimension variable assumed to range over non-zero dimension sizes, or the special placeholder string "_" denoting a monomorphic dimension whose size is given by the actual argument. As a shortcut, an Ellipsis suffix in the list of dimension specifications stands for a list of "_" placeholders. For convenience, a shape specification can also be given as a string representation, e.g.: "batch, ...", "batch, height, width, _", possibly with surrounding parentheses: "(batch, ...)". The lowering fails if it cannot ensure that the it would produce the same sequence of TF ops for any non-zero values of the dimension variables. polymorphic_shapes are only supported for positional arguments; shape polymorphism is not supported for keyword arguments. See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode TensorFlow AD is supported for the output TensorFlow function, and the value of the gradient will be JAX-accurate. enable_xla: if set (default), use the simplest conversion and use XLA TF ops when necessary. These ops are known to create issues for the TFLite and TFjs converters. For those cases, unset this parameter so the lowering tries harder to use non-XLA TF ops to lower the function and aborts if this is not possible. Cannot be set to `False` when using `native_serialization`. native_serialization: serialize the JAX function natively to StableHLO with compatibility guarantees. This makes it easier to have confidence that the code executed when calling this function from TensorFlow is exactly the same as JAX would run natively. The "default" value defers to `False` if `enable_xla` is set to `False` or to the configuration flag `--jax2tf_default_native_serialization` otherwise. Native serialization cannot be used with `enable_xla=False`. native_serialization_platforms: In conjunction with `native_serialization`, specify the platform(s) for which to lower the code. Must be a tuple of strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default (empty tuple), specifies the JAX default backend on the machine where the lowering is done. native_serialization_strict_checks: In conjunction with `native_serialization`, enable the following checks: (A) the lowered computation is executed on a platform for which it was lowered; (B) the serialized computation contains only custom calls with targets that are guaranteed to be stable, (more to come). Returns: A version of `fun_jax` that expects TfVals as arguments (or tuple/lists/dicts thereof), and returns TfVals as outputs, and uses only TensorFlow ops and thus can be called from a TensorFlow program. """ if native_serialization == "default": if not enable_xla: native_serialization = False else: native_serialization = config.jax2tf_default_native_serialization if native_serialization and not enable_xla: raise ValueError( "native_serialization is not supported with enable_xla=False") if native_serialization_platforms: if not native_serialization: warnings.warn( "using native_serialization_platforms without native_serialization. " "The parameter will have no effect, since the same code is serialized " "for all platforms without native_serialization.") if (not isinstance(native_serialization_platforms, (list, tuple)) or not all(p in ["tpu", "cpu", "gpu"] for p in native_serialization_platforms)): raise ValueError( "native_serialization_platforms must be a sequence " "containing a subset of {'cpu', 'gpu', 'tpu'}. " f"Got: {native_serialization_platforms}") native_serialization_platforms = tuple(native_serialization_platforms) if len(native_serialization_platforms) > 1: raise NotImplementedError( "native_serialization_platforms is not yet implemented for multiple platforms") api.check_callable(fun_jax) def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: # TODO: is there a better way to check if we are inside a transformation? if not core.trace_state_clean() and not _thread_local_state.inside_call_tf: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + f"Trace state: {core.thread_local_state.trace_state.trace_stack}") global _has_registered_tf_source_path if not _has_registered_tf_source_path: source_info_util.register_exclusion(os.path.dirname(tf.__file__)) _has_registered_tf_source_path = True def shape_and_dtype_tf(a: TfVal) -> Tuple[Sequence[Optional[int]], DType]: # The shape and JAX dtype for a TF argument tf_arg_shape = np.shape(a) # Fix the shape for TF1 tf_arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape) _, a_jax_dtype = _tfval_to_tensor_jax_dtype(a) return tf_arg_shape, a_jax_dtype args_specs = jax_export.poly_specs(args_tf, polymorphic_shapes=polymorphic_shapes, get_shape_and_dtype=shape_and_dtype_tf) # The polymorphic_shapes argument refers to positional arguments only. # We assume None for the kwargs. kwargs_specs = jax_export.poly_specs(kwargs_tf, polymorphic_shapes=None, get_shape_and_dtype=shape_and_dtype_tf) combined_args_tf = (args_tf, kwargs_tf) args_flat_tf: Sequence[TfVal] args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf) args_flat_tf = tuple( map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf)) impl: SerializationImpl if native_serialization: impl = NativeSerializationImpl( fun_jax, args_specs=args_specs, kwargs_specs=kwargs_specs, native_serialization_platforms=native_serialization_platforms, native_serialization_strict_checks=native_serialization_strict_checks) else: impl = GraphSerializationImpl( fun_jax, args_specs=args_specs, kwargs_specs=kwargs_specs, args_flat_tf=args_flat_tf, enable_xla=enable_xla) try: impl.before_conversion() outs_tree: tree_util.PyTreeDef = None # type: ignore if with_gradient: @tf.custom_gradient def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: nonlocal outs_tree outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf) return (tuple(outs_tf), _make_custom_gradient_fn_tf( impl=impl, args_tf=args_flat_tf, outs_avals=outs_avals, outs_tf=outs_tf)) outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf) else: outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf) message = ("The jax2tf-converted function does not support gradients. " "Use `with_gradient` parameter to enable gradients") # We use PreventGradient, which is propagated through a SavedModel. outs_flat_tf = [ tf.raw_ops.PreventGradient(input=o, message=message) for o in outs_tf ] finally: impl.after_conversion() outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf] out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf) return out_tf return converted_fun_tf class SerializationImpl: """Implementation details for jax2tf serialization. Abstract superclass for subclassing. """ def before_conversion(self): """Called in the resulting TF function, before any other method. Useful to set any global context.""" raise NotImplementedError def after_conversion(self): """Called in the resulting TF function, after conversion is done. Useful to restore any global context set up by `before_conversion`.""" raise NotImplementedError def run_fun_tf(self, args_flat_tf: Sequence[TfVal] ) -> Tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]: """Runs the resulting TF function. Args: args_flat_tf: a flat tuple of tf.Tensor arguments Returns: a tuple with: outs_tfs: a flat tuple of tf.Tensor results outs_avals: a flat tuple of JAX abstract values for the underlying JAX function. outs_tree: the PyTreeDef for the outputs """ raise NotImplementedError def run_vjp_fun_tf(self, vjp_args_flat_tf: Sequence[TfVal], outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]: """Runs the VJP function as a TF function. Args: vjp_args_flat_tf: the flattened sequence of tf.Tensor, including the primal arguments followed by the output cotangents. outs_avals: the flattened primal outputs avals Returns: the flattened sequence of input cotangents. """ raise NotImplementedError class NativeSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, args_specs, kwargs_specs, native_serialization_platforms: Sequence[str], native_serialization_strict_checks: bool): self.fun_jax = fun_jax self.args_specs = args_specs self.kwargs_specs = kwargs_specs self.native_serialization_strict_checks = native_serialization_strict_checks if native_serialization_platforms: self.lowering_platform: Optional[str] = native_serialization_platforms[0] else: self.lowering_platform = None def before_conversion(self): _prev_func_list = _thread_local_state.call_tf_concrete_function_list _thread_local_state.call_tf_concrete_function_list = [] def _restore_context(): _thread_local_state.call_tf_concrete_function_list = _prev_func_list self._restore_context = _restore_context self.exported = jax_export.export( self.fun_jax, lowering_platform=self.lowering_platform, strict_checks=self.native_serialization_strict_checks )(*self.args_specs, **self.kwargs_specs) def after_conversion(self): self._restore_context() def run_fun_tf(self, args_flat_tf: Sequence[TfVal] ) -> Tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]: results = _run_exported_as_tf(args_flat_tf, self.exported) return results, tuple(self.exported.out_avals), self.exported.out_tree def run_vjp_fun_tf(self, vjp_args_flat_tf: Sequence[TfVal], outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]: del outs_avals exported_vjp = self.exported.vjp() vjp_args_flat_tf = tuple(tf.identity(arg, f"jax2tf_arg_{arg_idx}") for arg_idx, arg in enumerate(vjp_args_flat_tf)) in_cts_flat = _run_exported_as_tf(vjp_args_flat_tf, exported_vjp) return tuple(tf.identity(arg, "jax2tf_out") for arg in in_cts_flat) class GraphSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, args_specs, kwargs_specs, args_flat_tf: Sequence[TfVal], enable_xla: bool): self.fun_jax = fun_jax self.args_specs = args_specs self.kwargs_specs = kwargs_specs self.enable_xla = enable_xla fun_name = getattr(fun_jax, "__name__", "unknown") name_stack = util.wrap_name(fun_name, "jax2tf") self.name_stack = name_stack self.args_flat_tf = args_flat_tf def before_conversion(self): prev_enable_xla = _thread_local_state.enable_xla prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata prev_tf_outer_name_scope = _thread_local_state.tf_outer_name_scope def _restore_context(): _thread_local_state.enable_xla = prev_enable_xla _thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata _thread_local_state.tf_outer_name_scope = prev_tf_outer_name_scope _thread_local_state.shape_env = () self._restore_context = _restore_context _thread_local_state.enable_xla = self.enable_xla # TODO(b/189306134): implement support for XLA metadata _thread_local_state.include_xla_op_metadata = False _thread_local_state.tf_outer_name_scope = tf.get_current_name_scope() assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}" args_specs_flat, self.in_tree = tree_util.tree_flatten( (self.args_specs, self.kwargs_specs)) self.args_avals_flat = tuple( map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat)) dim_vars = shape_poly.all_dim_vars(self.args_avals_flat) dim_values, _ = _interpret_fun_jax( partial(shape_poly.compute_dim_vars_from_arg_shapes, self.args_avals_flat, args_kwargs_tree=self.in_tree), self.args_flat_tf, self.args_avals_flat, self.name_stack) _thread_local_state.shape_env = zip(dim_vars, dim_values) fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree) # out_tree_thunk will be ready after we call run_fun_tf below. self.fun_flat_jax = fun_flat_jax self.out_tree_thunk = out_tree_thunk def after_conversion(self): self._restore_context() def run_fun_tf(self, args_flat_tf: Sequence[TfVal] ) -> Tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]: outs_tf, outs_avals = _interpret_fun_jax( self.fun_flat_jax, args_flat_tf, self.args_avals_flat, self.name_stack, fresh_constant_cache=True) return outs_tf, outs_avals, self.out_tree_thunk() def run_vjp_fun_tf(self, vjp_args_flat_tf: Sequence[TfVal], outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]: def fun_vjp_jax(*args_and_out_cts_flat_jax): # Takes a flat list of primals and output cotangents args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(self.args_avals_flat)]) _, pullback_jax = jax.vjp(self.fun_flat_jax, *args_flat_jax) return pullback_jax(out_cts_flat_jax) vjp_in_avals = tuple(self.args_avals_flat) + tuple(outs_avals) vjp_polymorphic_shapes = tuple(str(a.shape) # Note: may be _DimExpr, not just DimVar for a in vjp_in_avals) # type: ignore return convert( fun_vjp_jax, with_gradient=False, polymorphic_shapes=vjp_polymorphic_shapes, native_serialization=False)(*vjp_args_flat_tf) def dtype_of_val(val: TfVal) -> DType: """Computes the TensorFlow dtype using JAX's typing rules. If the value is a tf.Tensor, it starts with its dtype. If the value is a constant it uses JAX to infer its dtype. The resulting dtype follows the JAX type inference rules, and depends on the value of the JAX_ENABLE_X64 flag. See README.md for how 64-bit values are treated. """ tval, _ = _tfval_to_tensor_jax_dtype(val) return tval.dtype @partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes") def eval_polymorphic_shape(fun_jax: Callable, *, polymorphic_shapes=None) -> Callable: """Evaluates the output shape in presence of shape polymorphism. This is done without lowering or executing the function, same as for `jax.eval_shape`. Args: fun_jax: target JAX function to be called. Its arguments and return value should be JAX arrays, or nested standard Python containers (tuple/list/dict) thereof (pytrees). polymorphic_shapes: Specifies input shapes to be treated polymorphically during shape evaluation. See discussion for `jax2tf.convert`. .. warning:: The shape-polymorphic lowering is an experimental feature. Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values with `.shape` and `.dtype` attributes) corresponding to the inputs for `fun_jax`, and returns a tuple with: * the jax.ShapeDtypeStruct corresponding to the result, as for `jax.eval_shape`. The shape may contain symbolic dimension expressions. * the value that can be passed to `polymorphic_shapes` for a subsequent call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`. For example: >>> import jax >>> from jax.experimental import jax2tf >>> from jax import numpy as jnp >>> >>> f = lambda A, x: jnp.sin(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x) >>> print(out_spec.shape) ("a", "c") >>> print(out_poly_shape) (a, c) >>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec) >>> print(res_poly_shape) (c, a) """ def do_eval_polymorphic_shape(*args_specs) -> Any: args_poly_specs = jax_export.poly_specs( args_specs, polymorphic_shapes=polymorphic_shapes) res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs) # TODO(necula): For now we export the polymorphic shapes using `str`. res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec) return res_poly_spec, res_polymorphic_shape return do_eval_polymorphic_shape # Internals def flatten_fun_jax(fun_jax: Callable, in_tree, ) -> Tuple[Callable, Callable]: """Wraps the function to take a (flat) list of positional args. jax2tf works better and is simpler when the JAX function takes and returns just a tuple of values (no pytrees, no kwargs). This is in part because jax.vjp does not support kwargs and we can only set tf.custom_gradient on functions with flat arguments and results Returns: * the wrapped JAX function taking and returning a flat list of arguments * a thunk that can be called after the wrapped function has been called to return the output pytree. """ out_tree_ref = None def fun_flat_jax(*args_flat_jax): tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax) tree_res = fun_jax(*tree_args, **tree_kwargs) res_flat_jax, out_tree = tree_util.tree_flatten(tree_res) nonlocal out_tree_ref assert out_tree_ref is None or out_tree_ref == out_tree out_tree_ref = out_tree return res_flat_jax return fun_flat_jax, lambda: out_tree_ref def preprocess_arg_tf(arg_idx: int, arg_tf: TfVal) -> TfVal: """Pre-processes the TF args. Returns: a tuple with the pre-processed TF arg, the TF shape, and the JAX dtype. """ if not _is_tfval(arg_tf): msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) # May cast the args_flat to JAX types, using JAX's interpretation # of types of constants. arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf) # Name input tensors; do this after we have cast the arguments arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}") return arg_tf def _make_custom_gradient_fn_tf(*, impl: SerializationImpl, args_tf: Sequence[TfVal], outs_avals: Sequence[core.ShapedArray], outs_tf: Sequence[TfVal]): """Prepares the TF function to be used with tf.custom_gradient. Args: impl: the serialization implementation details args_tf: the flattened TF arguments of the primal function outs_avals: the flattened output JAX abstract values of the primal function outs_tf: the flattened TF outputs of the primal function """ def grad_fn_tf(*out_cts_flat_tf: TfVal, variables=None): if variables: raise ValueError( "Unexpected variables used in forward pass. " "This should not happen for first-order differentiation. " f"{variables=}") # TODO: enable higher-order gradients with tf.name_scope("jax2tf_vjp"): def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal): # If the primal function has outputs of integer or bool types, and if we are # under a tf.function context, then TF will pass None in _out_cts_flat # in place of these values. We should change these to float0 or # else JAX gets unhappy. See issue #6975. if out_ct_tf is not None: return out_ct_tf assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}" # Note that out_ct_aval.shape contains dimension variable from the # primal function scope. We use tf.zeros_like to make a 0 of the right shape. return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0) out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf)) vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf in_cts_flat = impl.run_vjp_fun_tf(vjp_args_flat_tf, outs_avals) # We do not need to fix the in_cts because the TF gradient machinery # will adjust the unconnected gradients and those for integer types. return in_cts_flat return grad_fn_tf @contextlib.contextmanager def _extended_name_stack(extra_name_stack: Optional[str]): name_ctx = (source_info_util.extend_name_stack(extra_name_stack) if extra_name_stack else contextlib.nullcontext()) with name_ctx: yield return def _interpret_fun_jax( fun_jax: Callable, args_tf: Sequence[TfVal], args_avals: Sequence[core.ShapedArray], extra_name_stack: Optional[str], fresh_constant_cache: bool = False, ) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]: with core.new_base_main(TensorFlowTrace) as main: # type: ignore subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals) with _extended_name_stack(extra_name_stack): with core.new_sublevel(): out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, fresh_constant_cache=fresh_constant_cache) del main return util.unzip2(out_vals) def _run_exported_as_tf(args_flat_tf: Sequence[TfVal], exported: jax_export.Exported, ) -> Sequence[TfVal]: """Runs the `exported` as an XlaCallModule TF op. Returns: the flattened tuple of results. """ args_avals = exported.in_avals # TF values may be integer types for float0 def _convert_value(val, aval): # Check the shape assert all(d_aval == d_val for d_aval, d_val in zip(aval.shape, val.shape) if core.is_constant_dim(d_aval)), (aval, val) conversion_dtype = _to_tf_dtype(aval.dtype) if conversion_dtype != aval.dtype: return tf.cast(val, conversion_dtype) else: return val args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals)) out_shapes_tf = tuple( tuple(d if core.is_constant_dim(d) else None for d in out_aval.shape) for out_aval in exported.out_avals) out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals) kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] call_module_attrs = dict( version=exported.xla_call_module_version, Tout=out_types, Sout=out_shapes_tf, function_list=[ concrete_fn.function_def.signature.name for concrete_fn in _thread_local_state.call_tf_concrete_function_list ] if _thread_local_state.call_tf_concrete_function_list is not None else [], ) if exported.xla_call_module_version >= 3: if exported.strict_checks: call_module_attrs["platforms"] = (exported.lowering_platform.upper(),) else: call_module_attrs["platforms"] = () # No platform checking if logging.vlog_is_on(3): # We already logged the MLIR module when we exported it. logging.vlog(3, "XlaCallModule %s", str(call_module_attrs)) call_module_attrs["module"] = exported.mlir_module_serialized # Apply the shardings on arguments and results for pjit. This is redundant # because the mlir_module_text will already contain the shardings, but it # makes it easier for tools like the TPU inference converter to see the # sharding without digging into the `module` attribute of the `XlaCallModule` # op, in the same way as it is done for the legacy jax2tf conversion. # Do not apply XlaSharding for REPLICATED, on inputs and outputs. # This is an agreed convention, and also improves usability under TF eager. # See b/255511660. if exported.in_shardings is not None: args_flat_tf = tuple( map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), kept_args_flat_tf, kept_args_avals, exported.in_shardings)) res = tfxla.call_module(args_flat_tf, **call_module_attrs) # TODO(b/278940799): Replace the TF v1 API with public TF2 API. # Add the custom call tf.function into the default graph, so those functions # will be available during tf.SavedModel.save. if _thread_local_state.call_tf_concrete_function_list is not None: for concrete_fn in _thread_local_state.call_tf_concrete_function_list: tf.compat.v1.get_default_graph()._add_function_recursive( concrete_fn._inference_function ) if exported.out_shardings is not None: res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), res, exported.out_avals, exported.out_shardings)) res = tuple(map(_convert_value, res, exported.out_avals)) return res def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun, in_vals: Sequence[TfVal], fresh_constant_cache: bool = False ) -> Sequence[Tuple[TfVal, core.ShapedArray]]: try: prev_constant_cache = _thread_local_state.constant_cache # Start a new cache, so that we don't share constants across tf.function # boundaries. if fresh_constant_cache: _thread_local_state.constant_cache = {} else: prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set() out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ fun.call_wrapped(*in_vals) finally: if (not fresh_constant_cache and prev_constant_cache is not None and _WRAP_JAX_JIT_WITH_TF_FUNCTION): newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys # Delete the newly added keys for k in newly_added_keys: del prev_constant_cache[k] _thread_local_state.constant_cache = prev_constant_cache return out_vals def _convert_jax_impl(impl_jax: Callable, *, multiple_results=True, with_physical_avals=False, extra_name_stack: Optional[str] = None) -> Callable: """Convert the JAX implementation of a primitive. Args: impl_jax: typically the impl-rule for a primitive, with signature `(*args_jax: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements a primitive in terms of other primitives. multiple_results: whether `impl_jax` returns a sequence of results. extra_name_stack: additional element to add to the name stack for the converted ops. Returns: a function with signature `(*args_tf: TfVal, _in_avals, _out_aval, **kwargs) -> Sequence[TfVal]`. """ def wrapped_tf(*args_tf: TfVal, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, **kwargs) -> Sequence[TfVal]: if with_physical_avals: _in_avals = map(_jax_physical_aval, _in_avals) _out_aval = _jax_physical_aval(_out_aval) # We wrap the impl_jax to always return a tuple of results. def impl_multiple_results_jax(*args_jax): results_jax = impl_jax(*args_jax, **kwargs) return results_jax if multiple_results else [results_jax] results_tf, _ = _interpret_fun_jax( impl_multiple_results_jax, args_tf, _in_avals, extra_name_stack) return results_tf if multiple_results else results_tf[0] return wrapped_tf @lu.transformation def _interpret_subtrace(main: core.MainTrace, in_avals: Sequence[core.ShapedArray], *in_vals: TfVal): trace = TensorFlowTrace(main, core.cur_sublevel()) in_tracers = tuple( TensorFlowTracer(trace, val, aval) for val, aval in zip(in_vals, in_avals)) outs = yield in_tracers, {} # type: Sequence[TfVal] out_tracers: Iterable[TensorFlowTracer] = ( map(trace.full_raise, outs)) # type: ignore out_vals_with_avals: Sequence[Tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) yield out_vals_with_avals def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, extra_name_stack: Optional[str], fresh_constant_cache: bool = True) -> Sequence[TfVal]: """Evaluates a Jaxpr with tf.Tensor arguments. This is most often used as the body of a tf.function, or tf.switch_case, in which case it should use a fresh constant cache. The output is a sequence of TfVal, suitable for use with TF. """ outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr), args_tf, jaxpr.in_avals, extra_name_stack, fresh_constant_cache=fresh_constant_cache) return outs_tf def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray: """Converts JAX avals from logical to physical, if relevant. JAX might have avals whose logical vs physical shape/dtype may differ, and only the physical view is expected to possibly relate to TF. TF impl rules should operate on the physical form. A JAX logical aval might even correspond, in principle, to several physical avals, but we don't support those here. Instead we assert there is only one and return it. """ physical_aval = core.physical_aval(aval) assert (len(physical_aval.shape) >= len(aval.shape) and physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval) return physical_aval def _jax_physical_dtype(dtype): # assuming () is a fine stand-in shape return _jax_physical_aval(core.ShapedArray((), dtype)).dtype def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]: """Generate a TF shape, possibly containing None for polymorphic dimensions.""" aval = _jax_physical_aval(aval) return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d, aval.shape)) # type: ignore[attr-defined] # In the TF world, we represent float0 as zeros of this type. # We pick bool because this is what JAX uses when it lowers float0 to HLO. _tf_np_dtype_for_float0 = np.bool_ def _to_tf_dtype(jax_dtype): # Note that converting _to_tf_dtype and _to_jax_dtype are not inverses, # due to float0 and 64-bit behavior. try: jax_dtype = _jax_physical_dtype(jax_dtype) except TypeError: # `jax_dtype` isn't actually a valid jax dtype (e.g. it is # tf.float32), so there is no physical dtype anyway pass if jax_dtype == dtypes.float0: jax_dtype = _tf_np_dtype_for_float0 return tf.dtypes.as_dtype(jax_dtype) def _to_jax_dtype(tf_dtype): # Note that converting _to_tf_dtype and _to_jax_dtype are not inverses, # due to float0 and 64-bit behavior. dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype) if dt not in dtypes._jax_dtype_set: raise TypeError(f"dtype {dt} is not a valid JAX array " "type. Only arrays of numeric types are supported by JAX.") return dt def _tfval_to_tensor_jax_dtype(val: TfVal, jax_dtype: Optional[DType] = None, memoize_constants=False) -> Tuple[TfVal, DType]: """Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type. If `jax_dtype` is missing, uses JAX typing rules. See README.md for details regarding 64-bit values. Args: val: a scalar, ndarray, tf.Tensor, or tf.Variable jax_dtype: an optional dtype to use. If missing, uses JAX type inference rules for constants. memoize_constants: whether to memoize TF constants. We can't do this everywhere, we may be outside of a conversion scope. Returns: a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type. """ if isinstance(val, (tf.Tensor, tf.Variable)): jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type conversion_dtype = _to_tf_dtype(jax_dtype) if conversion_dtype != val.dtype: # May need to cast for 64-bit values return tf.cast(val, conversion_dtype), jax_dtype else: return val, jax_dtype else: # A constant jax_dtype = jax_dtype or xla.abstractify(val).dtype # TODO(document): We assume that the value of a constant does not # change through the scope of the function. But it may be an ndarray, ... # JAX has the same problem when generating HLO. const_key = (id(val), jax_dtype) # Since we use id(val) as a cache key, we have to make sure that we keep # the previous `val` alive. Otherwise, for a ndarray, it can get garbage # collected and reused for a different value, which would create correctness # issues. We keep the `val` alive by storing in the cache the pair # `(val, tf_val)`. # Only memoize non-scalars. JAX will lift all non-scalar constants as # Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them # early, when entering the Jaxpr, so we create the tf.const early and its # scope is the entire Jaxpr. do_memoize = (memoize_constants and np.size(val) > 1 and _thread_local_state.constant_cache is not None) if do_memoize: _, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None)) else: tf_val = None if tf_val is None: conversion_dtype = _to_tf_dtype(jax_dtype) # The float0 type is not known to TF. if jax_dtype == dtypes.float0: val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype) if hasattr(val, 'dtype') and dtypes.is_opaque_dtype(val.dtype): val = val.dtype._rules.physical_const(val) tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype) if do_memoize: _thread_local_state.constant_cache[const_key] = (val, tf_val) return tf_val, jax_dtype def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: # Returns a tuple of shape_poly.dim_as_value_dtype # Used only for non-native lowering assert all(map(lambda x: x is not None, shape)), ( f"Argument shape should be a valid JAX shape but got {shape}") if dtype is not None: shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape if core.is_constant_shape(shape): return tuple(int(d) for d in shape) dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env) shape_values_tf, _ = _interpret_fun_jax( partial(core.evaluate_shape, shape, dim_vars), dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore # Keep only the non-constant dimensions return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf for d, d_tf in zip(shape, shape_values_tf)) def _ensure_tf_shape_if_dynamic(x: TfVal, shape): # Update TF tensor `x` with shape `shape` if the shape of `x`` is dynamic. if x.shape.is_fully_defined(): return x return tf.ensure_shape(x, shape) def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]): """Asserts that shape matches x.shape in the known dimensions and has dimension polynomials elsewhere.""" # Ensures that the shape does not contain None; it should contain symbolic expressions. def check_one(xd: Optional[int], sd: Any): if core.is_constant_dim(sd): return xd == sd else: assert isinstance(sd, shape_poly._DimExpr) return True assert (len(x.shape) == len(shape) and all(check_one(xd, sd) for xd, sd in zip(x.shape, shape))), \ f"Shape {shape} does not match x.shape {x.shape}" # TODO(b/26854495): pylint doesn't understand slots and inheritance. # pylint: disable=assigning-non-slot class TensorFlowTracer(core.Tracer): """Tracer class that boxes a TF value and a JAX abstract value. In addition to the TF value we carry the JAX abstract value because there are some cases when it cannot be recovered from the value: when we are converting with polymorphic shapes or when the JAX aval has a custom element type. In these cases the shape of the value may have dimensions set to `None`, or it may only correspond to the JAX "physical" (TF/lowering-compatible) shape, so the JAX abstract value may contain more precise information. When the value has a partially-known shape, the dimensions marked as `None` must correspond to non-constant dimensions in the abstract value. See README.md for details. """ # val: TfVal # _aval: core.ShapedArray __slots__ = ["val", "_aval"] def __init__(self, trace: "TensorFlowTrace", val: TfVal, aval: core.AbstractValue): self._trace = trace self._aval = aval phys_aval = _jax_physical_aval(self._aval) # type: ignore[arg-type] if isinstance(val, (tf.Tensor, tf.Variable)): val_shape = val.shape if config.jax_enable_checks: assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than {val_shape=}" # To compare types, we must handle float0 in JAX and x64 in TF if phys_aval.dtype == dtypes.float0: assert _to_tf_dtype(phys_aval.dtype) == val.dtype, f"expected {phys_aval.dtype} == {val.dtype}" else: assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}" for aval_dim, val_dim in zip(phys_aval.shape, val_shape): # type: ignore[attr-defined] if val_dim is None: assert shape_poly.is_poly_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined] elif not shape_poly.is_poly_dim(aval_dim): assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined] else: # We have a TF value with known shape, and the abstract shape is a shape variable. try: aval_int = int(_eval_shape([aval_dim])) # type: ignore except (TypeError, KeyError): continue assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore self.val = _tfval_to_tensor_jax_dtype(val, phys_aval.dtype, memoize_constants=True)[0] # type: ignore[attr-defined] @property def aval(self): return self._aval def full_lower(self): return self def _make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, ) -> xla_client.OpMetadata: eqn_str = (str(source_info.name_stack) + '/' + core.str_eqn_compact(primitive.name, params)) frame = source_info_util.user_frame(source_info) return xla_client.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=xla.get_canonical_source_file(frame) if frame else None, source_line=frame.start_line if frame else None) class TensorFlowTrace(core.Trace): """Trace class that underlies the jax2tf transformation. We are going to ensure that jax2tf.convert is never nested inside other transformations. This is sufficient for intended use cases (converting fully-transformed JAX code). It also simplifies our job because we do not have to handle situations where we apply primitives on a mix of TF values and JAX tracers from an outer transformation. E.g., for addition both the TF values and the JAX tracers have an override and they get confused if they see values from the other world. Hence a TFT trace does not interact with non-TFT traces at lower-level. For higher-order control-flow primitives we invoke recursively _interpret_fun on the body of the conditional, which will create a nested TFT. We do want to allow transformations nested inside a TensorFlowTrace (TFT), but those will introduce their own MainTrace, and any operations involving those will be done on those traces, i.e., not a concern for TFT. """ def pure(self, val: TfVal) -> TensorFlowTracer: """Lifts a non-Tracer into the TensorFlowTracer. This function may be called by way of trace.full_raise. """ if hasattr(val, "__jax_array__"): val = val.__jax_array__() if isinstance(val, TensorFlowTracer): return val tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True) return TensorFlowTracer( self, tf_val, core.ShapedArray(np.shape(val), jax_dtype, weak_type=dtypes.is_weakly_typed(val))) def lift(self, val: core.Tracer) -> TensorFlowTracer: # This would be called when we need to raise a tracer from a lower-level # main into the TensorFlowTrace. Since the TensorFlowTrace is never nested # inside another transform, there are no lower-level main traces. assert False def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer: # This is called when we need to raise a tracer from the same main, # but a lower sublevel. This could come from a nested jit. return TensorFlowTracer(self, val.val, val._aval) def process_primitive(self, primitive: core.Primitive, tracers: Sequence[TensorFlowTracer], params) -> TensorFlowTracer: impl, impl_needs_avals = self.get_primitive_impl(primitive) args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) # This is a bit conservative, doing abstract_eval even in op-by-op execution # but we needed it for, e.g., shape_polymorphism where only JAX's # abstract evaluation rules can properly track polymorphic shapes. # Unfortunately under op-by-op execution this is a rare occasion where we # need abstract evaluation. out_aval, _ = primitive.abstract_eval(*args_avals, **params) args_tf: Sequence[TfVal] = [t.val for t in tracers] def invoke_impl() -> TfVal: if impl_needs_avals: return impl( *args_tf, _in_avals=args_avals, # type: ignore _out_aval=out_aval, **params) else: return impl(*args_tf, **params) current_name_stack = _get_current_name_stack() # We don't use `str(name_stack)` because it uses parentheses for # transformations, which aren't allowed in `name_scope`. scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr] # Here we reset the name scope to the memorized TF name scope # + JAX name stack by using absolute scope. # We need to add a '/' to the name stack string to force `tf.name_scope` # to interpret it as an absolute scope, not a relative scope. if _thread_local_state.tf_outer_name_scope: scope = f"{_thread_local_state.tf_outer_name_scope}/{scope}" if not scope.endswith("/"): scope = scope + "/" with tf.name_scope(_sanitize_scope_name(scope)): if _thread_local_state.include_xla_op_metadata: op_metadata = _make_op_metadata(primitive, params, source_info=source_info_util.current()) op_metadata_proto = xla_data_pb2.OpMetadata( op_type=op_metadata.op_type, op_name=op_metadata.op_name, source_file=op_metadata.source_file, source_line=op_metadata.source_line ) with tf_ops.get_default_graph()._attr_scope( {"_XlaOpMetadata": attr_value_pb2.AttrValue( s=op_metadata_proto.SerializeToString())}): val_out = invoke_impl() else: val_out = invoke_impl() if primitive.multiple_results: out = [ TensorFlowTracer(self, v, a) for v, a in zip(val_out, out_aval) ] # type: ignore else: out = TensorFlowTracer(self, val_out, out_aval) # type: ignore # Check that the impl rule returned a value of expected shape and dtype # TODO: adapt this to match polymorphic shapes if config.jax_enable_checks: if primitive.multiple_results: for o, expected_aval in zip(out, out_aval): # type: ignore assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), ( f"{primitive}: out.aval = {o.aval}; expected {expected_aval}") else: assert out.aval == out_aval, ( # type: ignore f"{primitive}: out.aval = {out.aval}; expected {out_aval}" ) # type: ignore return out # type: ignore def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, tracers: Sequence[TensorFlowTracer], params): assert call_primitive.multiple_results vals: Sequence[TfVal] = [t.val for t in tracers] avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) interpreted_fun = _interpret_subtrace(fun, self.main, avals) extra_name_stack = None with _extended_name_stack(extra_name_stack): with core.new_sublevel(): vals_out = interpreted_fun.call_wrapped(*vals) return [TensorFlowTracer(self, v, a) for v, a in vals_out] def post_process_call(self, call_primitive: core.Primitive, out_tracers: Sequence[TensorFlowTracer], params): # We encountered a call primitive whose result (out_tracers) include # TensorFlowTracer that were not passed through its arguments (captured from # the environment). vals = tuple(t.val for t in out_tracers) main = self.main def todo(vals: Sequence[TfVal]): # TODO: is name_stack correct? trace = TensorFlowTrace(main, core.cur_sublevel()) return [ TensorFlowTracer(trace, v, out_tracer.aval) for v, out_tracer in zip(vals, out_tracers) ] return vals, todo def process_map(self, map_primitive, f, tracers, params): raise NotImplementedError("process_map") def post_process_map(self, map_primitive, out_tracers, params): raise NotImplementedError("post_process_map") def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so # there are no more JAX differentiation transformations to be applied. del jvp, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable assuming jax2tf runs with clean trace state def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so # there are no more JAX differentiation transformations to be applied. del fwd, bwd, out_trees, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) def post_process_custom_vjp_call(self, out_tracers, _): assert False # unreachable assuming jax2tf runs with clean trace state def post_process_custom_vjp_call_fwd(self, *_, **__): assert False # unreachable assuming jax2tf runs with clean trace state def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]: # Returns the primitive implementation and whether the implementation # takes abstract values (see definition of tf_impl_with_avals) if not _thread_local_state.enable_xla: try: return tf_impl_no_xla[p], True # Always require avals. except KeyError: pass try: return tf_impl[p], False except KeyError: try: return tf_impl_with_avals[p], True except KeyError as err: msg = "TensorFlow interpretation rule for '{}' not implemented" raise NotImplementedError(msg.format(p)) from err def _unexpected_primitive(p: core.Primitive, *args, **kwargs): assert False, f"Encountered unexpected primitive {p}" # Call primitives are inlined for unexpected in [core.call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. tf_not_yet_impl = [ "clz", "igamma_grad_a", "random_gamma_grad", "reduce_xor", "schur", "closed_call", "unreachable", "bint", "getslice", "full_to_shard", "shard_to_full", "pure_callback", "for", "inspect_sharding", "io_callback", "shard_map", "global_array_to_host_local_array", "host_local_array_to_global_array", "call_exported", # Not high priority? "after_all", "all_to_all", "check", "create_token", "custom_transpose_call", "custom_vmap_call", "infeed", "linear_call", "outfeed", "pmax_p", "pmin", "ppermute", "psum", "pmax", "pgather", "reduce_scatter", "axis_index", "pdot", "all_gather", "lu_pivots_to_permutation", "xla_pmap", "geqrf", "householder_product", "hessenberg", "tridiagonal", "eigh_jacobi", ] tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient tf_impl[ad_util.zeros_like_p] = tf.zeros_like def _add(x: TfVal, y: TfVal) -> TfVal: return tf.raw_ops.AddV2(x=x, y=y) tf_impl[ad_util.add_jaxvals_p] = _add tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x tf_impl[lax_internal.copy_p] = lambda x: x def _neg(x: TfVal) -> TfVal: if x.dtype.is_unsigned: signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[x.dtype] x_signed = tf.cast(x, signed_dtype) res_signed = tf.math.negative(x_signed) return tf.cast(res_signed, x.dtype) else: return tf.math.negative(x) tf_impl[lax.neg_p] = _neg def _sign(x: TfVal) -> TfVal: if x.dtype.is_unsigned: # TF and XLA do not support tf.math.sign for unsigned types. return tf.where( tf.math.equal(x, 0), tf.constant(0, dtype=x.dtype), tf.constant(1, dtype=x.dtype)) else: return tf.math.sign(x) tf_impl[lax.sign_p] = _sign tf_impl[lax.floor_p] = tf.math.floor tf_impl[lax.ceil_p] = tf.math.ceil def _round(operand, *, rounding_method, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO: # JAX uses a single HLO op Round here sign = _sign(operand) operand *= sign floor = tf.math.floor(operand) operand -= floor cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype)) return sign * ( tf.where(cond, tf.constant(np.array(1), operand.dtype), tf.math.round(operand)) + floor) else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN return tf.math.round(operand) tf_impl_with_avals[lax.round_p] = _round tf_impl[lax.nextafter_p] = tf.math.nextafter def _population_count(x): orig_dtype = x.dtype return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype) tf_impl[lax.population_count_p] = _population_count tf_impl[lax.is_finite_p] = tf.math.is_finite def _abs(x: TfVal) -> TfVal: # TF and XLA do not support tf.math.abs for unsigned types. return tf.math.abs(x) if not x.dtype.is_unsigned else x tf_impl[lax.abs_p] = _abs tf_impl[lax.pow_p] = tf.math.pow def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # Follows the implementation in lax._integer_pow_translation_rule if y == 0: return tf.broadcast_to( tf.constant(1, dtype=x.dtype, shape=()), _eval_shape(_out_aval.shape)) is_reciprocal = y < 0 if is_reciprocal: y = -y acc = None while y > 0: if y & 1: acc = x if acc is None else tf.math.multiply(acc, x) y >>= 1 if y > 0: x = tf.math.multiply(x, x) return tf.math.reciprocal(acc) if is_reciprocal else acc tf_impl_with_avals[lax.integer_pow_p] = _integer_pow tf_impl[lax.exp_p] = tf.math.exp tf_impl[lax.expm1_p] = tf.math.expm1 tf_impl[lax.log_p] = tf.math.log tf_impl[lax.log1p_p] = tf.math.log1p tf_impl[lax.tan_p] = tf.math.tan tf_impl[lax.tanh_p] = tf.math.tanh tf_impl[lax.sin_p] = tf.math.sin tf_impl[lax.sinh_p] = tf.math.sinh tf_impl[lax.cos_p] = tf.math.cos tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.acos_p] = _convert_jax_impl( lax_internal.acos_impl, multiple_results=False) tf_impl_with_avals[lax.asin_p] = _convert_jax_impl( lax_internal.asin_impl, multiple_results=False) tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) # TODO(phawkins): use tf.math.sigmoid here instead. tf_impl_with_avals[lax.logistic_p] = _convert_jax_impl( lax_internal.logistic_impl, multiple_results=False) def _atan2(y, x, **kwargs): if x.dtype.is_complex or y.dtype.is_complex: complex_component_dtype = { tf.complex64: tf.float32, tf.complex128: tf.float64 }.get(y.dtype) zero = tf.constant(0, complex_component_dtype) one = tf.constant(1, complex_component_dtype) i = tf.complex(zero, one) return -i * tf.math.log((x + i * y)/tf.math.sqrt(x * x + y * y)) else: return tf.math.atan2(y, x) tf_impl[lax.atan2_p] = _atan2 tf_impl[lax.acosh_p] = tf.math.acosh tf_impl[lax.atanh_p] = tf.math.atanh tf_impl[lax.asinh_p] = tf.math.asinh tf_impl[lax.sqrt_p] = tf.math.sqrt tf_impl[lax.rsqrt_p] = tf.math.rsqrt def _cbrt(x): return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) tf_impl[lax.cbrt_p] = _cbrt tf_impl[lax.lgamma_p] = tf.math.lgamma tf_impl[lax.digamma_p] = tf.math.digamma tf_impl[lax.igamma_p] = tf.math.igamma tf_impl[lax.igammac_p] = tf.math.igammac tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc tf_impl[lax.erf_p] = tf.math.erf tf_impl[lax.erfc_p] = tf.math.erfc tf_impl[lax.erf_inv_p] = tf.math.erfinv tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e tf_impl[lax.complex_p] = tf.complex def _conj(x, **kwargs): # The only dtypes that are allowed are: float32, float64, complex64, and # complex128. if x.dtype == tf.float32: return tf.cast(x, tf.complex64) elif x.dtype == tf.float64: return tf.cast(x, tf.complex128) else: return tf.math.conj(x) tf_impl[lax.conj_p] = _conj tf_impl[lax.real_p] = tf.math.real tf_impl[lax.imag_p] = tf.math.imag tf_impl[lax.add_p] = _add tf_impl[lax.sub_p] = tf.math.subtract tf_impl[lax.mul_p] = tf.math.multiply def _iota(*, dtype, shape, dimension): dtype = _to_tf_dtype(dtype) # Some dtypes are unsupported, like uint32, so we just fall back to int32. # TODO(mattjj, necula): improve tf.range dtype handling shape_tf = _eval_shape(shape) vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32) vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))] return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype) tf_impl[lax.iota_p] = _iota def _div(lhs, rhs): if lhs.dtype.is_integer: quotient = tf.math.floordiv(lhs, rhs) select = tf.math.logical_and( tf.not_equal(_sign(lhs), _sign(rhs)), tf.not_equal(tf.math.floormod(lhs, rhs), 0)) return tf.where(select, quotient + 1, quotient) else: return tf.math.truediv(lhs, rhs) def _rem(lhs, rhs): return _sign(lhs) * tf.math.floormod(_abs(lhs), _abs(rhs)) tf_impl[lax.div_p] = _div tf_impl[lax.rem_p] = _rem def _minmax(x: TfVal, y: TfVal, *, is_min: bool, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray,) -> TfVal: # For complex numbers use lexicographic ordering, like JAX if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating): return _convert_jax_impl( partial(lax_internal._minmax_complex_lowering, lax_cmp_pick_x=lax.lt if is_min else lax.gt), multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval) elif x.dtype.as_numpy_dtype == np.bool_: return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y) else: return (tf.math.minimum if is_min else tf.math.maximum)(x, y) def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal: # For reducers we will need min/max for scalars only. In that case we # can construct the AbstractValues outselves, even in the presence of # shape polymorphism. assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}" aval = core.ShapedArray((), _to_jax_dtype(x.dtype)) return _minmax(x, y, is_min=is_min, _in_avals=[aval, aval], _out_aval=aval) tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False) tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True) # Map from TF signed types to TF unsigned types. _SIGNED_TO_UNSIGNED_TABLE = { tf.int8: tf.uint8, tf.int16: tf.uint16, tf.int32: tf.uint32, tf.int64: tf.uint64, } # Map from TF unsigned types to TF signed types. _UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()} # Note: Bitwise operations only yield identical results on unsigned integers! # pylint: disable=protected-access def _shift_right_arithmetic_raw(x, y): if x.dtype.is_unsigned: assert x.dtype == y.dtype orig_dtype = x.dtype signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype] x = tf.cast(x, signed_dtype) y = tf.cast(y, signed_dtype) res = tf.bitwise.right_shift(x, y) return tf.cast(res, orig_dtype) else: return tf.bitwise.right_shift(x, y) def _shift_right_arithmetic(x, y): # TF shift is "implementation defined" if the shift amount is negative # or larger or equal to the size of the value. We implement the XLA # semantics to return the shift by the max value (x_bits - 1). # TODO: it is likely better to add XlaOps for shifts x_bits = 8 * x.dtype.size clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1) return _shift_right_arithmetic_raw(x, clamp_y) tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic def _shift_right_logical_raw(x, y): if x.dtype.is_unsigned: return tf.bitwise.right_shift(x, y) else: assert x.dtype == y.dtype orig_dtype = x.dtype unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype] x = tf.cast(x, unsigned_dtype) y = tf.cast(y, unsigned_dtype) res = tf.bitwise.right_shift(x, y) return tf.cast(res, orig_dtype) def _shift_right_logical(x, y): # TF shift is "implementation defined" if the shift amount is negative # or larger or equal to the size of the value. We implement the XLA semantics # to return 0. # TODO: it is likely better to add XlaOps for shifts return tf.where( _shift_in_bounds(x, y), _shift_right_logical_raw(x, y), tf.zeros_like(x)) tf_impl[lax.shift_right_logical_p] = _shift_right_logical def _shift_left(x, y): # TF shift is "implementation defined" if the shift amount is negative # or larger or equal to the size of the value. We implement the XLA semantics # to return 0. # TODO: it is likely better to add XlaOps for shifts return tf.where( _shift_in_bounds(x, y), tf.bitwise.left_shift(x, y), tf.zeros_like(x)) tf_impl[lax.shift_left_p] = _shift_left def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal: # Return the TF expression for when y is within bounds (0 <= y < |x|) x_bits = 8 * x.dtype.size # TF does not have comparisons for uint16 and uint32 (despite what the # documentation says) y_comp = tf.cast( y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y y_lt_x_bits = tf.math.less(y_comp, x_bits) y_ge_0 = tf.math.greater_equal(y_comp, 0) return tf.logical_and(y_lt_x_bits, y_ge_0) def _not(x): """Computes bitwise not with support for booleans. Numpy and JAX support bitwise not for booleans by applying a logical not! This means that applying bitwise_not yields an unexpected result: jnp.bitwise_not(jnp.array([True, False])) >> DeviceArray([False, True], dtype=bool) if you assume that booleans are simply casted to integers. jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool) >> DeviceArray([True, True], dtype=bool) """ if x.dtype == tf.bool: return tf.logical_not(x) else: return tf.bitwise.invert(x) tf_impl[lax.not_p] = _not def handle_boolean_args(f, argnums: Sequence[int], boolean_f=None): """Computes functions with some bool args and bool results using int8. This is needed because some TF ops do not work for bool args, e.g., inequalities, min/max. Args: f: a TF callable to wrap. It will be called with non-boolean arguments. argnums: the positional arguments that may be booleans. boolean_f: [Optional] a TF callable compatible with boolean arguments. Returns: a TF callable that can take a mix of boolean positional arguments (in the positions specified by `argnums`) and some non-boolean positional arguments. If there are no boolean arguments, just calls `f`. Otherwise, it calls `boolean_f` if defined. Otherwise, casts the boolean arguments to `int8`, calls `f`, then casts the result to `bool`. """ argnums = tf.nest.flatten(argnums) def wrapper(*args: TfVal, **kwargs): argnum_types = {args[i].dtype for i in argnums} if tf.bool not in argnum_types: return f(*args, **kwargs) else: # All argnums should be boolean assert len(argnum_types) == 1, argnum_types if boolean_f != None: return boolean_f(*args, **kwargs) else: args_cast = [(tf.cast(a, tf.int8) if i in argnums else a) for i, a in enumerate(args)] if "_in_avals" in kwargs: def cast_aval(aval): assert aval.dtype == np.bool_ return core.ShapedArray(aval.shape, np.int8) _in_avals_cast = [ cast_aval(aval) if i in argnums else aval for i, aval in enumerate(kwargs["_in_avals"]) ] _out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"]) kwargs = dict( kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast) out = f(*args_cast, **kwargs) return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out) return wrapper tf_impl[lax.or_p] = handle_boolean_args(tf.bitwise.bitwise_or, argnums=(0, 1), boolean_f=tf.logical_or) tf_impl[lax.and_p] = handle_boolean_args(tf.bitwise.bitwise_and, argnums=(0, 1), boolean_f=tf.logical_and) tf_impl[lax.xor_p] = handle_boolean_args(tf.bitwise.bitwise_xor, argnums=(0, 1), boolean_f=tf.math.logical_xor) tf_impl[lax.eq_p] = tf.math.equal tf_impl[lax.ne_p] = tf.math.not_equal boolean_greater = lambda x,y: tf.logical_and(x, tf.logical_not(y)) # Only one combo: T,F -> T boolean_less = lambda x,y: tf.logical_and(tf.logical_not(x), y) # Only one combo: F,T -> T boolean_greater_or_equal = lambda x, y: tf.logical_not(boolean_less(x,y)) # All cases except F,T boolean_less_or_equal = lambda x, y: tf.logical_not(boolean_greater(x,y)) # All cases except T,F tf_impl[lax.gt_p] = handle_boolean_args(tf.math.greater, argnums=(0, 1), boolean_f=boolean_greater) tf_impl[lax.lt_p] = handle_boolean_args(tf.math.less, argnums=(0, 1), boolean_f=boolean_less) tf_impl[lax.ge_p] = handle_boolean_args(tf.math.greater_equal, argnums=(0, 1), boolean_f=boolean_greater_or_equal) tf_impl[lax.le_p] = handle_boolean_args(tf.math.less_equal, argnums=(0, 1), boolean_f=boolean_less_or_equal) tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky def _convert_element_type(operand, *, new_dtype, weak_type=False): old_dtype = operand.dtype.as_numpy_dtype if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = tf.math.real(operand) if (dtypes.issubdtype(old_dtype, np.floating) and not (dtypes.issubdtype(new_dtype, np.floating) or dtypes.issubdtype( new_dtype, np.complexfloating) or new_dtype == np.bool_)): sign = _sign(operand) operand = sign * tf.math.floor(sign * operand) return tf.dtypes.cast(operand, _to_tf_dtype(new_dtype)) tf_impl[lax.convert_element_type_p] = _convert_element_type def _bitcast_convert_type(operand, new_dtype): if operand.dtype == new_dtype: return operand return tf.bitcast(operand, _to_tf_dtype(new_dtype)) tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type def _clamp(minval, operand, maxval, *, _in_avals, _out_aval): # The below permits mirroring the behavior of JAX when maxval < minval op_shape_tf_val = _eval_shape(_in_avals[1].shape, _in_avals[1].dtype) maxval = tf.broadcast_to(maxval, op_shape_tf_val) minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval) return tf.clip_by_value(operand, minval, maxval) tf_impl_with_avals[lax.clamp_p] = _clamp def _concatenate(*operands, dimension): return tf.concat(operands, axis=tf.cast(dimension, tf.int32)) tf_impl[lax.concatenate_p] = _concatenate def _conv_general_dimension_numbers_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) lhs_spec, rhs_spec, out_spec = dimension_numbers proto = xla_data_pb2.ConvolutionDimensionNumbers() proto.input_batch_dimension = lhs_spec[0] proto.input_feature_dimension = lhs_spec[1] proto.output_batch_dimension = out_spec[0] proto.output_feature_dimension = out_spec[1] proto.kernel_output_feature_dimension = rhs_spec[0] proto.kernel_input_feature_dimension = rhs_spec[1] proto.input_spatial_dimensions.extend(lhs_spec[2:]) proto.kernel_spatial_dimensions.extend(rhs_spec[2:]) proto.output_spatial_dimensions.extend(out_spec[2:]) return proto def _precision_config_proto(precision: Optional[Tuple[PrecisionType, PrecisionType]]): """Convert an integer to an XLA.PrecisionConfig.""" if precision is None: return None proto = xla_data_pb2.PrecisionConfig() proto.operand_precision.append(int(precision[0])) proto.operand_precision.append(int(precision[1])) return proto def _conv_general_dilated(lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers: lax.ConvDimensionNumbers, feature_group_count: int, batch_group_count: int, precision: Optional[Tuple[PrecisionType, PrecisionType]], preferred_element_type: Optional[DType], _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.conv_general_dilated_p using XlaConv.""" out_tf_shape = _aval_to_tf_shape(_out_aval) dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers) precision_config_proto = _precision_config_proto(precision) def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]): tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2]) if tf_version >= (2, 8): # TODO(necula): remove when 2.8.0 is the stable TF version (and supports # batch_group_count. padding_tf = [_eval_shape(p) for p in padding] out = tfxla.conv( lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation, dnums_proto, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision_config=precision_config_proto, preferred_element_type=preferred_element_type, use_v2=True) else: if batch_group_count != 1: raise ValueError( "The batch_group_count parameter for conv requires TF version " "at least 2.8.0. You may want to use tf-nightly.") padding_tf = [_eval_shape(p) for p in padding] out = tfxla.conv( lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation, dnums_proto, feature_group_count=feature_group_count, precision_config=precision_config_proto, preferred_element_type=preferred_element_type, use_v2=True) # TODO: implement shape inference for XlaConv out = _ensure_tf_shape_if_dynamic(out, out_tf_shape) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out # Follow the lowering for complex convolutions from # lax._conv_general_dilated_translation. We can use the same conversion on all # platforms because on XLA:TPU the compiler does the same as a rewrite. preferred_float_et: Optional[Any] if np.issubdtype(_in_avals[0].dtype, np.complexfloating): if preferred_element_type is not None: # Convert complex dtype to types used for real and imaginary parts assert np.issubdtype(preferred_element_type, np.complexfloating) preferred_float_et = ( np.float64 if preferred_element_type == np.complex128 else np.float32) else: preferred_float_et = None lhs_real, lhs_imag = tf.math.real(lhs), tf.math.imag(lhs) rhs_real, rhs_imag = tf.math.real(rhs), tf.math.imag(rhs) k1 = gen_conv(_add(lhs_real, lhs_imag), rhs_real, preferred_float_et) k2 = gen_conv(lhs_real, tf.math.subtract(rhs_imag, rhs_real), preferred_float_et) k3 = gen_conv(lhs_imag, _add(rhs_real, rhs_imag), preferred_float_et) return tf.complex(tf.math.subtract(k1, k3), _add(k1, k2)) else: return gen_conv(lhs, rhs, preferred_element_type) tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated def _dot_general(lhs, rhs, *, dimension_numbers, precision: Optional[Tuple[PrecisionType, PrecisionType]], preferred_element_type: Optional[DType], _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers dnums_proto = xla_data_pb2.DotDimensionNumbers() dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting) dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting) dnums_proto.lhs_batch_dimensions.extend(lhs_batch) dnums_proto.rhs_batch_dimensions.extend(rhs_batch) precision_config_proto = _precision_config_proto(precision) res = tfxla.dot_general( lhs, rhs, dnums_proto, precision_config_proto, preferred_element_type=preferred_element_type, use_v2=True) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: res = tf.stop_gradient(res) # See #7839 return res tf_impl_with_avals[lax.dot_general_p] = _dot_general def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # for i in range(len(operand.shape)): # result.shape[bcast_dims[i]] <- operand.shape[i] # bcast_dims must be strictly increasing. # len(bcast_dims) == len(operand.shape) op_shape = _in_avals[0].shape dtype = _in_avals[0].dtype add_1s_shape = [1] * len(shape) for i, broadcast_dim_i in enumerate(broadcast_dimensions): add_1s_shape[broadcast_dim_i] = op_shape[i] with_1s = tf.reshape(operand, _eval_shape(add_1s_shape, dtype=dtype)) return tf.broadcast_to(with_1s, _eval_shape(shape, dtype=dtype)) tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim def _empty(*, dtype): if dtypes.is_opaque_dtype(dtype): raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers return tf.constant(np.array(0, dtype=dtype)) tf_impl[lax_internal.empty_p] = _empty def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval): if dimensions is None: dimensions = tf.range(tf.rank(operand)) new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype) return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf) tf_impl_with_avals[lax.reshape_p] = _reshape def _squeeze(operand, *, dimensions, _in_avals, _out_aval): op_aval = _jax_physical_aval(_in_avals[0]) op_shape = op_aval.shape new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions) new_shape_tf = _eval_shape(new_shape, op_aval.dtype) return tf.reshape(operand, new_shape_tf) tf_impl_with_avals[lax.squeeze_p] = _squeeze def _pad(operand, padding_value, *, padding_config, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): low, high, interior = util.unzip3(map(_eval_shape, padding_config)) # type: ignore out = tfxla.pad(operand, padding_value, low, high, interior) # TODO: implement shape inference for XlaPad (when some padding_config is constant) out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out tf_impl_with_avals[lax.pad_p] = _pad def _rev(operand, *, dimensions): return tf.reverse(operand, dimensions) tf_impl[lax.rev_p] = _rev def _where(which, *cases): if which.dtype == tf.bool: assert len(cases) <= 2 return cases if len(cases) == 1 else tf.where(which, cases[1], cases[0]) def _select(offset, cases): assert len(cases) > 0 if len(cases) == 1: return cases[0] mid = len(cases) // 2 return tf.where(tf.less(which, offset + mid), _select(offset, cases[:mid]), _select(mid, cases[mid:])) return _select(0, cases) tf_impl[lax.select_n_p] = _where def _transpose(operand, *, permutation): return tf.transpose(operand, perm=permutation) tf_impl[lax.transpose_p] = _transpose axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes) # reduce_sum and reduce_prod are not supported for bool tf_impl[lax.reduce_sum_p] = axes_to_axis(tf.reduce_sum) tf_impl[lax.reduce_prod_p] = axes_to_axis(tf.reduce_prod) tf_impl[lax.reduce_max_p] = handle_boolean_args( axes_to_axis(tf.reduce_max), argnums=[0], boolean_f=axes_to_axis(tf.reduce_any)) # Max is T if any one is T tf_impl[lax.reduce_min_p] = handle_boolean_args( axes_to_axis(tf.reduce_min), argnums=[0], boolean_f=axes_to_axis(tf.reduce_all)) # Min is F if not all are T tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any) tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all) def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], index_dtype: DType, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # Follow the JAX implementation, using a XlaReduce with a custom comparator if is_min: extra_name_stack = "argmin" value_comparator = lax.lt get_identity = lax_internal._get_min_identity else: extra_name_stack = "argmax" value_comparator = lax.gt get_identity = lax_internal._get_max_identity res = _convert_jax_impl( partial(lax_internal._compute_argminmax, value_comparator, get_identity), multiple_results=False, extra_name_stack=extra_name_stack)( operand, index_dtype=index_dtype, axes=axes, _in_avals=_in_avals, _out_aval=_out_aval) return res tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True) tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False) _add_fn = tf.function(_add, autograph=False) _ge_fn = tf.function(tf.math.greater_equal, autograph=False) def _select_and_gather_add( tangents: TfVal, operand: TfVal, select_prim: core.Primitive, window_dimensions: Sequence[int], window_strides: Sequence[int], base_dilation: Sequence[int], window_dilation: Sequence[int], padding: Sequence[Tuple[int, int]], _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # Note: this function follows the pattern in # jax.lax._select_and_gather_add_translation. dtype = operand.dtype nbits = dtypes.finfo(dtype.as_numpy_dtype).bits # Specializing the function for 64 bits. Only up to 32 bits are supported on TPU, # we thus intend to let the code throw a different exception on this platform. max_bits = 64 assert nbits <= max_bits double_word_reduction = nbits * 2 <= max_bits const = lambda dtype, x: tf.constant(np.array(x), dtype) if double_word_reduction: word_dtype = lax_internal._UINT_DTYPES[nbits] double_word_dtype = lax_internal._UINT_DTYPES[nbits * 2] # Packs two values into a tuple. def pack(a, b): a = _bitcast_convert_type(a, word_dtype) b = _bitcast_convert_type(b, word_dtype) a = _convert_element_type(a, new_dtype=double_word_dtype) b = _convert_element_type(b, new_dtype=double_word_dtype) a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits)) return tf.bitwise.bitwise_or(a, b) # Unpacks the first element of a tuple. def fst(t): assert t.dtype == double_word_dtype st = _shift_right_logical(t, const(double_word_dtype, nbits)) return _bitcast_convert_type( _convert_element_type(st, new_dtype=word_dtype), dtype) # Unpacks the second element of a tuple. def snd(t): return _bitcast_convert_type( _convert_element_type(t, new_dtype=word_dtype), dtype) else: raise NotImplementedError( f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits." ) assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim def reducer(x, y): which = tf_impl[select_prim] return tf_impl[lax.select_n_p](which(fst(x), fst(y)), y, x) init = -np.inf if select_prim is lax.ge_p else np.inf init_identity = lambda x: pack(const(dtype, init), const(dtype, 0)) out = _specialized_reduce_window( reducer, init_identity, pack(operand, tangents), window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation, _in_avals=_in_avals, _out_aval=_out_aval) return snd(out) tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add def _common_reduce_window(operand, init_val, reducer, window_dimensions, window_strides, padding, base_dilation, window_dilation, _in_avals, _out_aval): o_spec = tf.TensorSpec((), dtype=operand.dtype) reducer_fn = tf.function( reducer, autograph=False).get_concrete_function(o_spec, o_spec) if not isinstance(init_val, (tf.Tensor, tf.Variable)): init_val = tf.constant(init_val, operand.dtype) window_dimensions_tf = _eval_shape(window_dimensions) window_strides_tf = _eval_shape(window_strides) window_dilation_tf = _eval_shape(window_dilation) base_dilation_tf = _eval_shape(base_dilation) padding_tf = [_eval_shape(p) for p in padding] out = tfxla.reduce_window( operand, init_val, reducer_fn, window_dimensions_tf, window_strides_tf, base_dilations=base_dilation_tf, window_dilations=window_dilation_tf, padding=padding_tf) # TODO: implement shape inference for XlaReduceWindow out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out def _reduce_window(*args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation, _in_avals, _out_aval): """TensorFlow implementation of reduce_window. Args: operands: N dimensional arrays containing elements of type T init_values: starting values of the reduction jaxpr: the jaxpr corresponding to the reduction function consts: the constants associated with jaxpr. window_dimensions: array of integers for window dimension values window_strides: array of integers for window stride values padding: array of pairs of integers for padding values base_dilation: array of integers for base dilation values window_dilation: array of integers for window dilation values Returns: The reduced operand. """ assert len(consts) == 0, "Reduction computation cannot have constants" operands, init_values = util.split_list(args, [len(args) // 2]) if len(operands) != 1: raise NotImplementedError("jax2tf does not support variadic reduce_window") def reducer(arg1: TfVal, arg2: TfVal) -> TfVal: closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None) return res return (_common_reduce_window(operands[0], init_values[0], reducer, window_dimensions, window_strides, padding, base_dilation, window_dilation, _in_avals, _out_aval[0]),) def _specialized_reduce_window(reducer, identity, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation, _in_avals, _out_aval, name=None): """Wraps the TensorFlow reduce window operation based on a reducer and an identity function defining the initial value of the reduction depending on the dtype of the operand. Args: reducer: reduction function of type TfVal -> TfVal -> TfVal identity: function that takes a TensorFlow dtype as a parameter and returns the starting value of the reduction. operand: N dimensional array containing elements of type T window_dimensions: array of integers for window dimension values window_strides: array of integers for window stride values padding: array of pairs of integers for padding values base_dilation: array of integers for base dilation values window_dilation: array of integers for window dilation values name: the name of the specialized reduce window primitive for which this conversion function is called. This information may help to choose a different conversion path (optional) Returns: The reduced operand. """ return _common_reduce_window(operand, identity(operand.dtype), reducer, window_dimensions, window_strides, padding, base_dilation, window_dilation, _in_avals, _out_aval) def _get_max_identity(tf_dtype): numpy_tf_dtype = tf_dtype.as_numpy_dtype if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact): return numpy_tf_dtype(-np.inf) elif dtypes.issubdtype(numpy_tf_dtype, np.integer): return dtypes.iinfo(numpy_tf_dtype).min else: assert dtypes.issubdtype( numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined max identity") return False def _get_min_identity(tf_dtype): numpy_tf_dtype = tf_dtype.as_numpy_dtype if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact): return numpy_tf_dtype(np.inf) elif dtypes.issubdtype(numpy_tf_dtype, np.integer): return dtypes.iinfo(numpy_tf_dtype).max else: assert dtypes.issubdtype( numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined min identity") return True # pylint: disable=protected-access tf_impl_with_avals[lax.reduce_window_sum_p] = ( partial(_specialized_reduce_window, _add, lambda x: 0, name="reduce_window_sum")) tf_impl_with_avals[lax.reduce_window_min_p] = ( partial(_specialized_reduce_window, partial(_minmax_scalar, is_min=True), _get_min_identity, name="reduce_window_min")) tf_impl_with_avals[lax.reduce_window_max_p] = ( partial(_specialized_reduce_window, partial(_minmax_scalar, is_min=False), _get_max_identity, name="reduce_window_max")) tf_impl_with_avals[lax.reduce_window_p] = _reduce_window # pylint: enable=protected-access def _reduce(*operands: TfVal, computation: Callable, jaxpr: core.Jaxpr, consts: Sequence[Any], dimensions: Sequence[int], _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray) -> Sequence[TfVal]: del computation assert not consts assert len(operands) % 2 == 0 # operands: op1, op2, ..., init_val1, init_val2, ... # reducer takes op1[i], op2[i], ..., init_val1, init_val2, ... nr_operands = len(operands) // 2 init_vals = operands[nr_operands:] operands = operands[0:nr_operands] reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2) def reducer_computation(*args: TfVal) -> TfVal: closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None) return res xla_reducer_computation = ( tf.function(reducer_computation, autograph=False).get_concrete_function(*reducer_arg_spec)) outs = tfxla.variadic_reduce(operands, init_vals, dimensions_to_reduce=dimensions, reducer=xla_reducer_computation) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: outs = tuple(tf.stop_gradient(out) for out in outs) # See #7839 return outs tf_impl_with_avals[lax.reduce_p] = _reduce # We use lax.cumred_reduce_window_impl to convert cummax, # cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is # O(n^2) on other backends. This may be implemented using associative_scan # instead to favor different backends. def _cumred(lax_reduce_fn: Callable, lax_reduce_window_fn: Callable, extra_name_stack: str): if config.jax2tf_associative_scan_reductions: return _convert_jax_impl(partial(lax_control_flow.associative_scan, lax_reduce_fn), multiple_results=False, extra_name_stack=extra_name_stack) else: return _convert_jax_impl(partial(lax_control_flow.cumred_reduce_window_impl, lax_reduce_window_fn), multiple_results=False, extra_name_stack=extra_name_stack) tf_impl_with_avals[lax.cummax_p] = _cumred( lax_reduce_window_fn=lax_windowed_reductions._reduce_window_max, lax_reduce_fn=lax.max, extra_name_stack="cummax") tf_impl_with_avals[lax.cummin_p] = _cumred( lax_reduce_window_fn=lax_windowed_reductions._reduce_window_min, lax_reduce_fn=lax.min, extra_name_stack="cummin") tf_impl_with_avals[lax.cumlogsumexp_p] = _cumred( lax_reduce_window_fn=lax_windowed_reductions._reduce_window_logaddexp, lax_reduce_fn=logaddexp, extra_name_stack="cumlogsumexp") tf_impl_with_avals[lax.cumsum_p] = _cumred( lax_reduce_window_fn=lax_windowed_reductions._reduce_window_sum, lax_reduce_fn=lax.add, extra_name_stack="cumsum") tf_impl_with_avals[lax.cumprod_p] = _cumred( lax_reduce_window_fn=lax_windowed_reductions._reduce_window_prod, lax_reduce_fn=lax.mul, extra_name_stack="cumprod") def _select_and_scatter(operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter") tf_impl[lax.select_and_scatter_p] = _select_and_scatter @partial(handle_boolean_args, argnums=(0, 1)) def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions, window_strides, padding, _in_avals, _out_aval): init_value = tf.zeros((), operand.dtype) select_fn = ( tf.function(tf_impl[select_prim], autograph=False).get_concrete_function( init_value, init_value)) scatter_fn = _add_fn.get_concrete_function(init_value, init_value) out = tfxla.select_and_scatter(operand, window_dimensions, window_strides, padding, source, init_value, select_fn, scatter_fn) out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval): def impl_wrapper(seeds: TfVal, *, impl): return prng.random_seed_impl_base(seeds, impl=impl) converted_impl = _convert_jax_impl( impl_wrapper, multiple_results=False, with_physical_avals=True, extra_name_stack="random_seed") return converted_impl( seeds, impl=impl, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[prng.random_seed_p] = _random_seed_impl def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval): keys_aval, = _in_avals def impl_wrapper(keys: TfVal, *, count): return prng.random_split_impl_base( keys_aval.dtype.impl, keys, keys_aval.ndim, count=count) converted_impl = _convert_jax_impl( impl_wrapper, multiple_results=False, with_physical_avals=True, extra_name_stack="random_split") return converted_impl( keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[prng.random_split_p] = _random_split_impl def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval): keys_aval, _ = _in_avals def impl_wrapper(keys: TfVal, msgs: TfVal): return prng.random_fold_in_impl_base( keys_aval.dtype.impl, keys, msgs, keys_aval.shape) converted_impl = _convert_jax_impl( impl_wrapper, multiple_results=False, with_physical_avals=True, extra_name_stack="random_fold_in") return converted_impl( keys, msgs, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[prng.random_fold_in_p] = _random_fold_in_impl def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval): keys_aval, = _in_avals def impl_wrapper(keys: TfVal, **kwargs): return prng.random_bits_impl_base( keys_aval.dtype.impl, keys, keys_aval.ndim, bit_width=bit_width, shape=shape) converted_impl = _convert_jax_impl( impl_wrapper, multiple_results=False, with_physical_avals=True, extra_name_stack="random_bits") return converted_impl(keys, bit_width=bit_width, shape=shape, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[prng.random_bits_p] = _random_bits_impl def _random_wrap_impl(base_arr: TfVal, *, impl, _in_avals, _out_aval): return base_arr tf_impl_with_avals[prng.random_wrap_p] = _random_wrap_impl def _random_unwrap_impl(keys: TfVal, *, _in_avals, _out_aval): return keys tf_impl_with_avals[prng.random_unwrap_p] = _random_unwrap_impl def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): res = _convert_jax_impl( partial(prng._threefry2x32_lowering, use_rolled_loops=False), multiple_results=True, extra_name_stack="threefry")( *args, _in_avals=_in_avals, _out_aval=_out_aval) return res tf_impl_with_avals[prng.threefry2x32_p] = _threefry2x32_jax_impl # Use the vmap implementation, otherwise on TPU the performance is really bad # With use_vmap=True on, we get about the same performance for JAX and jax2tf. tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl( partial(random_internal._gamma_impl, use_vmap=True), multiple_results=False, extra_name_stack="random_gamma") def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]: is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32) if is_uint32_key: key = tf.reshape(key, (2, 2)) key = tfxla.bitcast_convert_type(key, _to_tf_dtype(jnp.uint64)) shape_tf = _eval_shape(shape) # JAX uses XLA algorithm enums; tfxla uses tf.random.Algorithm if algorithm == lax.RandomAlgorithm.RNG_THREE_FRY: algorithm_tf = tf.random.Algorithm.THREEFRY elif algorithm == lax.RandomAlgorithm.RNG_PHILOX: algorithm_tf = tf.random.Algorithm.PHILOX elif algorithm == lax.RandomAlgorithm.RNG_DEFAULT: algorithm_tf = tf.random.Algorithm.AUTO_SELECT else: assert False (new_key, res) = tfxla.rng_bit_generator(algorithm_tf.value, key, shape_tf, dtype=_to_tf_dtype(dtype)) if is_uint32_key: new_key = tfxla.bitcast_convert_type(new_key, _to_tf_dtype(jnp.uint32)) new_key = tf.reshape(new_key, (4,)) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: # See #7839 new_key = tf.stop_gradient(new_key) res = tf.stop_gradient(res) return new_key, res tf_impl[lax.rng_bit_generator_p] = _rng_bit_generator def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal: shape_tf = _eval_shape(shape) return tf.random.uniform(shape_tf, minval=minval, maxval=maxval, dtype=minval.dtype) tf_impl[lax.rng_uniform_p] = _rng_uniform def _iota_2x32_shape(*, shape): def _add(x, y): return x + y def _mul(x, y): if not core.is_constant_dim(x): x = tf.cast(_eval_shape((x,))[0], y.dtype) x = tf.broadcast_to(x, tf.shape(y)) return x * y def _cast32(xs): return tf.cast(xs, _to_tf_dtype(jnp.uint32)) iotas = [_iota(dtype=jnp.uint64, shape=shape, dimension=dimension) for dimension in range(len(shape))] counts = prng.bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas) counts_lo = _cast32(counts) counts_hi = _cast32(tf.bitwise.right_shift(counts, 32)) return counts_hi, counts_lo tf_impl[prng.iota_2x32_shape_p] = _iota_2x32_shape def _gather_dimensions_proto(indices_shape, dimension_numbers): proto = xla_data_pb2.GatherDimensionNumbers() proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto def _maybe_cast_to_int64(x: TfVal) -> TfVal: if x.dtype != tf.int32 and x.dtype != tf.int64: return tf.cast(x, tf.int64) return x @partial(handle_boolean_args, argnums=[0]) def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shape, indices_are_sorted, unique_indices, mode, fill_value, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Tensorflow implementation of gather.""" if mode == lax.GatherScatterMode.FILL_OR_DROP: gather_fill_fn = _convert_jax_impl(lax_slicing._gather_fill, multiple_results=False) return gather_fill_fn( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value, output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval) operand_aval = _in_avals[0] start_indices = _maybe_cast_to_int64(start_indices) if dtypes.is_opaque_dtype(operand_aval.dtype): opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] trailing_offset_dims = [len(_out_aval.shape) + i for i in range(len(opaque_shape))] dimension_numbers = dimension_numbers._replace( offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims)) slice_sizes = (*slice_sizes, *opaque_shape) proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers) slice_sizes_tf = _eval_shape(slice_sizes) out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf, indices_are_sorted) out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out tf_impl_with_avals[lax.gather_p] = _gather def _slice(operand, start_indices, limit_indices, strides, _in_avals, _out_aval): if strides is None: strides = [1] * len(start_indices) slices = tuple( map(slice, _eval_shape(start_indices), _eval_shape(limit_indices), _eval_shape(strides))) out = operand[slices] # TODO(b/184503314): improve shape inference for __getitem__ # E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1) out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) return out tf_impl_with_avals[lax.slice_p] = _slice def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): start_indices = _maybe_cast_to_int64(tf.stack(start_indices)) operand_aval = _in_avals[0] if dtypes.is_opaque_dtype(operand_aval.dtype): opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] slice_sizes = (*slice_sizes, *opaque_shape) start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),), dtype=start_indices.dtype)], axis=0) slice_sizes_tf = _eval_shape(slice_sizes) res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: res = tf.stop_gradient(res) # See #7839 return res tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice def _dynamic_update_slice(operand, update, *start_indices, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): start_indices = _maybe_cast_to_int64(tf.stack(start_indices)) operand_aval = _in_avals[0] if dtypes.is_opaque_dtype(operand_aval.dtype): opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),), dtype=start_indices.dtype)], axis=0) out = tfxla.dynamic_update_slice(operand, update, start_indices) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice def _scatter_dimensions_proto(indices_shape, dimension_numbers): proto = xla_data_pb2.ScatterDimensionNumbers() proto.update_window_dims.extend(dimension_numbers.update_window_dims) proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) proto.scatter_dims_to_operand_dims.extend( dimension_numbers.scatter_dims_to_operand_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): del unique_indices if mode == lax.GatherScatterMode.CLIP: clip_fn = _convert_jax_impl(lax_slicing._clamp_scatter_indices, multiple_results=False) scatter_indices = clip_fn( operand, scatter_indices, updates, dnums=dimension_numbers, _in_avals=_in_avals, _out_aval=_in_avals[1]) assert len(update_consts) == 0, "Update computation cannot have constants" proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers) def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts) res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None) return res o_spec = tf.TensorSpec((), dtype=operand.dtype) xla_update_computation = ( tf.function(update_computation, autograph=False).get_concrete_function(o_spec, o_spec)) out = tfxla.scatter( operand, scatter_indices, updates, xla_update_computation, proto, indices_are_sorted=indices_are_sorted) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: out = tf.stop_gradient(out) # See #7839 return out tf_impl_with_avals[lax.scatter_p] = _scatter tf_impl_with_avals[lax.scatter_min_p] = _scatter tf_impl_with_avals[lax.scatter_max_p] = _scatter tf_impl_with_avals[lax.scatter_mul_p] = _scatter tf_impl_with_avals[lax.scatter_add_p] = _scatter def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], linear: Sequence[bool]) -> Sequence[TfVal]: del linear # tf.cond needs lambdas with no arguments. branches_tf = [ partial(_interpret_jaxpr, jaxpr, *operands, # Same name stack as the XLA translation of cond_p extra_name_stack=f"branch_{i}_fun") for i, jaxpr in enumerate(branches) ] # Same name stack as XLA translation of cond_p # Note: extend_name_stack is a contextmanager, which is callable as a decorator. branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type] branches_tf)) return tf.switch_case(index, branches_tf) tf_impl[lax.cond_p] = _cond def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: cond_consts, body_consts, init_carry = util.split_list( args, [cond_nconsts, body_nconsts]) if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined] # The conditional is not a scalar, this must be a batched while return _batched_cond_while( *args, cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr) # The conditional must return a single value to TF def cond_tf_func(*args: TfVal) -> TfVal: pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args, # Same name stack as the XLA translation of while_p extra_name_stack="while/cond") return pred body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts, extra_name_stack="while/body") # Sometimes TF infers more specific shapes for the init_carry, and this has # led to errors: "enters the loop with shape (1,), but has shape (None,) after one iteration" shape_invariants = [tf.TensorShape(_aval_to_tf_shape(_out_aval)) for _out_aval in body_jaxpr.out_avals] return tf.while_loop(cond_tf_func, body_tf_func, init_carry, shape_invariants=shape_invariants) def _batched_cond_while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: """Interprets a while_loop with a batched condition. A batched while has a conditional that returns a tensor of booleans, and a body that returns a list of tensors whose leading dimensions match those of the conditional tensor. We need to turn it into a while with scalar boolean conditional. We will expand the loop carry to include a prefix with the current tensor boolean condition. We prepend to the loop the first calculation of the tensor boolean condition. The loop condition will use a "reduce_any" to calculate a scalar boolean from the tensor boolean condition. The end of the loop body will compute the new carry using a "tf.where", and we compute the new tensor boolean condition. """ cond_consts, body_consts, init_carry = util.split_list( args, [cond_nconsts, body_nconsts]) # Initial computation of batched condition init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry, extra_name_stack="while/body_pred") def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal: pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape)))) return pred def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]: new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr, *body_consts, *carry, extra_name_stack="while/body") # We repeat those carries for which the loop termination condition is false def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: pred_b_bcast = _broadcast_in_dim( pred_b, shape=_jax_physical_aval(c_aval).shape, # a JAX shape broadcast_dimensions=list(range(len(pred_b.shape))), _in_avals=cond_jaxpr.out_avals, _out_aval=core.ShapedArray(c_aval.shape, np.bool_)) return tf.where(pred_b_bcast, new_c, c) selected_carry: Sequence[TfVal] = list(map(select_one_carry, new_carry, carry, body_jaxpr.out_avals)) next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry, extra_name_stack="body_pred") return (next_pred_b, *selected_carry) _, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func, (init_pred_b, *init_carry)) return res_carry tf_impl[lax.while_p] = _while # We use the scan impl rule to rewrite in terms of while. tf_impl_with_avals[lax.scan_p] = _convert_jax_impl( lax_control_flow._scan_impl, extra_name_stack="scan") tf_impl_with_avals[ad_checkpoint.remat_p] = \ _convert_jax_impl(partial(ad_checkpoint.remat_lowering, # TODO: jax2tf cannot discriminate by platform is_gpu_platform=False), multiple_results=True, extra_name_stack="checkpoint") tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x # TODO: Remove once tensorflow is 2.10.0 everywhere. if hasattr(tfxla, 'optimization_barrier'): tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]: # Some types originally incompatible with tf.math.top_k can be promoted # to a compatible type without loss of precision. def promote_tf_dtype(tf_dtype): if tf_dtype in [tf.bool, tf.uint8, tf.uint16]: return tf.uint32 if tf_dtype in [tf.int8, tf.int16]: return tf.int32 if tf_dtype is tf.float16: return tf.float32 return None conversion_dtype = promote_tf_dtype(operand.dtype) if core.is_special_dim_size(k): k_tf = _eval_shape((k,))[0] k_tf = tf.cast(k_tf, tf.int32) # TopK works only for int32 else: k_tf = k if conversion_dtype: values, indices = tf.math.top_k( tf.dtypes.cast(operand, conversion_dtype), k=k_tf, sorted=True) return tf.dtypes.cast(values, operand.dtype), indices else: return tf.math.top_k(operand, k=k_tf, sorted=True) tf_impl[lax.top_k_p] = _top_k def _approx_top_k(operand: TfVal, k: int, reduction_dimension: int, recall_target: float, is_max_k: bool, reduction_input_size_override: int, aggregate_to_topk: bool) -> Tuple[TfVal, TfVal]: k_tf = _eval_shape((k,))[0] if is_max_k: return tf.math.approx_max_k(operand, k_tf, reduction_dimension, recall_target, reduction_input_size_override, aggregate_to_topk) else: return tf.math.approx_min_k(operand, k_tf, reduction_dimension, recall_target, reduction_input_size_override, aggregate_to_topk) tf_impl[lax.approx_top_k_p] = _approx_top_k def _sort(*operands: TfVal, dimension: int, is_stable: bool, num_keys: int) -> Tuple[TfVal, ...]: assert 1 <= num_keys <= len(operands) assert 0 <= dimension < len( operands[0].shape ), f"Invalid {dimension} for ndim {len(operands[0].shape)}" comparator_spec: List[tf.TensorSpec] = [] comparator_jax_in_avals: List[core.ShapedArray] = [] for op in operands: o_spec = tf.TensorSpec((), dtype=op.dtype) comparator_spec.extend([o_spec, o_spec]) o_aval = core.ShapedArray((), _to_jax_dtype(op.dtype)) comparator_jax_in_avals.extend([o_aval, o_aval]) # Use the same comparator that JAX uses when compiling to XLA, to get the # proper NaN/Inf total order, and the lexicographic ordering. # The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1] # corresponding to two scalars from operand[k]. def lexicographic_comparator(*tf_args: TfVal) -> TfVal: return _convert_jax_impl( lax_internal._sort_lt_comparator, multiple_results=False)( *tf_args, _in_avals=comparator_jax_in_avals, _out_aval=core.ShapedArray((), np.bool_), num_keys=num_keys) xla_comparator_computation = ( tf.function(lexicographic_comparator, autograph=False).get_concrete_function(*comparator_spec)) results = tfxla.variadic_sort( operands, dimension=dimension, is_stable=is_stable, comparator=xla_comparator_computation) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: results = tuple(tf.stop_gradient(out) for out in results) # See #7839 return results tf_impl[lax.sort_p] = _sort def _fft(x, *, fft_type, fft_lengths, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3])) x_aval, = _in_avals x_shape = x_aval.shape if fft_type == IRFFT: expected_lengths = x_shape[-len(fft_lengths):-1] + ((x_shape[-1] - 1) * 2,) else: expected_lengths = x_shape[-len(fft_lengths):] if expected_lengths != fft_lengths: raise NotImplementedError( f"Unsupported {fft_lengths=} for {fft_type=} of " f"array with shape={x.shape}.") tf_funcs = { FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d], IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d], RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d], IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d] } res = tf_funcs[fft_type][len(fft_lengths) - 1](x) return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval)) tf_impl_with_avals[lax.fft_p] = _fft def _qr(operand, full_matrices): return tf.linalg.qr(operand, full_matrices=full_matrices) tf_impl[lax.linalg.qr_p] = _qr def _svd(operand, full_matrices, compute_uv): result = tf.linalg.svd(operand, full_matrices, compute_uv) if not compute_uv: return result, s, u, v = result return s, u, tf.linalg.adjoint(v) tf_impl[lax.linalg.svd_p] = _svd def _eig(operand: TfVal, compute_left_eigenvectors: bool, compute_right_eigenvectors: bool): if compute_left_eigenvectors and compute_right_eigenvectors: # TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to # sort the left eigenvectors in the right order. The jax.numpy.linalg API # suggests to me that left eigenvectors are anyway seldom used, so I # think it is acceptable to leave as unimplemented for now. msg = ("Conversion of eig is not implemented when both " "compute_left_eigenvectors and compute_right_eigenvectors are set " "to True.") raise NotImplementedError(msg) elif not (compute_left_eigenvectors or compute_right_eigenvectors): return tuple([tf.linalg.eigvals(operand)]) elif compute_right_eigenvectors: return tuple(tf.linalg.eig(operand)) else: # compute_left_eigenvectors == True wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand)) wHH = tf.math.conj(wH) return tuple([wHH, vl]) tf_impl[lax.linalg.eig_p] = _eig def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals, _out_aval): del sort_eigenvalues if operand.shape[-1] == 0: v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1])) else: if not lower: operand = tf.linalg.adjoint(operand) w, v = tf.linalg.eigh(operand) cast_type = { tf.complex64: tf.float32, tf.complex128: tf.float64 }.get(operand.dtype) if cast_type is not None: w = tf.cast(w, cast_type) return v, w tf_impl_with_avals[lax.linalg.eigh_p] = _eigh def _lu(operand: TfVal, _in_avals, _out_aval): return _convert_jax_impl(lax_linalg._lu_python, extra_name_stack="lu")( operand, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[lax.linalg.lu_p] = _lu def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool, transpose_a: bool, conjugate_a: bool, unit_diagonal: bool, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): if unit_diagonal: a_aval, _ = _in_avals a_shape = _eval_shape(a_aval.shape) a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype)) if not left_side: rank = len(a.shape) transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2] a = tf.transpose(a, transpose_dimensions) b = tf.transpose(b, transpose_dimensions) lower = not lower # adjoint == transpose for real dtypes, so special care need only be taken # for complex types. if a.dtype in [tf.complex64, tf.complex128]: if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a): a = tf.math.conj(a) result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a) if not left_side: result = tf.transpose(result, transpose_dimensions) return result tf_impl_with_avals[lax.linalg.triangular_solve_p] = _triangular_solve def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval): return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl, extra_name_stack="linear_solve")( *args, const_lengths=const_lengths, jaxprs=jaxprs, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[lax.linear_solve_p] = _linear_solve def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params): return _convert_jax_impl(lax_linalg._tridiagonal_solve_jax, multiple_results=False, extra_name_stack="tridiagonal_solve")( *args, _in_avals=_in_avals, _out_aval=_out_aval) tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable, num_consts: int) -> Sequence[TfVal]: # TODO(necula): ensure that there is no AD transformation in scope del jvp_jaxpr_thunk, num_consts return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp", fresh_constant_cache=False) tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr, **_) -> Sequence[TfVal]: # TODO(necula): ensure that there is no AD transformation in scope return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp", fresh_constant_cache=False) tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " "function.") tf_impl[ad.custom_lin_p] = _custom_lin PartitionsOrReplicated = Optional[Tuple[int, ...]] def split_to_logical_devices(tensor: TfVal, partition_dimensions: PartitionsOrReplicated): """Like TPUMPStrategy.experimental_split_to_logical_devices. For jax2tf purposes we want to avoid needing to thread the `strategy` object through the generated computation. It seems that the original function needs the strategy object only for error checking, which we assume is done upstream by JAX. Args: tensor: Input tensor to annotate. partition_dimensions: A list of integers, with one integer per tensor dimension, specifying in how many parts the dimension should be split. The product of integers must equal the number of devices per replica. use_sharding_op: whether to use a sharding op, or not. Returns: an annotated tensor. """ # TODO: this is only for sharded_jit. Either remove, or implement in terms # of _shard_values. if partition_dimensions is None: return xla_sharding.replicate(tensor, use_sharding_op=True) num_partition_splits = math.prod(partition_dimensions) tile_assignment = np.arange(num_partition_splits).reshape( partition_dimensions) return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) def _shard_value(val: TfVal, aval: core.ShapedArray, sd: sharding.XLACompatibleSharding, *, skip_replicated_sharding: bool) -> TfVal: """Apply sharding to a TfVal.""" if sharding_impls.is_unspecified(sd): return val sharding_proto: xla_client.OpSharding = cast( xla_client.OpSharding, sd._to_xla_hlo_sharding(aval.ndim).to_proto()) # type: ignore if (skip_replicated_sharding and op_shardings.is_op_sharding_replicated(sharding_proto)): return val # To use xla_sharding.py, we must have a xla_data_pb2.OpSharding. xla_sharding_proto: xla_data_pb2.OpSharding = ( xla_data_pb2.OpSharding( type=int(sharding_proto.type), tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions, tile_assignment_devices=sharding_proto.tile_assignment_devices, replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim, last_tile_dims=sharding_proto.last_tile_dims)) if tf_context.executing_eagerly(): raise ValueError( "A jit function with sharded arguments or results must be used under a `tf.function` context. " "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor( val, use_sharding_op=True) def _pjit(*args: TfVal, jaxpr: core.ClosedJaxpr, in_shardings: Sequence[sharding.XLACompatibleSharding], out_shardings: Sequence[sharding.XLACompatibleSharding], resource_env: maps.ResourceEnv, donated_invars, name: str, keep_unused: bool, inline: bool, _in_avals: Sequence[core.ShapedArray], _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars # Apply sharding annotation to the arguments sharded_args: Sequence[TfVal] = tuple( map(partial(_shard_value, skip_replicated_sharding=not _thread_local_state.enable_xla), args, _in_avals, in_shardings)) results = _interpret_jaxpr(jaxpr, *sharded_args, extra_name_stack=util.wrap_name(name, "pjit"), fresh_constant_cache=False) sharded_results: Sequence[TfVal] = tuple( map(partial(_shard_value, skip_replicated_sharding=not _thread_local_state.enable_xla), results, _out_aval, out_shardings)) return tuple(sharded_results) tf_impl_with_avals[pjit.pjit_p] = _pjit def _pjit_sharding_constraint(arg: TfVal, *, sharding: sharding.NamedSharding, resource_env: maps.ResourceEnv, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, **kwargs) -> TfVal: return _shard_value(arg, _in_avals[0], sharding, skip_replicated_sharding=False) tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint def _dimension_size_jax2tf(op: TfVal, *, dimension, _in_avals, _out_aval): dim_tf = tf.shape(op)[dimension] if dim_tf.dtype != _to_tf_dtype(_out_aval.dtype): return _convert_element_type(dim_tf, new_dtype=_out_aval.dtype, weak_type=_out_aval.weak_type) else: return dim_tf tf_impl_with_avals[shape_poly.dimension_size_p] = _dimension_size_jax2tf def _dim_as_value_jax2tf(dim: shape_poly.DimSize): dim_tf, = _eval_shape((dim,)) return dim_tf tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf def _reduce_precision(x, *, exponent_bits, mantissa_bits): return tfxla.reduce_precision(x, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) tf_impl[lax.reduce_precision_p] = _reduce_precision def _register_checkpoint_pytrees(): """Registers TF custom container types as pytrees.""" m = tf.Module() # The types here are automagically changed by TensorFlow's checkpointing # infrastructure. m.a = (tf.Module(), tf.Module()) m.b = [tf.Module(), tf.Module()] m.c = {"a": tf.Module()} tuple_wrapper = type(m.a) list_wrapper = type(m.b) dict_wrapper = type(m.c) # TF AutoTrackable swaps container types out for wrappers. assert tuple_wrapper is not tuple assert list_wrapper is not list assert dict_wrapper is not dict jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs: (tuple(xs), None), lambda _, xs: tuple(xs)) jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None), lambda _, xs: list(xs)) jax.tree_util.register_pytree_node( dict_wrapper, lambda s: (tuple(s.values()), tuple(s.keys())), lambda k, xs: dict_wrapper(zip(k, xs))) _register_checkpoint_pytrees()