Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jax2tf/jax2tf.py
2023-06-19 00:49:18 +02:00

3400 lines
131 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides JAX and TensorFlow interoperation APIs."""
from functools import partial
import contextlib
import math
import operator
import os
import re
import threading
from typing import (
Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union,
cast)
import warnings
from absl import logging
import numpy as np
import jax
from jax import lax
from jax import config
from jax import custom_derivatives
from jax import random
from jax import numpy as jnp
from jax import tree_util
from jax import sharding
from jax.experimental import maps
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import impl_no_xla
from jax.experimental.jax2tf import jax_export
from jax.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import pjit
from jax._src import prng
from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import slicing as lax_slicing
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
import tensorflow as tf # type: ignore[import]
# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
try:
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
except ModuleNotFoundError:
# This can be removed when TF 2.10 support is no longer needed.
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
from tensorflow.python.eager import context as tf_context # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape
DType = Any
# A temporary internal flag, to enable the wrapping of jax.jit functions
# with tf.function(jit_compile=True). See #7389. This change has triggered a
# number of failures in TF. We keep this until we are confident that it does
# not create problems.
# TODO(b/207464757): figure out why this change breaks test
_WRAP_JAX_JIT_WITH_TF_FUNCTION = False
# The scope name need to be a valid TensorFlow name. See
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
map = util.safe_map
zip = util.safe_zip
def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
if not _VALID_SCOPE_REGEX.match(scope_name):
scope_name = f".{scope_name}"
return scope_name
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Include all convertible types, even if not supported on accelerators.
with tf.device("CPU"):
tf.constant(v)
return True
except:
return False
# The implementation rules for primitives. The rule will be called with the
# arguments (TfVal) and must return TfVal (or a sequence thereof,
# if primitive.multiple_results). The exception are primarily the
# control-flow primitives.
tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
# Some primitive implementation rules need the abstract values of arguments
# and the results. This is the case for the primitives implemented using
# _convert_jax_impl and those that need to adjust the shape of the outputs
# due to missing TF shape inference rules for TFXLA ops. The rules for these
# primitives should be added to `tf_impl_with_avals`.
# The abstract value are passed to the implementation as two special kwargs
# `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a
# core.ShapedArray, or a tuple thereof when primitive.multiple_results).
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
# XLA is not linked in all environments when converting a primitive. If this is
# the case, we first search for implementation rules for primitives in the
# following map. These implementations are workarounds, making use of TF ops
# that do work when XLA is not linked in.
tf_impl_no_xla = impl_no_xla.tf_impl_no_xla
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
# conversion happens is:
# jax2tf.process_primitive (top of stack)
# jax tracing machinery ...
# tf.custom_gradient machinery ...
# jax2tf.converted_fun
# tf function machinery ...
# user code invokes the converted function on TF tensors
#
# We need to skip over not only JAX internal frames, but TF internal frames
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
class _ThreadLocalState(threading.local):
def __init__(self):
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
self.enable_xla = True
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions, for non-native lowering
self.shape_env: Sequence[Tuple[str, TfVal]] = ()
# Whether to actually include XLA op metadata in the generated TF ops
# TODO(b/189306134): implement support for XLA metadata
self.include_xla_op_metadata = False
# A cache for the tf.convert_to_tensor for constants. We try to preserve
# sharing for constants, to enable tf.Graph to take advantage of it.
# See https://github.com/google/jax/issues/7992.
self.constant_cache = None # None means that we don't use a cache. We
# may be outside a conversion scope.
# A cache for the outside tf name_scope when the converted
# function is running. We will add this as the prefix to the generated tf op
# name. For example, the tf op name will be like
# "{tf_outer_name_scope}/JAX_NAME_STACKS"
self.tf_outer_name_scope = ""
# A dict collecting all tf concrete_functions called by stablehlo.custom_call
# This is used only by native serialization (unlike all the other
# thread-local state).
self.call_tf_concrete_function_list: Optional[List[Any]] = None
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack() -> Union[NameStack, str]:
return source_info_util.current_name_stack()
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
def get_thread_local_state_call_tf_concrete_function_list() -> (
Optional[List[Any]]
):
return _thread_local_state.call_tf_concrete_function_list
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun_jax: Callable,
*,
polymorphic_shapes=None,
with_gradient=True,
enable_xla=True,
# TODO(necula): remove the experimental flag
experimental_native_lowering="default",
native_serialization="default",
native_serialization_platforms=(),
native_serialization_strict_checks=True) -> Callable:
"""Allows calling a JAX function from a TensorFlow program.
See
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during lowering.
.. warning:: The shape-polymorphic lowering is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be `None` (all arguments are monomorphic), a single PolyShape
or string (applies to all arguments), or a tuple/list of the same length
as the function arguments. For each argument the shape specification
should be `None` (monomorphic argument), or a Python object with the
same pytree structure as the argument.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
value of the gradient will be JAX-accurate.
enable_xla: if set (default), use the simplest conversion
and use XLA TF ops when necessary. These ops are known to create issues
for the TFLite and TFjs converters. For those cases, unset this parameter
so the lowering tries harder to use non-XLA TF ops to lower the
function and aborts if this is not possible. Cannot be set to `False`
when using `native_serialization`.
native_serialization: serialize the JAX function natively to
StableHLO with compatibility guarantees. This makes it easier to have
confidence that the code executed when calling this function from
TensorFlow is exactly the same as JAX would run natively.
The "default" value defers to `False` if `enable_xla`
is set to `False` or to the configuration flag
`--jax2tf_default_native_serialization` otherwise.
Native serialization cannot be used with `enable_xla=False`.
native_serialization_platforms: In conjunction with
`native_serialization`, specify the platform(s)
for which to lower the code. Must be a tuple of
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
The default (empty tuple), specifies the JAX default
backend on the machine where the lowering is done.
native_serialization_strict_checks: In conjunction with
`native_serialization`, enable the following
checks: (A) the lowered computation is executed on a platform for which it
was lowered; (B) the serialized computation contains only custom calls
with targets that are guaranteed to be stable, (more to come).
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if native_serialization == "default":
if not enable_xla:
native_serialization = False
else:
native_serialization = config.jax2tf_default_native_serialization
if native_serialization and not enable_xla:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
if native_serialization_platforms:
if not native_serialization:
warnings.warn(
"using native_serialization_platforms without native_serialization. "
"The parameter will have no effect, since the same code is serialized "
"for all platforms without native_serialization.")
if (not isinstance(native_serialization_platforms, (list, tuple)) or
not all(p in ["tpu", "cpu", "gpu"] for p in native_serialization_platforms)):
raise ValueError(
"native_serialization_platforms must be a sequence "
"containing a subset of {'cpu', 'gpu', 'tpu'}. "
f"Got: {native_serialization_platforms}")
native_serialization_platforms = tuple(native_serialization_platforms)
if len(native_serialization_platforms) > 1:
raise NotImplementedError(
"native_serialization_platforms is not yet implemented for multiple platforms")
api.check_callable(fun_jax)
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError(
"convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
def shape_and_dtype_tf(a: TfVal) -> Tuple[Sequence[Optional[int]], DType]:
# The shape and JAX dtype for a TF argument
tf_arg_shape = np.shape(a)
# Fix the shape for TF1
tf_arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
return tf_arg_shape, a_jax_dtype
args_specs = jax_export.poly_specs(args_tf,
polymorphic_shapes=polymorphic_shapes,
get_shape_and_dtype=shape_and_dtype_tf)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_specs = jax_export.poly_specs(kwargs_tf,
polymorphic_shapes=None,
get_shape_and_dtype=shape_and_dtype_tf)
combined_args_tf = (args_tf, kwargs_tf)
args_flat_tf: Sequence[TfVal]
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
args_flat_tf = tuple(
map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf))
impl: SerializationImpl
if native_serialization:
impl = NativeSerializationImpl(
fun_jax,
args_specs=args_specs, kwargs_specs=kwargs_specs,
native_serialization_platforms=native_serialization_platforms,
native_serialization_strict_checks=native_serialization_strict_checks)
else:
impl = GraphSerializationImpl(
fun_jax,
args_specs=args_specs, kwargs_specs=kwargs_specs,
args_flat_tf=args_flat_tf,
enable_xla=enable_xla)
try:
impl.before_conversion()
outs_tree: tree_util.PyTreeDef = None # type: ignore
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
nonlocal outs_tree
outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
return (tuple(outs_tf),
_make_custom_gradient_fn_tf(
impl=impl,
args_tf=args_flat_tf,
outs_avals=outs_avals,
outs_tf=outs_tf))
outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
outs_flat_tf = [
tf.raw_ops.PreventGradient(input=o, message=message)
for o in outs_tf
]
finally:
impl.after_conversion()
outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf]
out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf)
return out_tf
return converted_fun_tf
class SerializationImpl:
"""Implementation details for jax2tf serialization.
Abstract superclass for subclassing.
"""
def before_conversion(self):
"""Called in the resulting TF function, before any other method.
Useful to set any global context."""
raise NotImplementedError
def after_conversion(self):
"""Called in the resulting TF function, after conversion is done.
Useful to restore any global context set up by `before_conversion`."""
raise NotImplementedError
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> Tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
"""Runs the resulting TF function.
Args:
args_flat_tf: a flat tuple of tf.Tensor arguments
Returns: a tuple with:
outs_tfs: a flat tuple of tf.Tensor results
outs_avals: a flat tuple of JAX abstract values for the underlying JAX
function.
outs_tree: the PyTreeDef for the outputs
"""
raise NotImplementedError
def run_vjp_fun_tf(self,
vjp_args_flat_tf: Sequence[TfVal],
outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]:
"""Runs the VJP function as a TF function.
Args:
vjp_args_flat_tf: the flattened sequence of tf.Tensor, including the
primal arguments followed by the output cotangents.
outs_avals: the flattened primal outputs avals
Returns: the flattened sequence of input cotangents.
"""
raise NotImplementedError
class NativeSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
native_serialization_platforms: Sequence[str],
native_serialization_strict_checks: bool):
self.fun_jax = fun_jax
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.native_serialization_strict_checks = native_serialization_strict_checks
if native_serialization_platforms:
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
else:
self.lowering_platform = None
def before_conversion(self):
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
_thread_local_state.call_tf_concrete_function_list = []
def _restore_context():
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
self._restore_context = _restore_context
self.exported = jax_export.export(
self.fun_jax,
lowering_platform=self.lowering_platform,
strict_checks=self.native_serialization_strict_checks
)(*self.args_specs, **self.kwargs_specs)
def after_conversion(self):
self._restore_context()
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> Tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
results = _run_exported_as_tf(args_flat_tf, self.exported)
return results, tuple(self.exported.out_avals), self.exported.out_tree
def run_vjp_fun_tf(self,
vjp_args_flat_tf: Sequence[TfVal],
outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]:
del outs_avals
exported_vjp = self.exported.vjp()
vjp_args_flat_tf = tuple(tf.identity(arg, f"jax2tf_arg_{arg_idx}")
for arg_idx, arg in enumerate(vjp_args_flat_tf))
in_cts_flat = _run_exported_as_tf(vjp_args_flat_tf, exported_vjp)
return tuple(tf.identity(arg, "jax2tf_out") for arg in in_cts_flat)
class GraphSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
args_flat_tf: Sequence[TfVal],
enable_xla: bool):
self.fun_jax = fun_jax
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.enable_xla = enable_xla
fun_name = getattr(fun_jax, "__name__", "unknown")
name_stack = util.wrap_name(fun_name, "jax2tf")
self.name_stack = name_stack
self.args_flat_tf = args_flat_tf
def before_conversion(self):
prev_enable_xla = _thread_local_state.enable_xla
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
prev_tf_outer_name_scope = _thread_local_state.tf_outer_name_scope
def _restore_context():
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
_thread_local_state.tf_outer_name_scope = prev_tf_outer_name_scope
_thread_local_state.shape_env = ()
self._restore_context = _restore_context
_thread_local_state.enable_xla = self.enable_xla
# TODO(b/189306134): implement support for XLA metadata
_thread_local_state.include_xla_op_metadata = False
_thread_local_state.tf_outer_name_scope = tf.get_current_name_scope()
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
args_specs_flat, self.in_tree = tree_util.tree_flatten(
(self.args_specs, self.kwargs_specs))
self.args_avals_flat = tuple(
map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat))
dim_vars = shape_poly.all_dim_vars(self.args_avals_flat)
dim_values, _ = _interpret_fun_jax(
partial(shape_poly.compute_dim_vars_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
_thread_local_state.shape_env = zip(dim_vars, dim_values)
fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)
# out_tree_thunk will be ready after we call run_fun_tf below.
self.fun_flat_jax = fun_flat_jax
self.out_tree_thunk = out_tree_thunk
def after_conversion(self):
self._restore_context()
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> Tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
outs_tf, outs_avals = _interpret_fun_jax(
self.fun_flat_jax,
args_flat_tf, self.args_avals_flat,
self.name_stack,
fresh_constant_cache=True)
return outs_tf, outs_avals, self.out_tree_thunk()
def run_vjp_fun_tf(self,
vjp_args_flat_tf: Sequence[TfVal],
outs_avals: Sequence[core.AbstractValue]) -> Sequence[TfVal]:
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(self.args_avals_flat)])
_, pullback_jax = jax.vjp(self.fun_flat_jax, *args_flat_jax)
return pullback_jax(out_cts_flat_jax)
vjp_in_avals = tuple(self.args_avals_flat) + tuple(outs_avals)
vjp_polymorphic_shapes = tuple(str(a.shape) # Note: may be _DimExpr, not just DimVar
for a in vjp_in_avals) # type: ignore
return convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes,
native_serialization=False)(*vjp_args_flat_tf)
def dtype_of_val(val: TfVal) -> DType:
"""Computes the TensorFlow dtype using JAX's typing rules.
If the value is a tf.Tensor, it starts with its dtype. If the value is a
constant it uses JAX to infer its dtype. The resulting dtype follows the
JAX type inference rules, and depends on the value of the
JAX_ENABLE_X64 flag.
See README.md for how 64-bit values are treated.
"""
tval, _ = _tfval_to_tensor_jax_dtype(val)
return tval.dtype
@partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes")
def eval_polymorphic_shape(fun_jax: Callable,
*,
polymorphic_shapes=None) -> Callable:
"""Evaluates the output shape in presence of shape polymorphism.
This is done without lowering or executing the function, same as for
`jax.eval_shape`.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during shape evaluation. See discussion for `jax2tf.convert`.
.. warning:: The shape-polymorphic lowering is an experimental feature.
Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values
with `.shape` and `.dtype` attributes) corresponding to the inputs for
`fun_jax`, and returns a tuple with:
* the jax.ShapeDtypeStruct corresponding to the result, as for
`jax.eval_shape`. The shape may contain symbolic dimension expressions.
* the value that can be passed to `polymorphic_shapes` for a subsequent
call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`.
For example:
>>> import jax
>>> from jax.experimental import jax2tf
>>> from jax import numpy as jnp
>>>
>>> f = lambda A, x: jnp.sin(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x)
>>> print(out_spec.shape)
("a", "c")
>>> print(out_poly_shape)
(a, c)
>>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec)
>>> print(res_poly_shape)
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = jax_export.poly_specs(
args_specs, polymorphic_shapes=polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.
res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec)
return res_poly_spec, res_polymorphic_shape
return do_eval_polymorphic_shape
# Internals
def flatten_fun_jax(fun_jax: Callable,
in_tree,
) -> Tuple[Callable, Callable]:
"""Wraps the function to take a (flat) list of positional args.
jax2tf works better and is simpler when the JAX function takes and returns
just a tuple of values (no pytrees, no kwargs). This is in part because
jax.vjp does not support kwargs and we can only set
tf.custom_gradient on functions with flat arguments and results
Returns:
* the wrapped JAX function taking and returning a flat list of arguments
* a thunk that can be called after the wrapped function has been called
to return the output pytree.
"""
out_tree_ref = None
def fun_flat_jax(*args_flat_jax):
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
tree_res = fun_jax(*tree_args, **tree_kwargs)
res_flat_jax, out_tree = tree_util.tree_flatten(tree_res)
nonlocal out_tree_ref
assert out_tree_ref is None or out_tree_ref == out_tree
out_tree_ref = out_tree
return res_flat_jax
return fun_flat_jax, lambda: out_tree_ref
def preprocess_arg_tf(arg_idx: int,
arg_tf: TfVal) -> TfVal:
"""Pre-processes the TF args.
Returns: a tuple with the pre-processed TF arg, the TF shape, and the
JAX dtype.
"""
if not _is_tfval(arg_tf):
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
# May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf)
# Name input tensors; do this after we have cast the arguments
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
return arg_tf
def _make_custom_gradient_fn_tf(*,
impl: SerializationImpl,
args_tf: Sequence[TfVal],
outs_avals: Sequence[core.ShapedArray],
outs_tf: Sequence[TfVal]):
"""Prepares the TF function to be used with tf.custom_gradient.
Args:
impl: the serialization implementation details
args_tf: the flattened TF arguments of the primal function
outs_avals: the flattened output JAX abstract values of the primal function
outs_tf: the flattened TF outputs of the primal function
"""
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
if variables:
raise ValueError(
"Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"{variables=}")
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
# If the primal function has outputs of integer or bool types, and if we are
# under a tf.function context, then TF will pass None in _out_cts_flat
# in place of these values. We should change these to float0 or
# else JAX gets unhappy. See issue #6975.
if out_ct_tf is not None:
return out_ct_tf
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}"
# Note that out_ct_aval.shape contains dimension variable from the
# primal function scope. We use tf.zeros_like to make a 0 of the right shape.
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf))
vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf
in_cts_flat = impl.run_vjp_fun_tf(vjp_args_flat_tf, outs_avals)
# We do not need to fix the in_cts because the TF gradient machinery
# will adjust the unconnected gradients and those for integer types.
return in_cts_flat
return grad_fn_tf
@contextlib.contextmanager
def _extended_name_stack(extra_name_stack: Optional[str]):
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
if extra_name_stack
else contextlib.nullcontext())
with name_ctx:
yield
return
def _interpret_fun_jax(
fun_jax: Callable,
args_tf: Sequence[TfVal],
args_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False,
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
fresh_constant_cache=fresh_constant_cache)
del main
return util.unzip2(out_vals)
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
exported: jax_export.Exported,
) -> Sequence[TfVal]:
"""Runs the `exported` as an XlaCallModule TF op.
Returns: the flattened tuple of results.
"""
args_avals = exported.in_avals
# TF values may be integer types for float0
def _convert_value(val, aval):
# Check the shape
assert all(d_aval == d_val
for d_aval, d_val in zip(aval.shape, val.shape)
if core.is_constant_dim(d_aval)), (aval, val)
conversion_dtype = _to_tf_dtype(aval.dtype)
if conversion_dtype != aval.dtype:
return tf.cast(val, conversion_dtype)
else:
return val
args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals))
out_shapes_tf = tuple(
tuple(d if core.is_constant_dim(d) else None
for d in out_aval.shape)
for out_aval in exported.out_avals)
out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals)
kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx]
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]
call_module_attrs = dict(
version=exported.xla_call_module_version,
Tout=out_types,
Sout=out_shapes_tf,
function_list=[
concrete_fn.function_def.signature.name
for concrete_fn in _thread_local_state.call_tf_concrete_function_list
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
)
if exported.xla_call_module_version >= 3:
if exported.strict_checks:
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
else:
call_module_attrs["platforms"] = () # No platform checking
if logging.vlog_is_on(3):
# We already logged the MLIR module when we exported it.
logging.vlog(3, "XlaCallModule %s", str(call_module_attrs))
call_module_attrs["module"] = exported.mlir_module_serialized
# Apply the shardings on arguments and results for pjit. This is redundant
# because the mlir_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
# sharding without digging into the `module` attribute of the `XlaCallModule`
# op, in the same way as it is done for the legacy jax2tf conversion.
# Do not apply XlaSharding for REPLICATED, on inputs and outputs.
# This is an agreed convention, and also improves usability under TF eager.
# See b/255511660.
if exported.in_shardings is not None:
args_flat_tf = tuple(
map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()),
kept_args_flat_tf, kept_args_avals, exported.in_shardings))
res = tfxla.call_module(args_flat_tf, **call_module_attrs)
# TODO(b/278940799): Replace the TF v1 API with public TF2 API.
# Add the custom call tf.function into the default graph, so those functions
# will be available during tf.SavedModel.save.
if _thread_local_state.call_tf_concrete_function_list is not None:
for concrete_fn in _thread_local_state.call_tf_concrete_function_list:
tf.compat.v1.get_default_graph()._add_function_recursive(
concrete_fn._inference_function
)
if exported.out_shardings is not None:
res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()),
res, exported.out_avals, exported.out_shardings))
res = tuple(map(_convert_value, res, exported.out_avals))
return res
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
fresh_constant_cache: bool = False
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
try:
prev_constant_cache = _thread_local_state.constant_cache
# Start a new cache, so that we don't share constants across tf.function
# boundaries.
if fresh_constant_cache:
_thread_local_state.constant_cache = {}
else:
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
fun.call_wrapped(*in_vals)
finally:
if (not fresh_constant_cache and
prev_constant_cache is not None and
_WRAP_JAX_JIT_WITH_TF_FUNCTION):
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
# Delete the newly added keys
for k in newly_added_keys:
del prev_constant_cache[k]
_thread_local_state.constant_cache = prev_constant_cache
return out_vals
def _convert_jax_impl(impl_jax: Callable, *,
multiple_results=True,
with_physical_avals=False,
extra_name_stack: Optional[str] = None) -> Callable:
"""Convert the JAX implementation of a primitive.
Args:
impl_jax: typically the impl-rule for a primitive, with signature
`(*args_jax: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
a primitive in terms of other primitives.
multiple_results: whether `impl_jax` returns a sequence of results.
extra_name_stack: additional element to add to the name stack for the
converted ops.
Returns:
a function with signature `(*args_tf: TfVal, _in_avals, _out_aval, **kwargs)
-> Sequence[TfVal]`.
"""
def wrapped_tf(*args_tf: TfVal, _in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> Sequence[TfVal]:
if with_physical_avals:
_in_avals = map(_jax_physical_aval, _in_avals)
_out_aval = _jax_physical_aval(_out_aval)
# We wrap the impl_jax to always return a tuple of results.
def impl_multiple_results_jax(*args_jax):
results_jax = impl_jax(*args_jax, **kwargs)
return results_jax if multiple_results else [results_jax]
results_tf, _ = _interpret_fun_jax(
impl_multiple_results_jax, args_tf, _in_avals,
extra_name_stack)
return results_tf if multiple_results else results_tf[0]
return wrapped_tf
@lu.transformation
def _interpret_subtrace(main: core.MainTrace,
in_avals: Sequence[core.ShapedArray],
*in_vals: TfVal):
trace = TensorFlowTrace(main, core.cur_sublevel())
in_tracers = tuple(
TensorFlowTracer(trace, val, aval)
for val, aval in zip(in_vals, in_avals))
outs = yield in_tracers, {} # type: Sequence[TfVal]
out_tracers: Iterable[TensorFlowTracer] = (
map(trace.full_raise, outs)) # type: ignore
out_vals_with_avals: Sequence[Tuple[TfVal, core.ShapedArray]] = (
tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
extra_name_stack: Optional[str],
fresh_constant_cache: bool = True) -> Sequence[TfVal]:
"""Evaluates a Jaxpr with tf.Tensor arguments.
This is most often used as the body of a tf.function, or tf.switch_case,
in which case it should use a fresh constant cache.
The output is a sequence of TfVal, suitable for use with TF.
"""
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
args_tf, jaxpr.in_avals, extra_name_stack,
fresh_constant_cache=fresh_constant_cache)
return outs_tf
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
"""Converts JAX avals from logical to physical, if relevant.
JAX might have avals whose logical vs physical shape/dtype may
differ, and only the physical view is expected to possibly
relate to TF. TF impl rules should operate on the physical form.
A JAX logical aval might even correspond, in principle, to several
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
physical_aval = core.physical_aval(aval)
assert (len(physical_aval.shape) >= len(aval.shape) and
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
return physical_aval
def _jax_physical_dtype(dtype):
# assuming () is a fine stand-in shape
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
aval = _jax_physical_aval(aval)
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
aval.shape)) # type: ignore[attr-defined]
# In the TF world, we represent float0 as zeros of this type.
# We pick bool because this is what JAX uses when it lowers float0 to HLO.
_tf_np_dtype_for_float0 = np.bool_
def _to_tf_dtype(jax_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
try:
jax_dtype = _jax_physical_dtype(jax_dtype)
except TypeError:
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
# tf.float32), so there is no physical dtype anyway
pass
if jax_dtype == dtypes.float0:
jax_dtype = _tf_np_dtype_for_float0
return tf.dtypes.as_dtype(jax_dtype)
def _to_jax_dtype(tf_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
if dt not in dtypes._jax_dtype_set:
raise TypeError(f"dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
return dt
def _tfval_to_tensor_jax_dtype(val: TfVal,
jax_dtype: Optional[DType] = None,
memoize_constants=False) -> Tuple[TfVal, DType]:
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
If `jax_dtype` is missing, uses JAX typing rules.
See README.md for details regarding 64-bit values.
Args:
val: a scalar, ndarray, tf.Tensor, or tf.Variable
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
rules for constants.
memoize_constants: whether to memoize TF constants. We can't do this
everywhere, we may be outside of a conversion scope.
Returns:
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
"""
if isinstance(val, (tf.Tensor, tf.Variable)):
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
conversion_dtype = _to_tf_dtype(jax_dtype)
if conversion_dtype != val.dtype: # May need to cast for 64-bit values
return tf.cast(val, conversion_dtype), jax_dtype
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or xla.abstractify(val).dtype
# TODO(document): We assume that the value of a constant does not
# change through the scope of the function. But it may be an ndarray, ...
# JAX has the same problem when generating HLO.
const_key = (id(val), jax_dtype)
# Since we use id(val) as a cache key, we have to make sure that we keep
# the previous `val` alive. Otherwise, for a ndarray, it can get garbage
# collected and reused for a different value, which would create correctness
# issues. We keep the `val` alive by storing in the cache the pair
# `(val, tf_val)`.
# Only memoize non-scalars. JAX will lift all non-scalar constants as
# Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them
# early, when entering the Jaxpr, so we create the tf.const early and its
# scope is the entire Jaxpr.
do_memoize = (memoize_constants and np.size(val) > 1 and _thread_local_state.constant_cache is not None)
if do_memoize:
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
else:
tf_val = None
if tf_val is None:
conversion_dtype = _to_tf_dtype(jax_dtype)
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
if hasattr(val, 'dtype') and dtypes.is_opaque_dtype(val.dtype):
val = val.dtype._rules.physical_const(val)
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
# Returns a tuple of shape_poly.dim_as_value_dtype
# Used only for non-native lowering
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
if dtype is not None:
shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape
if core.is_constant_shape(shape):
return tuple(int(d) for d in shape)
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
shape_values_tf, _ = _interpret_fun_jax(
partial(core.evaluate_shape, shape, dim_vars),
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
# Keep only the non-constant dimensions
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf
for d, d_tf in zip(shape, shape_values_tf))
def _ensure_tf_shape_if_dynamic(x: TfVal, shape):
# Update TF tensor `x` with shape `shape` if the shape of `x`` is dynamic.
if x.shape.is_fully_defined():
return x
return tf.ensure_shape(x, shape)
def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]):
"""Asserts that shape matches x.shape in the known dimensions and has
dimension polynomials elsewhere."""
# Ensures that the shape does not contain None; it should contain symbolic expressions.
def check_one(xd: Optional[int], sd: Any):
if core.is_constant_dim(sd):
return xd == sd
else:
assert isinstance(sd, shape_poly._DimExpr)
return True
assert (len(x.shape) == len(shape) and
all(check_one(xd, sd)
for xd, sd in zip(x.shape, shape))), \
f"Shape {shape} does not match x.shape {x.shape}"
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
# pylint: disable=assigning-non-slot
class TensorFlowTracer(core.Tracer):
"""Tracer class that boxes a TF value and a JAX abstract value.
In addition to the TF value we carry the JAX abstract value because
there are some cases when it cannot be recovered from the value:
when we are converting with polymorphic shapes or when the JAX aval
has a custom element type. In these cases the shape of the value may
have dimensions set to `None`, or it may only correspond to the JAX
"physical" (TF/lowering-compatible) shape, so the JAX abstract value
may contain more precise information.
When the value has a partially-known shape, the dimensions marked as `None`
must correspond to non-constant dimensions in the abstract value.
See README.md for details.
"""
# val: TfVal
# _aval: core.ShapedArray
__slots__ = ["val", "_aval"]
def __init__(self, trace: "TensorFlowTrace", val: TfVal,
aval: core.AbstractValue):
self._trace = trace
self._aval = aval
phys_aval = _jax_physical_aval(self._aval) # type: ignore[arg-type]
if isinstance(val, (tf.Tensor, tf.Variable)):
val_shape = val.shape
if config.jax_enable_checks:
assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than {val_shape=}"
# To compare types, we must handle float0 in JAX and x64 in TF
if phys_aval.dtype == dtypes.float0:
assert _to_tf_dtype(phys_aval.dtype) == val.dtype, f"expected {phys_aval.dtype} == {val.dtype}"
else:
assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}"
for aval_dim, val_dim in zip(phys_aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
assert shape_poly.is_poly_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not shape_poly.is_poly_dim(aval_dim):
assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
else:
# We have a TF value with known shape, and the abstract shape is a shape variable.
try:
aval_int = int(_eval_shape([aval_dim])) # type: ignore
except (TypeError, KeyError):
continue
assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
self.val = _tfval_to_tensor_jax_dtype(val,
phys_aval.dtype,
memoize_constants=True)[0] # type: ignore[attr-defined]
@property
def aval(self):
return self._aval
def full_lower(self):
return self
def _make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
) -> xla_client.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ core.str_eqn_compact(primitive.name, params))
frame = source_info_util.user_frame(source_info)
return xla_client.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=xla.get_canonical_source_file(frame) if frame else None,
source_line=frame.start_line if frame else None)
class TensorFlowTrace(core.Trace):
"""Trace class that underlies the jax2tf transformation.
We are going to ensure that jax2tf.convert is never nested inside other
transformations. This is sufficient for intended use cases (converting
fully-transformed JAX code). It also simplifies our job because we do not have
to handle situations where we apply primitives on a mix of TF values and
JAX tracers from an outer transformation. E.g., for addition both the TF
values
and the JAX tracers have an override and they get confused if they see values
from the other world.
Hence a TFT trace does not interact with non-TFT traces at lower-level. For
higher-order control-flow primitives we invoke recursively
_interpret_fun on the body of the conditional, which will create a nested TFT.
We do want to allow transformations nested inside a TensorFlowTrace (TFT), but
those will introduce their own MainTrace, and any operations involving those
will be done on those traces, i.e., not a concern for TFT.
"""
def pure(self, val: TfVal) -> TensorFlowTracer:
"""Lifts a non-Tracer into the TensorFlowTracer.
This function may be called by way of trace.full_raise.
"""
if hasattr(val, "__jax_array__"):
val = val.__jax_array__()
if isinstance(val, TensorFlowTracer):
return val
tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True)
return TensorFlowTracer(
self, tf_val, core.ShapedArray(np.shape(val), jax_dtype,
weak_type=dtypes.is_weakly_typed(val)))
def lift(self, val: core.Tracer) -> TensorFlowTracer:
# This would be called when we need to raise a tracer from a lower-level
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
# inside another transform, there are no lower-level main traces.
assert False
def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer:
# This is called when we need to raise a tracer from the same main,
# but a lower sublevel. This could come from a nested jit.
return TensorFlowTracer(self, val.val, val._aval)
def process_primitive(self, primitive: core.Primitive,
tracers: Sequence[TensorFlowTracer],
params) -> TensorFlowTracer:
impl, impl_needs_avals = self.get_primitive_impl(primitive)
args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
# This is a bit conservative, doing abstract_eval even in op-by-op execution
# but we needed it for, e.g., shape_polymorphism where only JAX's
# abstract evaluation rules can properly track polymorphic shapes.
# Unfortunately under op-by-op execution this is a rare occasion where we
# need abstract evaluation.
out_aval, _ = primitive.abstract_eval(*args_avals, **params)
args_tf: Sequence[TfVal] = [t.val for t in tracers]
def invoke_impl() -> TfVal:
if impl_needs_avals:
return impl(
*args_tf,
_in_avals=args_avals, # type: ignore
_out_aval=out_aval,
**params)
else:
return impl(*args_tf, **params)
current_name_stack = _get_current_name_stack()
# We don't use `str(name_stack)` because it uses parentheses for
# transformations, which aren't allowed in `name_scope`.
scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr]
# Here we reset the name scope to the memorized TF name scope
# + JAX name stack by using absolute scope.
# We need to add a '/' to the name stack string to force `tf.name_scope`
# to interpret it as an absolute scope, not a relative scope.
if _thread_local_state.tf_outer_name_scope:
scope = f"{_thread_local_state.tf_outer_name_scope}/{scope}"
if not scope.endswith("/"):
scope = scope + "/"
with tf.name_scope(_sanitize_scope_name(scope)):
if _thread_local_state.include_xla_op_metadata:
op_metadata = _make_op_metadata(primitive, params,
source_info=source_info_util.current())
op_metadata_proto = xla_data_pb2.OpMetadata(
op_type=op_metadata.op_type,
op_name=op_metadata.op_name,
source_file=op_metadata.source_file,
source_line=op_metadata.source_line
)
with tf_ops.get_default_graph()._attr_scope(
{"_XlaOpMetadata": attr_value_pb2.AttrValue(
s=op_metadata_proto.SerializeToString())}):
val_out = invoke_impl()
else:
val_out = invoke_impl()
if primitive.multiple_results:
out = [
TensorFlowTracer(self, v, a)
for v, a in zip(val_out, out_aval)
] # type: ignore
else:
out = TensorFlowTracer(self, val_out, out_aval) # type: ignore
# Check that the impl rule returned a value of expected shape and dtype
# TODO: adapt this to match polymorphic shapes
if config.jax_enable_checks:
if primitive.multiple_results:
for o, expected_aval in zip(out, out_aval): # type: ignore
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
f"{primitive}: out.aval = {o.aval}; expected {expected_aval}")
else:
assert out.aval == out_aval, ( # type: ignore
f"{primitive}: out.aval = {out.aval}; expected {out_aval}"
) # type: ignore
return out # type: ignore
def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun,
tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results
vals: Sequence[TfVal] = [t.val for t in tracers]
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
extra_name_stack = None
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
vals_out = interpreted_fun.call_wrapped(*vals)
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
def post_process_call(self, call_primitive: core.Primitive,
out_tracers: Sequence[TensorFlowTracer], params):
# We encountered a call primitive whose result (out_tracers) include
# TensorFlowTracer that were not passed through its arguments (captured from
# the environment).
vals = tuple(t.val for t in out_tracers)
main = self.main
def todo(vals: Sequence[TfVal]):
# TODO: is name_stack correct?
trace = TensorFlowTrace(main, core.cur_sublevel())
return [
TensorFlowTracer(trace, v, out_tracer.aval)
for v, out_tracer in zip(vals, out_tracers)
]
return vals, todo
def process_map(self, map_primitive, f, tracers, params):
raise NotImplementedError("process_map")
def post_process_map(self, map_primitive, out_tracers, params):
raise NotImplementedError("post_process_map")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del jvp, symbolic_zeros # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_jvp_call(self, out_tracers, _):
assert False # unreachable assuming jax2tf runs with clean trace state
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del fwd, bwd, out_trees, symbolic_zeros # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable assuming jax2tf runs with clean trace state
def post_process_custom_vjp_call_fwd(self, *_, **__):
assert False # unreachable assuming jax2tf runs with clean trace state
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
# Returns the primitive implementation and whether the implementation
# takes abstract values (see definition of tf_impl_with_avals)
if not _thread_local_state.enable_xla:
try:
return tf_impl_no_xla[p], True # Always require avals.
except KeyError:
pass
try:
return tf_impl[p], False
except KeyError:
try:
return tf_impl_with_avals[p], True
except KeyError as err:
msg = "TensorFlow interpretation rule for '{}' not implemented"
raise NotImplementedError(msg.format(p)) from err
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"
# Call primitives are inlined
for unexpected in [core.call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"clz",
"igamma_grad_a",
"random_gamma_grad",
"reduce_xor",
"schur",
"closed_call",
"unreachable",
"bint",
"getslice",
"full_to_shard",
"shard_to_full",
"pure_callback",
"for",
"inspect_sharding",
"io_callback",
"shard_map",
"global_array_to_host_local_array",
"host_local_array_to_global_array",
"call_exported",
# Not high priority?
"after_all",
"all_to_all",
"check",
"create_token",
"custom_transpose_call",
"custom_vmap_call",
"infeed",
"linear_call",
"outfeed",
"pmax_p",
"pmin",
"ppermute",
"psum",
"pmax",
"pgather",
"reduce_scatter",
"axis_index",
"pdot",
"all_gather",
"lu_pivots_to_permutation",
"xla_pmap",
"geqrf",
"householder_product",
"hessenberg",
"tridiagonal",
"eigh_jacobi",
]
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
tf_impl[ad_util.zeros_like_p] = tf.zeros_like
def _add(x: TfVal, y: TfVal) -> TfVal:
return tf.raw_ops.AddV2(x=x, y=y)
tf_impl[ad_util.add_jaxvals_p] = _add
tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x
tf_impl[lax_internal.copy_p] = lambda x: x
def _neg(x: TfVal) -> TfVal:
if x.dtype.is_unsigned:
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[x.dtype]
x_signed = tf.cast(x, signed_dtype)
res_signed = tf.math.negative(x_signed)
return tf.cast(res_signed, x.dtype)
else:
return tf.math.negative(x)
tf_impl[lax.neg_p] = _neg
def _sign(x: TfVal) -> TfVal:
if x.dtype.is_unsigned:
# TF and XLA do not support tf.math.sign for unsigned types.
return tf.where(
tf.math.equal(x, 0), tf.constant(0, dtype=x.dtype),
tf.constant(1, dtype=x.dtype))
else:
return tf.math.sign(x)
tf_impl[lax.sign_p] = _sign
tf_impl[lax.floor_p] = tf.math.floor
tf_impl[lax.ceil_p] = tf.math.ceil
def _round(operand, *, rounding_method,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO:
# JAX uses a single HLO op Round here
sign = _sign(operand)
operand *= sign
floor = tf.math.floor(operand)
operand -= floor
cond = tf.math.equal(operand, tf.constant(np.array(0.5), operand.dtype))
return sign * (
tf.where(cond, tf.constant(np.array(1), operand.dtype),
tf.math.round(operand)) + floor)
else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN
return tf.math.round(operand)
tf_impl_with_avals[lax.round_p] = _round
tf_impl[lax.nextafter_p] = tf.math.nextafter
def _population_count(x):
orig_dtype = x.dtype
return tf.cast(tf.raw_ops.PopulationCount(x=x), orig_dtype)
tf_impl[lax.population_count_p] = _population_count
tf_impl[lax.is_finite_p] = tf.math.is_finite
def _abs(x: TfVal) -> TfVal:
# TF and XLA do not support tf.math.abs for unsigned types.
return tf.math.abs(x) if not x.dtype.is_unsigned else x
tf_impl[lax.abs_p] = _abs
tf_impl[lax.pow_p] = tf.math.pow
def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
# Follows the implementation in lax._integer_pow_translation_rule
if y == 0:
return tf.broadcast_to(
tf.constant(1, dtype=x.dtype, shape=()), _eval_shape(_out_aval.shape))
is_reciprocal = y < 0
if is_reciprocal:
y = -y
acc = None
while y > 0:
if y & 1:
acc = x if acc is None else tf.math.multiply(acc, x)
y >>= 1
if y > 0:
x = tf.math.multiply(x, x)
return tf.math.reciprocal(acc) if is_reciprocal else acc
tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
tf_impl[lax.tan_p] = tf.math.tan
tf_impl[lax.tanh_p] = tf.math.tanh
tf_impl[lax.sin_p] = tf.math.sin
tf_impl[lax.sinh_p] = tf.math.sinh
tf_impl[lax.cos_p] = tf.math.cos
tf_impl[lax.cosh_p] = tf.math.cosh
tf_impl_with_avals[lax.acos_p] = _convert_jax_impl(
lax_internal.acos_impl, multiple_results=False)
tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(
lax_internal.asin_impl, multiple_results=False)
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
lax_internal.atan_impl, multiple_results=False)
# TODO(phawkins): use tf.math.sigmoid here instead.
tf_impl_with_avals[lax.logistic_p] = _convert_jax_impl(
lax_internal.logistic_impl, multiple_results=False)
def _atan2(y, x, **kwargs):
if x.dtype.is_complex or y.dtype.is_complex:
complex_component_dtype = {
tf.complex64: tf.float32,
tf.complex128: tf.float64
}.get(y.dtype)
zero = tf.constant(0, complex_component_dtype)
one = tf.constant(1, complex_component_dtype)
i = tf.complex(zero, one)
return -i * tf.math.log((x + i * y)/tf.math.sqrt(x * x + y * y))
else:
return tf.math.atan2(y, x)
tf_impl[lax.atan2_p] = _atan2
tf_impl[lax.acosh_p] = tf.math.acosh
tf_impl[lax.atanh_p] = tf.math.atanh
tf_impl[lax.asinh_p] = tf.math.asinh
tf_impl[lax.sqrt_p] = tf.math.sqrt
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
def _cbrt(x):
return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3)
tf_impl[lax.cbrt_p] = _cbrt
tf_impl[lax.lgamma_p] = tf.math.lgamma
tf_impl[lax.digamma_p] = tf.math.digamma
tf_impl[lax.igamma_p] = tf.math.igamma
tf_impl[lax.igammac_p] = tf.math.igammac
tf_impl[lax.regularized_incomplete_beta_p] = tf.math.betainc
tf_impl[lax.erf_p] = tf.math.erf
tf_impl[lax.erfc_p] = tf.math.erfc
tf_impl[lax.erf_inv_p] = tf.math.erfinv
tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
tf_impl[lax.complex_p] = tf.complex
def _conj(x, **kwargs):
# The only dtypes that are allowed are: float32, float64, complex64, and
# complex128.
if x.dtype == tf.float32:
return tf.cast(x, tf.complex64)
elif x.dtype == tf.float64:
return tf.cast(x, tf.complex128)
else:
return tf.math.conj(x)
tf_impl[lax.conj_p] = _conj
tf_impl[lax.real_p] = tf.math.real
tf_impl[lax.imag_p] = tf.math.imag
tf_impl[lax.add_p] = _add
tf_impl[lax.sub_p] = tf.math.subtract
tf_impl[lax.mul_p] = tf.math.multiply
def _iota(*, dtype, shape, dimension):
dtype = _to_tf_dtype(dtype)
# Some dtypes are unsupported, like uint32, so we just fall back to int32.
# TODO(mattjj, necula): improve tf.range dtype handling
shape_tf = _eval_shape(shape)
vec = tf.range(tf.cast(shape_tf[dimension], tf.int32), dtype=tf.int32)
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape_tf), dtype)
tf_impl[lax.iota_p] = _iota
def _div(lhs, rhs):
if lhs.dtype.is_integer:
quotient = tf.math.floordiv(lhs, rhs)
select = tf.math.logical_and(
tf.not_equal(_sign(lhs), _sign(rhs)),
tf.not_equal(tf.math.floormod(lhs, rhs), 0))
return tf.where(select, quotient + 1, quotient)
else:
return tf.math.truediv(lhs, rhs)
def _rem(lhs, rhs):
return _sign(lhs) * tf.math.floormod(_abs(lhs), _abs(rhs))
tf_impl[lax.div_p] = _div
tf_impl[lax.rem_p] = _rem
def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,) -> TfVal:
# For complex numbers use lexicographic ordering, like JAX
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
return _convert_jax_impl(
partial(lax_internal._minmax_complex_lowering,
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
elif x.dtype.as_numpy_dtype == np.bool_:
return (tf.math.logical_and if is_min else tf.math.logical_or)(x, y)
else:
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
# For reducers we will need min/max for scalars only. In that case we
# can construct the AbstractValues outselves, even in the presence of
# shape polymorphism.
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
return _minmax(x, y, is_min=is_min,
_in_avals=[aval, aval], _out_aval=aval)
tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False)
tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True)
# Map from TF signed types to TF unsigned types.
_SIGNED_TO_UNSIGNED_TABLE = {
tf.int8: tf.uint8,
tf.int16: tf.uint16,
tf.int32: tf.uint32,
tf.int64: tf.uint64,
}
# Map from TF unsigned types to TF signed types.
_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()}
# Note: Bitwise operations only yield identical results on unsigned integers!
# pylint: disable=protected-access
def _shift_right_arithmetic_raw(x, y):
if x.dtype.is_unsigned:
assert x.dtype == y.dtype
orig_dtype = x.dtype
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype]
x = tf.cast(x, signed_dtype)
y = tf.cast(y, signed_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
else:
return tf.bitwise.right_shift(x, y)
def _shift_right_arithmetic(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA
# semantics to return the shift by the max value (x_bits - 1).
# TODO: it is likely better to add XlaOps for shifts
x_bits = 8 * x.dtype.size
clamp_y = tf.where(_shift_in_bounds(x, y), y, x_bits - 1)
return _shift_right_arithmetic_raw(x, clamp_y)
tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic
def _shift_right_logical_raw(x, y):
if x.dtype.is_unsigned:
return tf.bitwise.right_shift(x, y)
else:
assert x.dtype == y.dtype
orig_dtype = x.dtype
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype]
x = tf.cast(x, unsigned_dtype)
y = tf.cast(y, unsigned_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
def _shift_right_logical(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA semantics
# to return 0.
# TODO: it is likely better to add XlaOps for shifts
return tf.where(
_shift_in_bounds(x, y), _shift_right_logical_raw(x, y), tf.zeros_like(x))
tf_impl[lax.shift_right_logical_p] = _shift_right_logical
def _shift_left(x, y):
# TF shift is "implementation defined" if the shift amount is negative
# or larger or equal to the size of the value. We implement the XLA semantics
# to return 0.
# TODO: it is likely better to add XlaOps for shifts
return tf.where(
_shift_in_bounds(x, y), tf.bitwise.left_shift(x, y), tf.zeros_like(x))
tf_impl[lax.shift_left_p] = _shift_left
def _shift_in_bounds(x: TfVal, y: TfVal) -> TfVal:
# Return the TF expression for when y is within bounds (0 <= y < |x|)
x_bits = 8 * x.dtype.size
# TF does not have comparisons for uint16 and uint32 (despite what the
# documentation says)
y_comp = tf.cast(
y, _UNSIGNED_TO_SIGNED_TABLE[y.dtype]) if y.dtype.is_unsigned else y
y_lt_x_bits = tf.math.less(y_comp, x_bits)
y_ge_0 = tf.math.greater_equal(y_comp, 0)
return tf.logical_and(y_lt_x_bits, y_ge_0)
def _not(x):
"""Computes bitwise not with support for booleans.
Numpy and JAX support bitwise not for booleans by applying a logical not!
This means that applying bitwise_not yields an unexpected result:
jnp.bitwise_not(jnp.array([True, False]))
>> DeviceArray([False, True], dtype=bool)
if you assume that booleans are simply casted to integers.
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
>> DeviceArray([True, True], dtype=bool)
"""
if x.dtype == tf.bool:
return tf.logical_not(x)
else:
return tf.bitwise.invert(x)
tf_impl[lax.not_p] = _not
def handle_boolean_args(f, argnums: Sequence[int], boolean_f=None):
"""Computes functions with some bool args and bool results using int8.
This is needed because some TF ops do not work for bool args, e.g.,
inequalities, min/max.
Args:
f: a TF callable to wrap. It will be called with non-boolean arguments.
argnums: the positional arguments that may be booleans.
boolean_f: [Optional] a TF callable compatible with boolean
arguments.
Returns: a TF callable that can take a mix of boolean positional arguments
(in the positions specified by `argnums`) and some non-boolean positional
arguments. If there are no boolean arguments, just calls `f`. Otherwise,
it calls `boolean_f` if defined. Otherwise, casts the boolean
arguments to `int8`, calls `f`, then casts the result to `bool`.
"""
argnums = tf.nest.flatten(argnums)
def wrapper(*args: TfVal, **kwargs):
argnum_types = {args[i].dtype for i in argnums}
if tf.bool not in argnum_types:
return f(*args, **kwargs)
else:
# All argnums should be boolean
assert len(argnum_types) == 1, argnum_types
if boolean_f != None:
return boolean_f(*args, **kwargs)
else:
args_cast = [(tf.cast(a, tf.int8) if i in argnums else a)
for i, a in enumerate(args)]
if "_in_avals" in kwargs:
def cast_aval(aval):
assert aval.dtype == np.bool_
return core.ShapedArray(aval.shape, np.int8)
_in_avals_cast = [
cast_aval(aval) if i in argnums else aval
for i, aval in enumerate(kwargs["_in_avals"])
]
_out_aval_cast = tf.nest.map_structure(cast_aval, kwargs["_out_aval"])
kwargs = dict(
kwargs, _in_avals=_in_avals_cast, _out_aval=_out_aval_cast)
out = f(*args_cast, **kwargs)
return tf.nest.map_structure(lambda o: tf.cast(o, tf.bool), out)
return wrapper
tf_impl[lax.or_p] = handle_boolean_args(tf.bitwise.bitwise_or, argnums=(0, 1), boolean_f=tf.logical_or)
tf_impl[lax.and_p] = handle_boolean_args(tf.bitwise.bitwise_and, argnums=(0, 1), boolean_f=tf.logical_and)
tf_impl[lax.xor_p] = handle_boolean_args(tf.bitwise.bitwise_xor, argnums=(0, 1), boolean_f=tf.math.logical_xor)
tf_impl[lax.eq_p] = tf.math.equal
tf_impl[lax.ne_p] = tf.math.not_equal
boolean_greater = lambda x,y: tf.logical_and(x, tf.logical_not(y)) # Only one combo: T,F -> T
boolean_less = lambda x,y: tf.logical_and(tf.logical_not(x), y) # Only one combo: F,T -> T
boolean_greater_or_equal = lambda x, y: tf.logical_not(boolean_less(x,y)) # All cases except F,T
boolean_less_or_equal = lambda x, y: tf.logical_not(boolean_greater(x,y)) # All cases except T,F
tf_impl[lax.gt_p] = handle_boolean_args(tf.math.greater, argnums=(0, 1), boolean_f=boolean_greater)
tf_impl[lax.lt_p] = handle_boolean_args(tf.math.less, argnums=(0, 1), boolean_f=boolean_less)
tf_impl[lax.ge_p] = handle_boolean_args(tf.math.greater_equal, argnums=(0, 1), boolean_f=boolean_greater_or_equal)
tf_impl[lax.le_p] = handle_boolean_args(tf.math.less_equal, argnums=(0, 1), boolean_f=boolean_less_or_equal)
tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky
def _convert_element_type(operand, *, new_dtype, weak_type=False):
old_dtype = operand.dtype.as_numpy_dtype
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = tf.math.real(operand)
if (dtypes.issubdtype(old_dtype, np.floating) and
not (dtypes.issubdtype(new_dtype, np.floating) or dtypes.issubdtype(
new_dtype, np.complexfloating) or new_dtype == np.bool_)):
sign = _sign(operand)
operand = sign * tf.math.floor(sign * operand)
return tf.dtypes.cast(operand, _to_tf_dtype(new_dtype))
tf_impl[lax.convert_element_type_p] = _convert_element_type
def _bitcast_convert_type(operand, new_dtype):
if operand.dtype == new_dtype:
return operand
return tf.bitcast(operand, _to_tf_dtype(new_dtype))
tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
# The below permits mirroring the behavior of JAX when maxval < minval
op_shape_tf_val = _eval_shape(_in_avals[1].shape, _in_avals[1].dtype)
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
return tf.clip_by_value(operand, minval, maxval)
tf_impl_with_avals[lax.clamp_p] = _clamp
def _concatenate(*operands, dimension):
return tf.concat(operands, axis=tf.cast(dimension, tf.int32))
tf_impl[lax.concatenate_p] = _concatenate
def _conv_general_dimension_numbers_proto(dimension_numbers):
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
proto = xla_data_pb2.ConvolutionDimensionNumbers()
proto.input_batch_dimension = lhs_spec[0]
proto.input_feature_dimension = lhs_spec[1]
proto.output_batch_dimension = out_spec[0]
proto.output_feature_dimension = out_spec[1]
proto.kernel_output_feature_dimension = rhs_spec[0]
proto.kernel_input_feature_dimension = rhs_spec[1]
proto.input_spatial_dimensions.extend(lhs_spec[2:])
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
proto.output_spatial_dimensions.extend(out_spec[2:])
return proto
def _precision_config_proto(precision: Optional[Tuple[PrecisionType,
PrecisionType]]):
"""Convert an integer to an XLA.PrecisionConfig."""
if precision is None:
return None
proto = xla_data_pb2.PrecisionConfig()
proto.operand_precision.append(int(precision[0]))
proto.operand_precision.append(int(precision[1]))
return proto
def _conv_general_dilated(lhs, rhs, *,
window_strides, padding, lhs_dilation,
rhs_dilation,
dimension_numbers: lax.ConvDimensionNumbers,
feature_group_count: int,
batch_group_count: int,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
out_tf_shape = _aval_to_tf_shape(_out_aval)
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _precision_config_proto(precision)
def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]):
tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2])
if tf_version >= (2, 8):
# TODO(necula): remove when 2.8.0 is the stable TF version (and supports
# batch_group_count.
padding_tf = [_eval_shape(p) for p in padding]
out = tfxla.conv(
lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation,
dnums_proto,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
use_v2=True)
else:
if batch_group_count != 1:
raise ValueError(
"The batch_group_count parameter for conv requires TF version "
"at least 2.8.0. You may want to use tf-nightly.")
padding_tf = [_eval_shape(p) for p in padding]
out = tfxla.conv(
lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation,
dnums_proto,
feature_group_count=feature_group_count,
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
use_v2=True)
# TODO: implement shape inference for XlaConv
out = _ensure_tf_shape_if_dynamic(out, out_tf_shape)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
# Follow the lowering for complex convolutions from
# lax._conv_general_dilated_translation. We can use the same conversion on all
# platforms because on XLA:TPU the compiler does the same as a rewrite.
preferred_float_et: Optional[Any]
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts
assert np.issubdtype(preferred_element_type, np.complexfloating)
preferred_float_et = (
np.float64 if preferred_element_type == np.complex128 else np.float32)
else:
preferred_float_et = None
lhs_real, lhs_imag = tf.math.real(lhs), tf.math.imag(lhs)
rhs_real, rhs_imag = tf.math.real(rhs), tf.math.imag(rhs)
k1 = gen_conv(_add(lhs_real, lhs_imag), rhs_real, preferred_float_et)
k2 = gen_conv(lhs_real, tf.math.subtract(rhs_imag, rhs_real),
preferred_float_et)
k3 = gen_conv(lhs_imag, _add(rhs_real, rhs_imag), preferred_float_et)
return tf.complex(tf.math.subtract(k1, k3), _add(k1, k2))
else:
return gen_conv(lhs, rhs, preferred_element_type)
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
precision_config_proto = _precision_config_proto(precision)
res = tfxla.dot_general(
lhs,
rhs,
dnums_proto,
precision_config_proto,
preferred_element_type=preferred_element_type,
use_v2=True)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
res = tf.stop_gradient(res) # See #7839
return res
tf_impl_with_avals[lax.dot_general_p] = _dot_general
def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
# for i in range(len(operand.shape)):
# result.shape[bcast_dims[i]] <- operand.shape[i]
# bcast_dims must be strictly increasing.
# len(bcast_dims) == len(operand.shape)
op_shape = _in_avals[0].shape
dtype = _in_avals[0].dtype
add_1s_shape = [1] * len(shape)
for i, broadcast_dim_i in enumerate(broadcast_dimensions):
add_1s_shape[broadcast_dim_i] = op_shape[i]
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape, dtype=dtype))
return tf.broadcast_to(with_1s, _eval_shape(shape, dtype=dtype))
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
def _empty(*, dtype):
if dtypes.is_opaque_dtype(dtype):
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
return tf.constant(np.array(0, dtype=dtype))
tf_impl[lax_internal.empty_p] = _empty
def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval):
if dimensions is None:
dimensions = tf.range(tf.rank(operand))
new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype)
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
tf_impl_with_avals[lax.reshape_p] = _reshape
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
op_aval = _jax_physical_aval(_in_avals[0])
op_shape = op_aval.shape
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
new_shape_tf = _eval_shape(new_shape, op_aval.dtype)
return tf.reshape(operand, new_shape_tf)
tf_impl_with_avals[lax.squeeze_p] = _squeeze
def _pad(operand, padding_value, *, padding_config,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
low, high, interior = util.unzip3(map(_eval_shape, padding_config)) # type: ignore
out = tfxla.pad(operand, padding_value, low, high, interior)
# TODO: implement shape inference for XlaPad (when some padding_config is constant)
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
tf_impl_with_avals[lax.pad_p] = _pad
def _rev(operand, *, dimensions):
return tf.reverse(operand, dimensions)
tf_impl[lax.rev_p] = _rev
def _where(which, *cases):
if which.dtype == tf.bool:
assert len(cases) <= 2
return cases if len(cases) == 1 else tf.where(which, cases[1], cases[0])
def _select(offset, cases):
assert len(cases) > 0
if len(cases) == 1:
return cases[0]
mid = len(cases) // 2
return tf.where(tf.less(which, offset + mid),
_select(offset, cases[:mid]),
_select(mid, cases[mid:]))
return _select(0, cases)
tf_impl[lax.select_n_p] = _where
def _transpose(operand, *, permutation):
return tf.transpose(operand, perm=permutation)
tf_impl[lax.transpose_p] = _transpose
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
# reduce_sum and reduce_prod are not supported for bool
tf_impl[lax.reduce_sum_p] = axes_to_axis(tf.reduce_sum)
tf_impl[lax.reduce_prod_p] = axes_to_axis(tf.reduce_prod)
tf_impl[lax.reduce_max_p] = handle_boolean_args(
axes_to_axis(tf.reduce_max), argnums=[0],
boolean_f=axes_to_axis(tf.reduce_any)) # Max is T if any one is T
tf_impl[lax.reduce_min_p] = handle_boolean_args(
axes_to_axis(tf.reduce_min), argnums=[0],
boolean_f=axes_to_axis(tf.reduce_all)) # Min is F if not all are T
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
index_dtype: DType,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
# Follow the JAX implementation, using a XlaReduce with a custom comparator
if is_min:
extra_name_stack = "argmin"
value_comparator = lax.lt
get_identity = lax_internal._get_min_identity
else:
extra_name_stack = "argmax"
value_comparator = lax.gt
get_identity = lax_internal._get_max_identity
res = _convert_jax_impl(
partial(lax_internal._compute_argminmax, value_comparator, get_identity),
multiple_results=False,
extra_name_stack=extra_name_stack)(
operand,
index_dtype=index_dtype,
axes=axes,
_in_avals=_in_avals,
_out_aval=_out_aval)
return res
tf_impl_with_avals[lax.argmin_p] = partial(_argminmax, True)
tf_impl_with_avals[lax.argmax_p] = partial(_argminmax, False)
_add_fn = tf.function(_add, autograph=False)
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
def _select_and_gather_add(
tangents: TfVal, operand: TfVal, select_prim: core.Primitive,
window_dimensions: Sequence[int], window_strides: Sequence[int],
base_dilation: Sequence[int], window_dilation: Sequence[int],
padding: Sequence[Tuple[int, int]], _in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
# Note: this function follows the pattern in
# jax.lax._select_and_gather_add_translation.
dtype = operand.dtype
nbits = dtypes.finfo(dtype.as_numpy_dtype).bits
# Specializing the function for 64 bits. Only up to 32 bits are supported on TPU,
# we thus intend to let the code throw a different exception on this platform.
max_bits = 64
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits
const = lambda dtype, x: tf.constant(np.array(x), dtype)
if double_word_reduction:
word_dtype = lax_internal._UINT_DTYPES[nbits]
double_word_dtype = lax_internal._UINT_DTYPES[nbits * 2]
# Packs two values into a tuple.
def pack(a, b):
a = _bitcast_convert_type(a, word_dtype)
b = _bitcast_convert_type(b, word_dtype)
a = _convert_element_type(a, new_dtype=double_word_dtype)
b = _convert_element_type(b, new_dtype=double_word_dtype)
a = tf.bitwise.left_shift(a, const(double_word_dtype, nbits))
return tf.bitwise.bitwise_or(a, b)
# Unpacks the first element of a tuple.
def fst(t):
assert t.dtype == double_word_dtype
st = _shift_right_logical(t, const(double_word_dtype, nbits))
return _bitcast_convert_type(
_convert_element_type(st, new_dtype=word_dtype), dtype)
# Unpacks the second element of a tuple.
def snd(t):
return _bitcast_convert_type(
_convert_element_type(t, new_dtype=word_dtype), dtype)
else:
raise NotImplementedError(
f"TODO: need to pack {nbits * 2} bits but this platform can only go up to {max_bits} bits."
)
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
def reducer(x, y):
which = tf_impl[select_prim]
return tf_impl[lax.select_n_p](which(fst(x), fst(y)), y, x)
init = -np.inf if select_prim is lax.ge_p else np.inf
init_identity = lambda x: pack(const(dtype, init), const(dtype, 0))
out = _specialized_reduce_window(
reducer,
init_identity,
pack(operand, tangents),
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
_in_avals=_in_avals,
_out_aval=_out_aval)
return snd(out)
tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add
def _common_reduce_window(operand, init_val, reducer, window_dimensions,
window_strides, padding, base_dilation,
window_dilation, _in_avals, _out_aval):
o_spec = tf.TensorSpec((), dtype=operand.dtype)
reducer_fn = tf.function(
reducer, autograph=False).get_concrete_function(o_spec, o_spec)
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
init_val = tf.constant(init_val, operand.dtype)
window_dimensions_tf = _eval_shape(window_dimensions)
window_strides_tf = _eval_shape(window_strides)
window_dilation_tf = _eval_shape(window_dilation)
base_dilation_tf = _eval_shape(base_dilation)
padding_tf = [_eval_shape(p) for p in padding]
out = tfxla.reduce_window(
operand,
init_val,
reducer_fn,
window_dimensions_tf,
window_strides_tf,
base_dilations=base_dilation_tf,
window_dilations=window_dilation_tf,
padding=padding_tf)
# TODO: implement shape inference for XlaReduceWindow
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
def _reduce_window(*args, jaxpr, consts, window_dimensions,
window_strides, padding, base_dilation, window_dilation,
_in_avals, _out_aval):
"""TensorFlow implementation of reduce_window.
Args:
operands: N dimensional arrays containing elements of type T
init_values: starting values of the reduction
jaxpr: the jaxpr corresponding to the reduction function
consts: the constants associated with jaxpr.
window_dimensions: array of integers for window dimension values
window_strides: array of integers for window stride values
padding: array of pairs of integers for padding values
base_dilation: array of integers for base dilation values
window_dilation: array of integers for window dilation values
Returns:
The reduced operand.
"""
assert len(consts) == 0, "Reduction computation cannot have constants"
operands, init_values = util.split_list(args, [len(args) // 2])
if len(operands) != 1:
raise NotImplementedError("jax2tf does not support variadic reduce_window")
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
return res
return (_common_reduce_window(operands[0], init_values[0], reducer,
window_dimensions, window_strides, padding,
base_dilation, window_dilation, _in_avals,
_out_aval[0]),)
def _specialized_reduce_window(reducer,
identity,
operand,
*,
window_dimensions,
window_strides,
padding,
base_dilation,
window_dilation,
_in_avals,
_out_aval,
name=None):
"""Wraps the TensorFlow reduce window operation based on a reducer and an
identity function defining the initial value of the reduction depending on
the dtype of the operand.
Args:
reducer: reduction function of type TfVal -> TfVal -> TfVal
identity: function that takes a TensorFlow dtype as a parameter and returns
the starting value of the reduction.
operand: N dimensional array containing elements of type T
window_dimensions: array of integers for window dimension values
window_strides: array of integers for window stride values
padding: array of pairs of integers for padding values
base_dilation: array of integers for base dilation values
window_dilation: array of integers for window dilation values
name: the name of the specialized reduce window primitive for which this
conversion function is called. This information may help to choose a
different conversion path (optional)
Returns:
The reduced operand.
"""
return _common_reduce_window(operand, identity(operand.dtype), reducer,
window_dimensions, window_strides, padding,
base_dilation, window_dilation, _in_avals,
_out_aval)
def _get_max_identity(tf_dtype):
numpy_tf_dtype = tf_dtype.as_numpy_dtype
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
return numpy_tf_dtype(-np.inf)
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
return dtypes.iinfo(numpy_tf_dtype).min
else:
assert dtypes.issubdtype(
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined max identity")
return False
def _get_min_identity(tf_dtype):
numpy_tf_dtype = tf_dtype.as_numpy_dtype
if tf_dtype == tf.bfloat16 or dtypes.issubdtype(numpy_tf_dtype, np.inexact):
return numpy_tf_dtype(np.inf)
elif dtypes.issubdtype(numpy_tf_dtype, np.integer):
return dtypes.iinfo(numpy_tf_dtype).max
else:
assert dtypes.issubdtype(
numpy_tf_dtype, np.bool_), (f"{tf_dtype} has no defined min identity")
return True
# pylint: disable=protected-access
tf_impl_with_avals[lax.reduce_window_sum_p] = (
partial(_specialized_reduce_window, _add, lambda x: 0,
name="reduce_window_sum"))
tf_impl_with_avals[lax.reduce_window_min_p] = (
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=True),
_get_min_identity,
name="reduce_window_min"))
tf_impl_with_avals[lax.reduce_window_max_p] = (
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=False),
_get_max_identity,
name="reduce_window_max"))
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
# pylint: enable=protected-access
def _reduce(*operands: TfVal,
computation: Callable,
jaxpr: core.Jaxpr,
consts: Sequence[Any],
dimensions: Sequence[int],
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray) -> Sequence[TfVal]:
del computation
assert not consts
assert len(operands) % 2 == 0
# operands: op1, op2, ..., init_val1, init_val2, ...
# reducer takes op1[i], op2[i], ..., init_val1, init_val2, ...
nr_operands = len(operands) // 2
init_vals = operands[nr_operands:]
operands = operands[0:nr_operands]
reducer_arg_spec = tuple([tf.TensorSpec((), op.dtype) for op in init_vals] * 2)
def reducer_computation(*args: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
res = _interpret_jaxpr(closed_jaxpr, *args, extra_name_stack=None)
return res
xla_reducer_computation = (
tf.function(reducer_computation,
autograph=False).get_concrete_function(*reducer_arg_spec))
outs = tfxla.variadic_reduce(operands, init_vals,
dimensions_to_reduce=dimensions,
reducer=xla_reducer_computation)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
outs = tuple(tf.stop_gradient(out) for out in outs) # See #7839
return outs
tf_impl_with_avals[lax.reduce_p] = _reduce
# We use lax.cumred_reduce_window_impl to convert cummax,
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
# O(n^2) on other backends. This may be implemented using associative_scan
# instead to favor different backends.
def _cumred(lax_reduce_fn: Callable,
lax_reduce_window_fn: Callable,
extra_name_stack: str):
if config.jax2tf_associative_scan_reductions:
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
lax_reduce_fn),
multiple_results=False,
extra_name_stack=extra_name_stack)
else:
return _convert_jax_impl(partial(lax_control_flow.cumred_reduce_window_impl,
lax_reduce_window_fn),
multiple_results=False,
extra_name_stack=extra_name_stack)
tf_impl_with_avals[lax.cummax_p] = _cumred(
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_max,
lax_reduce_fn=lax.max,
extra_name_stack="cummax")
tf_impl_with_avals[lax.cummin_p] = _cumred(
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_min,
lax_reduce_fn=lax.min,
extra_name_stack="cummin")
tf_impl_with_avals[lax.cumlogsumexp_p] = _cumred(
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_logaddexp,
lax_reduce_fn=logaddexp,
extra_name_stack="cumlogsumexp")
tf_impl_with_avals[lax.cumsum_p] = _cumred(
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_sum,
lax_reduce_fn=lax.add,
extra_name_stack="cumsum")
tf_impl_with_avals[lax.cumprod_p] = _cumred(
lax_reduce_window_fn=lax_windowed_reductions._reduce_window_prod,
lax_reduce_fn=lax.mul,
extra_name_stack="cumprod")
def _select_and_scatter(operand, source, init_value, select_jaxpr,
select_consts, scatter_jaxpr, scatter_consts,
window_dimensions, window_strides, padding):
raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter")
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
@partial(handle_boolean_args, argnums=(0, 1))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, _in_avals, _out_aval):
init_value = tf.zeros((), operand.dtype)
select_fn = (
tf.function(tf_impl[select_prim], autograph=False).get_concrete_function(
init_value, init_value))
scatter_fn = _add_fn.get_concrete_function(init_value, init_value)
out = tfxla.select_and_scatter(operand, window_dimensions, window_strides,
padding, source, init_value, select_fn,
scatter_fn)
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval):
def impl_wrapper(seeds: TfVal, *, impl):
return prng.random_seed_impl_base(seeds, impl=impl)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_seed")
return converted_impl(
seeds, impl=impl, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[prng.random_seed_p] = _random_seed_impl
def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval):
keys_aval, = _in_avals
def impl_wrapper(keys: TfVal, *, count):
return prng.random_split_impl_base(
keys_aval.dtype.impl, keys, keys_aval.ndim, count=count)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_split")
return converted_impl(
keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[prng.random_split_p] = _random_split_impl
def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval):
keys_aval, _ = _in_avals
def impl_wrapper(keys: TfVal, msgs: TfVal):
return prng.random_fold_in_impl_base(
keys_aval.dtype.impl, keys, msgs, keys_aval.shape)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_fold_in")
return converted_impl(
keys, msgs, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[prng.random_fold_in_p] = _random_fold_in_impl
def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval):
keys_aval, = _in_avals
def impl_wrapper(keys: TfVal, **kwargs):
return prng.random_bits_impl_base(
keys_aval.dtype.impl, keys, keys_aval.ndim,
bit_width=bit_width, shape=shape)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_bits")
return converted_impl(keys, bit_width=bit_width, shape=shape,
_in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[prng.random_bits_p] = _random_bits_impl
def _random_wrap_impl(base_arr: TfVal, *, impl, _in_avals, _out_aval):
return base_arr
tf_impl_with_avals[prng.random_wrap_p] = _random_wrap_impl
def _random_unwrap_impl(keys: TfVal, *, _in_avals, _out_aval):
return keys
tf_impl_with_avals[prng.random_unwrap_p] = _random_unwrap_impl
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
res = _convert_jax_impl(
partial(prng._threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True, extra_name_stack="threefry")(
*args, _in_avals=_in_avals, _out_aval=_out_aval)
return res
tf_impl_with_avals[prng.threefry2x32_p] = _threefry2x32_jax_impl
# Use the vmap implementation, otherwise on TPU the performance is really bad
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
partial(random_internal._gamma_impl, use_vmap=True),
multiple_results=False, extra_name_stack="random_gamma")
def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]:
is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32)
if is_uint32_key:
key = tf.reshape(key, (2, 2))
key = tfxla.bitcast_convert_type(key, _to_tf_dtype(jnp.uint64))
shape_tf = _eval_shape(shape)
# JAX uses XLA algorithm enums; tfxla uses tf.random.Algorithm
if algorithm == lax.RandomAlgorithm.RNG_THREE_FRY:
algorithm_tf = tf.random.Algorithm.THREEFRY
elif algorithm == lax.RandomAlgorithm.RNG_PHILOX:
algorithm_tf = tf.random.Algorithm.PHILOX
elif algorithm == lax.RandomAlgorithm.RNG_DEFAULT:
algorithm_tf = tf.random.Algorithm.AUTO_SELECT
else:
assert False
(new_key, res) = tfxla.rng_bit_generator(algorithm_tf.value, key, shape_tf,
dtype=_to_tf_dtype(dtype))
if is_uint32_key:
new_key = tfxla.bitcast_convert_type(new_key, _to_tf_dtype(jnp.uint32))
new_key = tf.reshape(new_key, (4,))
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
# See #7839
new_key = tf.stop_gradient(new_key)
res = tf.stop_gradient(res)
return new_key, res
tf_impl[lax.rng_bit_generator_p] = _rng_bit_generator
def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal:
shape_tf = _eval_shape(shape)
return tf.random.uniform(shape_tf, minval=minval, maxval=maxval, dtype=minval.dtype)
tf_impl[lax.rng_uniform_p] = _rng_uniform
def _iota_2x32_shape(*, shape):
def _add(x, y):
return x + y
def _mul(x, y):
if not core.is_constant_dim(x):
x = tf.cast(_eval_shape((x,))[0], y.dtype)
x = tf.broadcast_to(x, tf.shape(y))
return x * y
def _cast32(xs): return tf.cast(xs, _to_tf_dtype(jnp.uint32))
iotas = [_iota(dtype=jnp.uint64, shape=shape, dimension=dimension)
for dimension in range(len(shape))]
counts = prng.bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
counts_lo = _cast32(counts)
counts_hi = _cast32(tf.bitwise.right_shift(counts, 32))
return counts_hi, counts_lo
tf_impl[prng.iota_2x32_shape_p] = _iota_2x32_shape
def _gather_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.GatherDimensionNumbers()
proto.offset_dims.extend(dimension_numbers.offset_dims)
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
proto.start_index_map.extend(dimension_numbers.start_index_map)
assert indices_shape
proto.index_vector_dim = len(indices_shape) - 1
return proto
def _maybe_cast_to_int64(x: TfVal) -> TfVal:
if x.dtype != tf.int32 and x.dtype != tf.int64:
return tf.cast(x, tf.int64)
return x
@partial(handle_boolean_args, argnums=[0])
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shape,
indices_are_sorted, unique_indices, mode, fill_value,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Tensorflow implementation of gather."""
if mode == lax.GatherScatterMode.FILL_OR_DROP:
gather_fill_fn = _convert_jax_impl(lax_slicing._gather_fill,
multiple_results=False)
return gather_fill_fn(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value,
output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)
operand_aval = _in_avals[0]
start_indices = _maybe_cast_to_int64(start_indices)
if dtypes.is_opaque_dtype(operand_aval.dtype):
opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):]
trailing_offset_dims = [len(_out_aval.shape) + i for i in range(len(opaque_shape))]
dimension_numbers = dimension_numbers._replace(
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
slice_sizes = (*slice_sizes, *opaque_shape)
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
slice_sizes_tf = _eval_shape(slice_sizes)
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf,
indices_are_sorted)
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
tf_impl_with_avals[lax.gather_p] = _gather
def _slice(operand, start_indices, limit_indices, strides, _in_avals,
_out_aval):
if strides is None:
strides = [1] * len(start_indices)
slices = tuple(
map(slice, _eval_shape(start_indices), _eval_shape(limit_indices),
_eval_shape(strides)))
out = operand[slices]
# TODO(b/184503314): improve shape inference for __getitem__
# E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1)
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.slice_p] = _slice
def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
start_indices = _maybe_cast_to_int64(tf.stack(start_indices))
operand_aval = _in_avals[0]
if dtypes.is_opaque_dtype(operand_aval.dtype):
opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):]
slice_sizes = (*slice_sizes, *opaque_shape)
start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),),
dtype=start_indices.dtype)],
axis=0)
slice_sizes_tf = _eval_shape(slice_sizes)
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
res = tf.stop_gradient(res) # See #7839
return res
tf_impl_with_avals[lax.dynamic_slice_p] = _dynamic_slice
def _dynamic_update_slice(operand, update, *start_indices,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
start_indices = _maybe_cast_to_int64(tf.stack(start_indices))
operand_aval = _in_avals[0]
if dtypes.is_opaque_dtype(operand_aval.dtype):
opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):]
start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),),
dtype=start_indices.dtype)],
axis=0)
out = tfxla.dynamic_update_slice(operand, update, start_indices)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
tf_impl_with_avals[lax.dynamic_update_slice_p] = _dynamic_update_slice
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.ScatterDimensionNumbers()
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
proto.scatter_dims_to_operand_dims.extend(
dimension_numbers.scatter_dims_to_operand_dims)
assert indices_shape
proto.index_vector_dim = len(indices_shape) - 1
return proto
def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices, mode,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
del unique_indices
if mode == lax.GatherScatterMode.CLIP:
clip_fn = _convert_jax_impl(lax_slicing._clamp_scatter_indices,
multiple_results=False)
scatter_indices = clip_fn(
operand, scatter_indices, updates, dnums=dimension_numbers,
_in_avals=_in_avals, _out_aval=_in_avals[1])
assert len(update_consts) == 0, "Update computation cannot have constants"
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal:
closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts)
res, = _interpret_jaxpr(closed_jaxpr, arg1, arg2, extra_name_stack=None)
return res
o_spec = tf.TensorSpec((), dtype=operand.dtype)
xla_update_computation = (
tf.function(update_computation,
autograph=False).get_concrete_function(o_spec, o_spec))
out = tfxla.scatter(
operand,
scatter_indices,
updates,
xla_update_computation,
proto,
indices_are_sorted=indices_are_sorted)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
out = tf.stop_gradient(out) # See #7839
return out
tf_impl_with_avals[lax.scatter_p] = _scatter
tf_impl_with_avals[lax.scatter_min_p] = _scatter
tf_impl_with_avals[lax.scatter_max_p] = _scatter
tf_impl_with_avals[lax.scatter_mul_p] = _scatter
tf_impl_with_avals[lax.scatter_add_p] = _scatter
def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
linear: Sequence[bool]) -> Sequence[TfVal]:
del linear
# tf.cond needs lambdas with no arguments.
branches_tf = [
partial(_interpret_jaxpr, jaxpr, *operands,
# Same name stack as the XLA translation of cond_p
extra_name_stack=f"branch_{i}_fun")
for i, jaxpr in enumerate(branches)
]
# Same name stack as XLA translation of cond_p
# Note: extend_name_stack is a contextmanager, which is callable as a decorator.
branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type]
branches_tf))
return tf.switch_case(index, branches_tf)
tf_impl[lax.cond_p] = _cond
def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
cond_consts, body_consts, init_carry = util.split_list(
args, [cond_nconsts, body_nconsts])
if cond_jaxpr.out_avals[0].shape: # type: ignore[attr-defined]
# The conditional is not a scalar, this must be a batched while
return _batched_cond_while(
*args,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr)
# The conditional must return a single value to TF
def cond_tf_func(*args: TfVal) -> TfVal:
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args,
# Same name stack as the XLA translation of while_p
extra_name_stack="while/cond")
return pred
body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts,
extra_name_stack="while/body")
# Sometimes TF infers more specific shapes for the init_carry, and this has
# led to errors: "enters the loop with shape (1,), but has shape (None,) after one iteration"
shape_invariants = [tf.TensorShape(_aval_to_tf_shape(_out_aval))
for _out_aval in body_jaxpr.out_avals]
return tf.while_loop(cond_tf_func, body_tf_func, init_carry,
shape_invariants=shape_invariants)
def _batched_cond_while(*args: TfVal, cond_nconsts: int,
cond_jaxpr: core.ClosedJaxpr, body_nconsts: int,
body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]:
"""Interprets a while_loop with a batched condition.
A batched while has a conditional that returns a tensor of booleans, and
a body that returns a list of tensors whose leading dimensions match those
of the conditional tensor.
We need to turn it into a while with scalar boolean conditional. We will
expand the loop carry to include a prefix with the current tensor boolean
condition. We prepend to the loop the first calculation of the tensor boolean
condition. The loop condition will use a "reduce_any" to calculate a scalar
boolean from the tensor boolean condition. The end of the loop body will
compute the new carry using a "tf.where", and we compute the new tensor
boolean condition.
"""
cond_consts, body_consts, init_carry = util.split_list(
args, [cond_nconsts, body_nconsts])
# Initial computation of batched condition
init_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *init_carry,
extra_name_stack="while/body_pred")
def new_cond_tf_func(pred_b: TfVal, *carry: TfVal) -> TfVal:
pred = tf.reduce_any(pred_b, axis=list(range(len(pred_b.shape))))
return pred
def new_body_tf_func(pred_b: TfVal, *carry: TfVal) -> Sequence[TfVal]:
new_carry: Sequence[TfVal] = _interpret_jaxpr(body_jaxpr, *body_consts,
*carry,
extra_name_stack="while/body")
# We repeat those carries for which the loop termination condition is false
def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
pred_b_bcast = _broadcast_in_dim(
pred_b,
shape=_jax_physical_aval(c_aval).shape, # a JAX shape
broadcast_dimensions=list(range(len(pred_b.shape))),
_in_avals=cond_jaxpr.out_avals,
_out_aval=core.ShapedArray(c_aval.shape, np.bool_))
return tf.where(pred_b_bcast, new_c, c)
selected_carry: Sequence[TfVal] = list(map(select_one_carry, new_carry, carry, body_jaxpr.out_avals))
next_pred_b, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *selected_carry,
extra_name_stack="body_pred")
return (next_pred_b, *selected_carry)
_, *res_carry = tf.while_loop(new_cond_tf_func, new_body_tf_func,
(init_pred_b, *init_carry))
return res_carry
tf_impl[lax.while_p] = _while
# We use the scan impl rule to rewrite in terms of while.
tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
lax_control_flow._scan_impl,
extra_name_stack="scan")
tf_impl_with_avals[ad_checkpoint.remat_p] = \
_convert_jax_impl(partial(ad_checkpoint.remat_lowering,
# TODO: jax2tf cannot discriminate by platform
is_gpu_platform=False),
multiple_results=True,
extra_name_stack="checkpoint")
tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if hasattr(tfxla, 'optimization_barrier'):
tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
# Some types originally incompatible with tf.math.top_k can be promoted
# to a compatible type without loss of precision.
def promote_tf_dtype(tf_dtype):
if tf_dtype in [tf.bool, tf.uint8, tf.uint16]:
return tf.uint32
if tf_dtype in [tf.int8, tf.int16]:
return tf.int32
if tf_dtype is tf.float16:
return tf.float32
return None
conversion_dtype = promote_tf_dtype(operand.dtype)
if core.is_special_dim_size(k):
k_tf = _eval_shape((k,))[0]
k_tf = tf.cast(k_tf, tf.int32) # TopK works only for int32
else:
k_tf = k
if conversion_dtype:
values, indices = tf.math.top_k(
tf.dtypes.cast(operand, conversion_dtype), k=k_tf, sorted=True)
return tf.dtypes.cast(values, operand.dtype), indices
else:
return tf.math.top_k(operand, k=k_tf, sorted=True)
tf_impl[lax.top_k_p] = _top_k
def _approx_top_k(operand: TfVal, k: int, reduction_dimension: int,
recall_target: float, is_max_k: bool,
reduction_input_size_override: int,
aggregate_to_topk: bool) -> Tuple[TfVal, TfVal]:
k_tf = _eval_shape((k,))[0]
if is_max_k:
return tf.math.approx_max_k(operand, k_tf, reduction_dimension, recall_target,
reduction_input_size_override,
aggregate_to_topk)
else:
return tf.math.approx_min_k(operand, k_tf, reduction_dimension, recall_target,
reduction_input_size_override,
aggregate_to_topk)
tf_impl[lax.approx_top_k_p] = _approx_top_k
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
num_keys: int) -> Tuple[TfVal, ...]:
assert 1 <= num_keys <= len(operands)
assert 0 <= dimension < len(
operands[0].shape
), f"Invalid {dimension} for ndim {len(operands[0].shape)}"
comparator_spec: List[tf.TensorSpec] = []
comparator_jax_in_avals: List[core.ShapedArray] = []
for op in operands:
o_spec = tf.TensorSpec((), dtype=op.dtype)
comparator_spec.extend([o_spec, o_spec])
o_aval = core.ShapedArray((), _to_jax_dtype(op.dtype))
comparator_jax_in_avals.extend([o_aval, o_aval])
# Use the same comparator that JAX uses when compiling to XLA, to get the
# proper NaN/Inf total order, and the lexicographic ordering.
# The comparator is a 2N-argument TF function, with arguments [2k] and [2k +1]
# corresponding to two scalars from operand[k].
def lexicographic_comparator(*tf_args: TfVal) -> TfVal:
return _convert_jax_impl(
lax_internal._sort_lt_comparator, multiple_results=False)(
*tf_args,
_in_avals=comparator_jax_in_avals,
_out_aval=core.ShapedArray((), np.bool_),
num_keys=num_keys)
xla_comparator_computation = (
tf.function(lexicographic_comparator,
autograph=False).get_concrete_function(*comparator_spec))
results = tfxla.variadic_sort(
operands,
dimension=dimension,
is_stable=is_stable,
comparator=xla_comparator_computation)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
results = tuple(tf.stop_gradient(out) for out in results) # See #7839
return results
tf_impl[lax.sort_p] = _sort
def _fft(x, *, fft_type, fft_lengths,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
x_aval, = _in_avals
x_shape = x_aval.shape
if fft_type == IRFFT:
expected_lengths = x_shape[-len(fft_lengths):-1] + ((x_shape[-1] - 1) * 2,)
else:
expected_lengths = x_shape[-len(fft_lengths):]
if expected_lengths != fft_lengths:
raise NotImplementedError(
f"Unsupported {fft_lengths=} for {fft_type=} of "
f"array with shape={x.shape}.")
tf_funcs = {
FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d],
IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d],
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
}
res = tf_funcs[fft_type][len(fft_lengths) - 1](x)
return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval))
tf_impl_with_avals[lax.fft_p] = _fft
def _qr(operand, full_matrices):
return tf.linalg.qr(operand, full_matrices=full_matrices)
tf_impl[lax.linalg.qr_p] = _qr
def _svd(operand, full_matrices, compute_uv):
result = tf.linalg.svd(operand, full_matrices, compute_uv)
if not compute_uv:
return result,
s, u, v = result
return s, u, tf.linalg.adjoint(v)
tf_impl[lax.linalg.svd_p] = _svd
def _eig(operand: TfVal, compute_left_eigenvectors: bool,
compute_right_eigenvectors: bool):
if compute_left_eigenvectors and compute_right_eigenvectors:
# TODO(bchetioui): didn't find a 100% reliable, easy and satisfying way to
# sort the left eigenvectors in the right order. The jax.numpy.linalg API
# suggests to me that left eigenvectors are anyway seldom used, so I
# think it is acceptable to leave as unimplemented for now.
msg = ("Conversion of eig is not implemented when both "
"compute_left_eigenvectors and compute_right_eigenvectors are set "
"to True.")
raise NotImplementedError(msg)
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
return tuple([tf.linalg.eigvals(operand)])
elif compute_right_eigenvectors:
return tuple(tf.linalg.eig(operand))
else: # compute_left_eigenvectors == True
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
wHH = tf.math.conj(wH)
return tuple([wHH, vl])
tf_impl[lax.linalg.eig_p] = _eig
def _eigh(operand: TfVal, lower: bool, sort_eigenvalues: bool, _in_avals,
_out_aval):
del sort_eigenvalues
if operand.shape[-1] == 0:
v, w = operand, tf.reshape(operand, _eval_shape(_in_avals[0].shape[:-1]))
else:
if not lower:
operand = tf.linalg.adjoint(operand)
w, v = tf.linalg.eigh(operand)
cast_type = {
tf.complex64: tf.float32,
tf.complex128: tf.float64
}.get(operand.dtype)
if cast_type is not None:
w = tf.cast(w, cast_type)
return v, w
tf_impl_with_avals[lax.linalg.eigh_p] = _eigh
def _lu(operand: TfVal, _in_avals, _out_aval):
return _convert_jax_impl(lax_linalg._lu_python, extra_name_stack="lu")(
operand, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[lax.linalg.lu_p] = _lu
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
transpose_a: bool, conjugate_a: bool, unit_diagonal: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
if unit_diagonal:
a_aval, _ = _in_avals
a_shape = _eval_shape(a_aval.shape)
a = tf.linalg.set_diag(a, tf.ones(a_shape[:-1], dtype=a.dtype))
if not left_side:
rank = len(a.shape)
transpose_dimensions = list(range(rank - 2)) + [rank - 1, rank - 2]
a = tf.transpose(a, transpose_dimensions)
b = tf.transpose(b, transpose_dimensions)
lower = not lower
# adjoint == transpose for real dtypes, so special care need only be taken
# for complex types.
if a.dtype in [tf.complex64, tf.complex128]:
if (transpose_a and not conjugate_a) or (not transpose_a and conjugate_a):
a = tf.math.conj(a)
result = tf.linalg.triangular_solve(a, b, lower=lower, adjoint=transpose_a)
if not left_side:
result = tf.transpose(result, transpose_dimensions)
return result
tf_impl_with_avals[lax.linalg.triangular_solve_p] = _triangular_solve
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl,
extra_name_stack="linear_solve")(
*args,
const_lengths=const_lengths,
jaxprs=jaxprs,
_in_avals=_in_avals,
_out_aval=_out_aval)
tf_impl_with_avals[lax.linear_solve_p] = _linear_solve
def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):
return _convert_jax_impl(lax_linalg._tridiagonal_solve_jax,
multiple_results=False,
extra_name_stack="tridiagonal_solve")(
*args,
_in_avals=_in_avals,
_out_aval=_out_aval)
tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve
def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable,
num_consts: int) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
del jvp_jaxpr_thunk, num_consts
return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp",
fresh_constant_cache=False)
tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call
def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr,
**_) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp",
fresh_constant_cache=False)
tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr
def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
"function.")
tf_impl[ad.custom_lin_p] = _custom_lin
PartitionsOrReplicated = Optional[Tuple[int, ...]]
def split_to_logical_devices(tensor: TfVal,
partition_dimensions: PartitionsOrReplicated):
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
For jax2tf purposes we want to avoid needing to thread the `strategy` object
through the generated computation. It seems that the original function needs
the strategy object only for error checking, which we assume is done upstream
by JAX.
Args:
tensor: Input tensor to annotate.
partition_dimensions: A list of integers, with one integer per tensor
dimension, specifying in how many parts the dimension should be split. The
product of integers must equal the number of devices per replica.
use_sharding_op: whether to use a sharding op, or not.
Returns:
an annotated tensor.
"""
# TODO: this is only for sharded_jit. Either remove, or implement in terms
# of _shard_values.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = math.prod(partition_dimensions)
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _shard_value(val: TfVal,
aval: core.ShapedArray,
sd: sharding.XLACompatibleSharding, *,
skip_replicated_sharding: bool) -> TfVal:
"""Apply sharding to a TfVal."""
if sharding_impls.is_unspecified(sd):
return val
sharding_proto: xla_client.OpSharding = cast(
xla_client.OpSharding, sd._to_xla_hlo_sharding(aval.ndim).to_proto()) # type: ignore
if (skip_replicated_sharding and
op_shardings.is_op_sharding_replicated(sharding_proto)):
return val
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_proto: xla_data_pb2.OpSharding = (
xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=sharding_proto.tile_assignment_devices,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
last_tile_dims=sharding_proto.last_tile_dims))
if tf_context.executing_eagerly():
raise ValueError(
"A jit function with sharded arguments or results must be used under a `tf.function` context. "
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
val, use_sharding_op=True)
def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_shardings: Sequence[sharding.XLACompatibleSharding],
out_shardings: Sequence[sharding.XLACompatibleSharding],
resource_env: maps.ResourceEnv,
donated_invars,
name: str,
keep_unused: bool,
inline: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
del donated_invars
# Apply sharding annotation to the arguments
sharded_args: Sequence[TfVal] = tuple(
map(partial(_shard_value,
skip_replicated_sharding=not _thread_local_state.enable_xla),
args, _in_avals, in_shardings))
results = _interpret_jaxpr(jaxpr, *sharded_args,
extra_name_stack=util.wrap_name(name, "pjit"),
fresh_constant_cache=False)
sharded_results: Sequence[TfVal] = tuple(
map(partial(_shard_value,
skip_replicated_sharding=not _thread_local_state.enable_xla),
results, _out_aval, out_shardings))
return tuple(sharded_results)
tf_impl_with_avals[pjit.pjit_p] = _pjit
def _pjit_sharding_constraint(arg: TfVal, *,
sharding: sharding.NamedSharding,
resource_env: maps.ResourceEnv,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> TfVal:
return _shard_value(arg, _in_avals[0], sharding, skip_replicated_sharding=False)
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
def _dimension_size_jax2tf(op: TfVal, *, dimension, _in_avals, _out_aval):
dim_tf = tf.shape(op)[dimension]
if dim_tf.dtype != _to_tf_dtype(_out_aval.dtype):
return _convert_element_type(dim_tf, new_dtype=_out_aval.dtype,
weak_type=_out_aval.weak_type)
else:
return dim_tf
tf_impl_with_avals[shape_poly.dimension_size_p] = _dimension_size_jax2tf
def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
dim_tf, = _eval_shape((dim,))
return dim_tf
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
def _reduce_precision(x, *, exponent_bits, mantissa_bits):
return tfxla.reduce_precision(x, exponent_bits=exponent_bits,
mantissa_bits=mantissa_bits)
tf_impl[lax.reduce_precision_p] = _reduce_precision
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()
# The types here are automagically changed by TensorFlow's checkpointing
# infrastructure.
m.a = (tf.Module(), tf.Module())
m.b = [tf.Module(), tf.Module()]
m.c = {"a": tf.Module()}
tuple_wrapper = type(m.a)
list_wrapper = type(m.b)
dict_wrapper = type(m.c)
# TF AutoTrackable swaps container types out for wrappers.
assert tuple_wrapper is not tuple
assert list_wrapper is not list
assert dict_wrapper is not dict
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
(tuple(xs), None), lambda _, xs: tuple(xs))
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
lambda _, xs: list(xs))
jax.tree_util.register_pytree_node(
dict_wrapper,
lambda s: (tuple(s.values()), tuple(s.keys())),
lambda k, xs: dict_wrapper(zip(k, xs)))
_register_checkpoint_pytrees()