Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jax2tf/tests/jax2tf_test.py

1817 lines
68 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for JAX2TF converted.
Specific JAX primitive conversion tests are in primitives_test."""
import collections
import contextlib
import math
import os
import re
from typing import Callable, Dict, Optional, Tuple
import unittest
from absl import logging
from absl.testing import absltest, parameterized
import jax
from jax import ad_checkpoint
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import sharding
from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import stablehlo
import jax._src.xla_bridge
from jax import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.sharding import PartitionSpec as P
import numpy as np
import tensorflow as tf # type: ignore[import]
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
config.parse_flags_with_absl()
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_empty(self):
f_jax = lambda x, y: x
self.ConvertAndCompare(f_jax, 0.7, 1)
def test_sin(self):
f_tf = jax2tf.convert(jnp.sin)
x = np.float32(.5)
sin_x = np.sin(x)
self.assertAllClose(sin_x, f_tf(x))
self.assertAllClose(sin_x, tf.function(f_tf, autograph=False,
jit_compile=True)(x))
tf_preferred_device = (
tf.config.list_logical_devices("TPU")
+ tf.config.list_logical_devices("GPU")
+ tf.config.list_logical_devices()
)[0]
logging.info("Running TF on %s", tf_preferred_device)
# The following, with jit_compile=False, fails with native serialization
# because TF executes the function where it is instantiated (For example,
# XlaCallModule op on CPU). The workaround here is that we can
# wrap it and add device assignment inside the tf.function.
@tf.function(autograph=False, jit_compile=False)
def f_tf_wrapped(x):
with tf.device(tf_preferred_device.name):
return f_tf(x)
with tf.device(tf_preferred_device.name):
self.assertAllClose(sin_x, f_tf_wrapped(x))
def test_basics(self):
f_jax = lambda x: jnp.sin(jnp.cos(x))
self.ConvertAndCompare(f_jax, 0.7)
def test_input_output_naming(self):
@jax2tf.convert
def f(xs, y):
return [jnp.add(x, y) for x in xs]
@tf.function(autograph=False)
def u(xs, y):
xs = tf.nest.map_structure(tf.convert_to_tensor, xs)
with tf.GradientTape() as tape:
tf.nest.map_structure(tape.watch, xs)
y = f(xs, y)
tape.gradient(y, xs)
return y
cf = u.get_concrete_function([1., 2., 3.], 4.)
g = cf.graph
g.get_operation_by_name("jax2tf_arg_0")
g.get_operation_by_name("jax2tf_arg_1")
g.get_operation_by_name("jax2tf_arg_2")
g.get_operation_by_name("jax2tf_arg_3")
g.get_operation_by_name("jax2tf_out")
g.get_operation_by_name("jax2tf_out_1")
g.get_operation_by_name("jax2tf_out_2")
with self.assertRaises(KeyError):
g.get_operation_by_name("jax2tf_arg_4")
with self.assertRaises(KeyError):
g.get_operation_by_name("jax2tf_out_3")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_0")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_2")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_3")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out_1")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out_2")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out_3")
def test_pytrees(self):
# Take and return pytrees
def f_jax(x: Tuple[float, Dict[str, float]]) -> Tuple[float, Dict[str, float]]:
x_a, x_dict = x
return x_a * 2., {k: v * 3. for k, v in x_dict.items()}
x = (.7, {"a": .8, "b": .9})
self.ConvertAndCompare(f_jax, x)
def test_variable_input(self):
f_jax = lambda x: jnp.sin(jnp.cos(x))
f_tf = jax2tf.convert(f_jax)
v = tf.Variable(0.7, dtype=jax2tf.dtype_of_val(0.7))
self.assertIsInstance(f_tf(v), tf.Tensor)
self.assertAllClose(f_jax(0.7), f_tf(v))
def test_jit(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
self.ConvertAndCompare(f_jax, 0.7)
def test_nested_jit(self):
f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
x = 0.7
self.ConvertAndCompare(f_jax, x)
def test_nested_jit_pytree(self):
@jax.jit
def f_jax(xy):
x, y = xy
return x + y
xy = (0.7, 0.8)
self.ConvertAndCompare(f_jax, xy)
def test_nested_jit_is_compiled(self):
# Check that nested jax.jit are compiled with tf.function(jit_compile=True)
# We do this by looking for the _XlaMustCompile attribute in the function graph
def has_xla_must_compile(f_tf, x):
f_conc = tf.function(f_tf, autograph=True).get_concrete_function(tf.convert_to_tensor(x))
for n in f_conc.graph._nodes_by_id.values():
try:
n.get_attr("_XlaMustCompile")
return True
except ValueError:
continue
return False
x = np.array(0.7)
f_no_jit = lambda x: x
self.assertFalse(has_xla_must_compile(jax2tf.convert(f_no_jit), x))
f_jit = lambda x: jax.jit(jnp.sin)(x)
# TODO(b/207464757): TF compilation is disabled
self.assertFalse(has_xla_must_compile(jax2tf.convert(f_jit), x))
def test_converts_jax_arrays(self):
f_tf = tf.function(lambda x: x)
self.assertEqual(f_tf(jnp.zeros([])).numpy(), 0.)
self.assertEqual(f_tf(jnp.ones([])).numpy(), 1.)
f_tf = tf.function(lambda x: x + x)
self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.)
# Test with ShardedDeviceArray.
n = jax.local_device_count()
mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n]))
f_tf = tf.function(lambda x: x)
self.assertAllClose(f_tf(mk_sharded(jnp.zeros)).numpy(),
jnp.zeros([n]))
self.assertAllClose(f_tf(mk_sharded(jnp.ones)).numpy(),
jnp.ones([n]))
@jtu.skip_on_devices("gpu")
def test_bfloat16_passed_by_tf(self):
f_jax = lambda a, b: a + b
f_tf = tf.function(jax2tf.convert(f_jax), autograph=False,
input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
tf.TensorSpec([512, 512], tf.bfloat16)])
self.assertIsNotNone(f_tf.get_concrete_function())
@jtu.skip_on_devices("gpu")
def test_bfloat16_returned_by_jax(self):
f_jax = lambda a, b: (a + b).astype(jnp.bfloat16)
f_tf = jax2tf.convert(f_jax)
self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)
@jtu.skip_on_devices("gpu")
def test_bfloat16_tf_grad(self):
f_jax = lambda a, b: a + b
def _tf_grad(a, b):
with tf.GradientTape() as tape:
tape.watch(a)
result = jax2tf.convert(f_jax)(a, b)
return result, tape.gradient(result, a)
f_tf = tf.function(_tf_grad, autograph=False,
input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
tf.TensorSpec([512, 512], tf.bfloat16)])
self.assertIsNotNone(f_tf.get_concrete_function())
@jtu.sample_product(
dtype=[np.int64, np.float64],
with_function=[True, False],
)
def test_converts_64bit(self, dtype=np.int64, with_function=False):
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
big_const = np.full((5,), 2 ** 33, dtype=dtype)
self.ConvertAndCompare(jnp.sin, big_const)
f_conv = jax2tf.convert(jnp.sin)
if with_function:
f_conv = tf.function(f_conv, autograph=False)
# We check also when we pass tf.Variable or tf.Tensor into the
# converted function
self.assertAllClose(jnp.sin(big_const),
f_conv(tf.Variable(big_const)))
self.assertAllClose(jnp.sin(big_const),
f_conv(tf.constant(big_const)))
def test_64bit_behavior_enable_x64_readme(self):
# Tests some of the examples from the README
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
# JAX and TF have different default float types if JAX_ENABLE_X64=1
self.assertEqual(tf.math.sin(3.14).dtype, tf.float32)
self.assertEqual(jnp.sin(3.14).dtype, jnp.float64)
# jax2tf.convert has the same behavior as JAX
self.assertEqual(jax2tf.convert(jnp.sin)(3.14).dtype, tf.float64)
# The following will compute `sin` in float64.
self.assertEqual(tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64)).dtype, tf.float64)
# The following will compute `sin` in float32.
self.assertEqual(tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14)).dtype, tf.float32)
def test_64bit_behavior_not_enable_x64_readme(self):
# Tests some of the examples from the README
if config.jax_enable_x64:
self.skipTest("requires not x64 mode")
# JAX and TF have same default float types if JAX_ENABLE_X64=0
self.assertEqual(tf.math.sin(3.14).dtype, tf.float32)
self.assertEqual(jnp.sin(3.14).dtype, jnp.float32)
self.assertEqual(tf.math.sin(np.float64(3.14)).dtype, tf.float64)
# JAX forces values to 32-bit
self.assertEqual(jnp.sin(np.float64(3.14)).dtype, jnp.float32)
# jax2tf.convert has the same behavior as JAX
self.assertEqual(jax2tf.convert(jnp.sin)(3.14).dtype, tf.float32)
self.assertEqual(jax2tf.convert(jnp.sin)(np.float64(3.14)).dtype, tf.float32)
self.assertEqual(tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64)).dtype, tf.float32)
def test_function(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
self.ConvertAndCompare(f_jax, 0.7)
@jtu.sample_product(with_function=[False, True])
def test_gradients_disabled(self, with_function=False):
f_tf = jax2tf.convert(jnp.tan, with_gradient=False)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
x = tf.ones([])
# With tf.function the error is raised when we evaluate f_tf(x), in
# eager mode when we evaluate tape.gradient(y, x)
with self.assertRaisesRegex(LookupError,
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
with tf.GradientTape() as tape:
tape.watch(x)
y = f_tf(x)
_ = tape.gradient(y, x)
@jtu.sample_product(with_function=[False, True])
def test_gradients(self, with_function=True):
def f(x, y):
return x * x, x * y
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_type = jax2tf.dtype_of_val(4.)
x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
y = tf.Variable(5., dtype=default_float_type)
with tf.GradientTape(persistent=True) as tape:
u, v = f_tf(x, y)
self.assertAllClose(2. * 4., tape.gradient(u, x))
self.assertAllClose(0., tape.gradient(u, y))
self.assertAllClose(5., tape.gradient(v, x))
self.assertAllClose(4., tape.gradient(v, y))
@jtu.sample_product(with_function=[False, True])
def test_gradients_pytree(self, with_function=False):
def f(xy: Tuple[float, float]) -> Dict[str, float]:
x, y = xy
return dict(one=x * x, two=x * y)
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_dtype = jax2tf.dtype_of_val(4.)
x = tf.Variable(4., dtype=default_float_dtype)
y = tf.Variable(5., dtype=default_float_dtype)
with tf.GradientTape(persistent=True) as tape:
uv = f_tf((x, y))
self.assertAllClose(2. * 4., tape.gradient(uv["one"], x))
self.assertAllClose(0., tape.gradient(uv["one"], y))
self.assertAllClose(5., tape.gradient(uv["two"], x))
self.assertAllClose(4., tape.gradient(uv["two"], y))
def test_custom_pytree_readme(self):
# Code examples from README.md
class CustomPair:
def __init__(self, a, b):
self.a = a
self.b = b
jax.tree_util.register_pytree_node(CustomPair,
lambda x: ((x.a, x.b), None),
lambda _, ab: CustomPair(*ab))
def f_jax(pair: CustomPair):
return np.float32(2.) * pair.a + np.float32(3.) * pair.b
f_tf = jax2tf.convert(f_jax)
x = CustomPair(np.float32(4.), np.float32(5.))
res_jax = f_jax(x)
# TF execution works as long as JAX can flatten the arguments and results
res_tf = f_tf(x)
self.assertAllClose(res_jax, res_tf.numpy())
res_tf_2 = tf.function(f_tf, autograph=False, jit_compile=True)(x)
self.assertAllClose(res_jax, res_tf_2)
# wrapped TF function to use only standard containers
def f_tf_wrapped(a, b):
return f_tf(CustomPair(a, b))
# Try to put into SavedModel
my_model = tf.Module()
# Save a function that can take scalar inputs.
my_model.f = tf.function(f_tf_wrapped, autograph=False,
input_signature=[tf.TensorSpec([], tf.float32),
tf.TensorSpec([], tf.float32)])
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(my_model)))
tf.saved_model.save(my_model, model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
# Restoring (note: the restored model does *not* require JAX to run, just XLA).
restored_model = tf.saved_model.load(model_dir)
def restored_f(pair: CustomPair):
return restored_model.f(pair.a, pair.b)
res_tf_3 = restored_f(x)
self.assertAllClose(res_jax, res_tf_3)
grad_jax = jax.grad(f_jax)(x)
x_v = [tf.Variable(x.a), tf.Variable(x.b)]
with tf.GradientTape() as tape:
res = f_tf_wrapped(*x_v)
grad_tf = tape.gradient(res, x_v)
self.assertAllClose(grad_jax.a, grad_tf[0])
self.assertAllClose(grad_jax.b, grad_tf[1])
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_ordered_dict_input(self, with_function=True):
def f(inputs):
out = 0.0
for v in inputs.values():
out += jnp.sum(v)
return out
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_type = jax2tf.dtype_of_val(4.)
x = tf.Variable([4.], dtype=default_float_type)
y = tf.Variable([4., 5.], dtype=default_float_type)
inputs = collections.OrderedDict()
inputs['r'] = x
inputs['d'] = y
with tf.GradientTape(persistent=True) as tape:
u = f_tf(inputs)
self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_custom_jvp(self, with_function=True):
"""Check gradients, for a function with custom JVP."""
@jax.custom_jvp
def f(x):
return x * x
@f.defjvp
def f_jvp(primals, tangents):
# 3 * x * x_t
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = 3. * x * x_dot
return primal_out, tangent_out
self.assertAllClose(4. * 4., f(4.))
self.assertAllClose(3. * 4., jax.grad(f)(4.))
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
self.assertAllClose(4. * 4., f_tf(4.))
x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
with tf.GradientTape() as tape:
tape.watch(x)
y = f_tf(x)
self.assertAllClose(4. * 4., y)
self.assertAllClose(3. * 4., tape.gradient(y, x))
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_custom_vjp(self, with_function=True):
"""Check gradients, for a function with custom VJP."""
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
self.assertAllClose(4. * 4., f(4.))
self.assertAllClose(3. * 4., jax.grad(f)(4.))
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
self.assertAllClose(4. * 4., f_tf(4.))
x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
with tf.GradientTape() as tape:
tape.watch(x)
y = f_tf(x)
self.assertAllClose(4. * 4., y)
self.assertAllClose(3. * 4., tape.gradient(y, x))
def test_gradient_with_float0_intermediate(self):
# Gradient over integer-argument functions
def f(x, y): # x is an int, y is a float
return 2 * x + y
def g(x): # x: f32
return 2. * f(3 * x.astype("int32"), x * 4.)
x = 2.
grad_g = jax.grad(g)
self.ConvertAndCompare(grad_g, x)
def test_gradient_with_float0_result(self):
# Gradient over integer-argument functions, with float0 result
def f(x, y): # x is an int, y is a float
return 2 * x + y
def g(x): # x: i32
return jnp.sum(2. * f(3 * x, 4. * jnp.array(x, jnp.dtype("float32"))))
grad_g = jax.grad(g, allow_int=True)
x = 2
d_dx_jax = grad_g(x)
d_dx_tf = jax2tf.convert(grad_g)(x)
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.bool_),
d_dx_tf.numpy())
shape = (3, 4)
x = np.ones(shape, dtype=np.int32)
d_dx_jax = grad_g(x)
d_dx_tf = jax2tf.convert(grad_g)(x)
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.bool_),
d_dx_tf.numpy())
@jtu.sample_product(with_function=[False, True])
def test_gradients_unused_argument_readme(self, with_function=False):
# x1 and x3 are not used. x3 has integer type.
def fn(x0, x1, x2, x3):
return x0 * 0. + x2 * 2.
xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
with tf.GradientTape(persistent=True) as tape:
res = fn(*xs)
g_tf_native = tape.gradient(res, xs)
self.assertAllClose(g_tf_native[0].numpy(), np.float32(0.))
self.assertIsNone(g_tf_native[1])
self.assertAllClose(g_tf_native[2].numpy(), np.float32(2.))
self.assertIsNone(g_tf_native[3])
g_tf_native_0 = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_tf_native_0[0].numpy(), np.float32(0.))
self.assertAllClose(g_tf_native_0[1].numpy(), np.float32(0.))
self.assertAllClose(g_tf_native_0[2].numpy(), np.float32(2.))
self.assertAllClose(g_tf_native_0[3].numpy(), np.int32(0))
# Now with jax2tf.convert
with tf.GradientTape(persistent=True) as tape:
conv_fn = jax2tf.convert(fn, with_gradient=True)
if with_function:
conv_fn = tf.function(conv_fn, autograph=False)
res = conv_fn(*xs)
g_jax2tf = tape.gradient(res, xs)
# Returns: 0., 0., 2., None
# Note that the gradient for x1 is 0.
self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertIsNone(g_jax2tf[3])
g_jax2tf = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0))
@jtu.sample_product(with_function=[False, True])
def test_gradients_int_argument(self, with_function=False):
# https://github.com/google/jax/issues/6975
# Also issue #6975.
# An expanded version of test_gradients_unused_argument
state = dict(
float_used=np.array([0.7, 0.9], dtype=np.float32),
float_passthrough=np.float16(1.),
float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32),
int_used=np.int16(5),
int_passthrough=np.int8(7),
int_unused=np.array([1, 2, 3], dtype=np.uint32),
bool_used=np.array([True, False, False, True], dtype=np.bool_),
bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_),
bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_),
)
def jax_f(state):
res = dict(state,
float_used=2. * state["float_used"],
int_used=3 * state["int_used"],
bool_used=(state["bool_used"] == state["bool_used"]))
del res["float_unused"]
del res["int_unused"]
del res["bool_unused"]
return res
args = (state,)
res_jax = jax_f(*args)
# Native JAX AD
vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax)
grad_jax, = vjp_jax_fun(*args_vjp)
def compare_with_overrides(*, what, expected, **expected_overrides):
what_keys = set(what.keys())
expected_keys = set(expected.keys())
self.assertEqual(what_keys, expected_keys)
for k, w in what.items():
e = expected[k]
if k in expected_overrides:
if expected_overrides[k] == "ZERO":
e = np.zeros_like(w)
elif expected_overrides[k] == "ZERO_BOOL":
e = np.zeros(np.shape(w), dtype=np.bool_)
elif expected_overrides[k] == "ONE":
e = np.ones_like(w)
else:
e = expected_overrides[k]
if e is None:
self.assertIsNone(w, msg=k)
else:
self.assertIsNotNone(w, msg=k)
w = w.numpy() if isinstance(w, tf.Tensor) else e
e = e.numpy() if isinstance(e, tf.Tensor) else e
try:
self.assertAllClose(e, w, err_msg=k)
except:
print(f"Failed at {k}")
raise
# compare_with_overrides(g_jax, {},
# bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0),
# bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0),
# bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0),
# float_passthrough=np.ones_like(state["float_passthrough"]),
# float_unused=np.zeros_like(state["float_unused"]),
# float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype),
# int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0),
# int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0),
# int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0))
# Now native TF gradients, only to test how native TF AD works
_, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad(
jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO)
compare_with_overrides(what=grad_tf_0,
expected=grad_jax,
float_unused="ZERO",
bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO",
int_used="ZERO", int_passthrough="ONE", int_unused="ZERO")
_, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad(
jax_f, args,
unconnected_gradients=tf.UnconnectedGradients.NONE)
compare_with_overrides(what=grad_tf_None,
expected=grad_tf_0,
float_unused=None, int_used=None, int_unused=None,
bool_used=None, bool_unused=None)
f_tf_jax = jax2tf.convert(jax_f)
if with_function:
f_tf_jax = tf.function(f_tf_jax, autograph=False)
_, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args)
# Same results as TF native AD with tf.UnconnectedGradients.ZERO
compare_with_overrides(what=grad_tf_jax_0,
expected=grad_tf_0,
int_passthrough="ZERO", bool_passthrough="ZERO")
_, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad(
f_tf_jax, args,
unconnected_gradients=tf.UnconnectedGradients.NONE)
compare_with_overrides(what=grad_tf_jax_None,
expected=grad_tf_0,
int_used=None, int_passthrough=None, int_unused=None,
bool_unused=None, bool_used=None, bool_passthrough=None)
# Not convert the JAX gradient function
tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun)
grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp)
compare_with_overrides(what=grad_tf_vjp_jax,
expected=grad_tf_0,
bool_passthrough="ZERO_BOOL",
bool_unused="ZERO_BOOL", bool_used="ZERO_BOOL",
int_passthrough="ZERO_BOOL", int_unused="ZERO_BOOL",
int_used="ZERO_BOOL")
def test_readme_gradient_int(self):
x = np.array(2, dtype=np.int16)
def f_jax(x): # x: int16
return x.astype(np.float32) * 2.
print(jax.grad(f_jax, allow_int=True)(x))
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])
print(jax2tf.convert(jax.grad(f_jax, allow_int=True))(x))
# returns a 0 with same shape as x, but with dtype int32
def f_tf(x): # x: int16
return tf.cast(x, tf.float32) * 2.
xv = tf.Variable(x)
with tf.GradientTape(persistent=True) as tape:
print(tape.gradient(f_tf(xv), xv))
# returns None
print(tape.gradient(f_tf(xv), xv,
unconnected_gradients=tf.UnconnectedGradients.ZERO))
# returns 0 with the same shape and dtype as x
def test_convert_argument_non_callable_error(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
jax2tf.convert(5.)
def test_convert_argument_non_tensor_error(self):
with self.assertRaisesRegex(TypeError,
"Argument.*is not a valid JAX type"):
jax2tf.convert(lambda x: x)(lambda y: y)
def test_argument_eager_tensor(self):
x = jax2tf.convert(jnp.sin)(1.)
jax2tf.convert(jnp.cos)(x) # No error
def test_checkpoint_wrapper_types(self):
m = tf.Module()
m.a = [tf.Module(), tf.Module()]
m.b = (tf.Module(), tf.Module())
m.c = {'a': tf.Module(), 'b': tf.Module()}
self.assertNotEqual(type(m.a), list)
self.assertNotEqual(type(m.b), tuple)
self.assertNotEqual(type(m.c), dict)
self.assertLen(jax.tree_util.tree_leaves(m.a), 2)
self.assertLen(jax.tree_util.tree_leaves(m.b), 2)
self.assertLen(jax.tree_util.tree_leaves(m.c), 2)
def test_issue_10586(self):
class JaxModule(tf.Module):
def __init__(self):
self._params = {'w': tf.Variable(tf.ones([784, 10]), name='w'),
'b': tf.Variable(tf.ones([10]), name='b')}
def __call__(self, x):
return jax2tf.convert(lambda p, x: x @ p['w'] + p['b'])(self._params, x)
net = JaxModule()
images = tf.ones([1, 784])
with tf.GradientTape() as tape:
loss = tf.reduce_sum(net(images))
params = tape.watched_variables()
grads = tape.gradient(loss, params)
for var, grad in zip(params, grads):
self.assertEqual(var.shape, grad.shape, msg=var.name)
def test_custom_jvp(self):
"""Conversion of function with custom JVP"""
@jax.custom_jvp
def f(x):
return x * x
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = 3. * x * x_dot
return primal_out, tangent_out
arg = 0.7
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "jvp")
self.TransformConvertAndCompare(f, arg, "vmap")
self.TransformConvertAndCompare(f, arg, "jvp_vmap")
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
def test_custom_vjp(self):
"""Conversion of function with custom VJP"""
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
arg = 0.7
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "vmap")
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
def test_remat(self):
def f(x1):
x2 = jnp.sin(x1)
x3 = jnp.sin(x2)
x4 = jnp.sin(x3)
return x4
remat_f = ad_checkpoint.checkpoint(f)
# The computation of grad_f computes "sin" 5 times, 3 for the forward pass
# and then to rematerialize "x2" and "x3" in the backward pass.
arg = np.array(3.)
f_tf = jax2tf.convert(jax.grad(remat_f))
f_tf_hlo = self.TfToHlo(f_tf, arg)
if jax.config.jax_remat_opt_barrier:
self.assertRegex(f_tf_hlo, r"opt-barrier")
else:
self.assertRegex(f_tf_hlo,
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
def test_remat_free_var(self):
def f(x):
y = 2 * x
@ad_checkpoint.checkpoint
def g():
return y
return g()
arg = 3.
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "grad")
def test_checkpoint_name(self):
def f_jax(x):
return ad_checkpoint.checkpoint_name(jnp.sin(x), "sin")
jax2tf.convert(f_jax)(1.) # No error.
def test_convert_nullary_func(self):
# Even nullary functions are converted to TF (as opposed to constant-folded
# in JAX prior to conversion).
def f_jax():
return jnp.sin(1.)
f_tf = jax2tf.convert(f_jax)
# for native serialization the HLO we get from TF is constant-folded, so this
# test fails.
if not config.jax2tf_default_native_serialization:
self.assertIn("sine(", self.TfToHlo(f_tf))
def test_convert_of_nested_independent_jit(self):
def func(x):
def inner1(y):
return x + y
# The JIT does not have data dependency
return jax.jit(inner1)(1.)
jax2tf.convert(func)(2.)
def test_convert_of_nested_dependent_jit(self):
def func(x):
def inner1(y):
return x + y
# The JIT does have data dependency
return jax.jit(inner1)(x)
jax2tf.convert(func)(2.) # No error
def test_jit_unused(self):
def f_jax(x, y_unused):
return x * np.float32(2.)
x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32)
res_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))(x, y_unused)
self.assertAllClose(f_jax(x, None), res_tf)
@parameterized.named_parameters(
dict(testcase_name=mode, mode=mode)
for mode in ("eager", "graph", "compiled"))
def test_jit_unused_grad(self, mode="eager"):
def f_jax(x, y_unused):
return x * np.float32(2.)
x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32)
res_jax = f_jax(x, y_unused)
f_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))
x_tf, y_unused_tf = tf.constant(x), tf.constant(y_unused)
def grad_tf(x, y_unused):
with tf.GradientTape() as tape:
tape.watch(x)
tape.watch(y_unused)
res_tf = f_tf(x, y_unused)
grad_tf_x, grad_tf_y = tape.gradient(res_tf, (x, y_unused))
return res_tf, grad_tf_x, grad_tf_y
if mode == "graph":
grad_tf = tf.function(grad_tf, autograph=False)
elif mode == "compiled":
grad_tf = tf.function(grad_tf, autograph=False, jit_compile=True)
res_tf, grad_tf_x, grad_tf_y = grad_tf(x_tf, y_unused_tf)
self.assertAllClose(res_jax, res_tf)
self.assertAllClose(np.float32(2.), grad_tf_x)
self.assertIsNone(grad_tf_y)
def test_nested_convert_error(self):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
jax2tf.convert(outer)(np.ones((4,), dtype=np.float32))
def test_nested_convert_error_non_tracer(self):
"""The inner convert takes non-tracer arguments"""
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
return y + sin_1
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
jax2tf.convert(outer)(2.)
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
def test_convert_under_transform_error(self, transform="vmap"):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
def test_convert_under_transform_error_non_tracer(self, transform="vmap"):
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
return y + sin_1
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
def test_name_scope(self):
def run_tf():
@jax.named_call
def my_test_function_jax(x):
return x * x
def caller_jax(x):
return my_test_function_jax(jnp.sin(x))
out = jax2tf.convert(caller_jax, with_gradient=False)(2.)
return out
if config.jax2tf_default_native_serialization:
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
else:
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
if "my_test_function_jax/pjit_fn_/Mul" not in graph_def:
self.assertIn("my_test_function_jax/jit_fn_/Mul", graph_def)
def test_bfloat16_constant(self):
# Re: https://github.com/google/jax/issues/3942
def jax_fn_scalar(x):
x = x.astype(jnp.bfloat16)
x *= 2.
return x
def jax_fn_array(x):
x = x.astype(jnp.bfloat16)
x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16)
return x
tf_fn_scalar = jax2tf.convert(jax_fn_scalar)
self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750))
tf_fn_array = jax2tf.convert(jax_fn_array)
self.assertAllClose(
tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5],
jnp.bfloat16))
def test_shared_constants(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
def f(x):
return x + const + const + const + const
f_tf_consts = self.FindLargeTfConstants(jax2tf.convert(f), const)
self.assertLen(f_tf_consts, 1)
def test_shared_constants_under_cond(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const_size = 512
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
x = np.ones((const_size,), dtype=np.float32)
def f1(x):
# Ensure that we first see the constants in the inside jaxpr
return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const
def f2(x):
return f1(x) + const # The extra const should not cost anything
f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), x, at_least=const_size)
f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), x, at_least=const_size)
self.assertLen(f2_consts, len(f1_consts))
def test_shared_constants_under_scan(self):
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const_size = 512
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
xs = np.ones((8, const_size), dtype=np.float32)
def f1(xs):
res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
jnp.zeros((const_size,), dtype=np.float32), xs)
return res
def f2(xs):
return f1(xs) + const # The extra const should not be saved
f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), xs, at_least=const_size)
f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), xs, at_least=const_size)
self.assertLen(f2_consts, len(f1_consts))
def test_shared_constants_under_jit(self):
# We do not share constants under jit.
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const = np.random.uniform(size=(16, 16)).astype(np.float32) # A shared constant
@jax.jit
def g_jit(x):
return x * const
def f(x):
return g_jit(x) + const + const
f_tf_graph_consts = self.FindLargeTfConstants(jax2tf.convert(f), const)
self.assertLen(f_tf_graph_consts, 1)
def test_shared_constants_randint(self):
# randint has the property that the TF lowering of the randbits_p
# primitive generates constants that did not exist in the Jaxpr. As such
# it has created new errors related to the sharing of the constants.
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
key = jax.random.PRNGKey(42)
def f_nested_jax(x):
# Lowering this will generate a tf.constant(shape=(1,), dtype=np.int32)
# that was not already in the Jaxpr, and hence JAX did not get a chance
# to share.
return x + jax.random.randint(key, shape=x.shape,
minval=0, maxval=100, dtype=np.int32)
def f_jax(x):
res = lax.cond(x[0] >= 2, lambda: f_nested_jax(x), lambda: f_nested_jax(x))
res += lax.while_loop(lambda x: f_nested_jax(x)[0] <= 0, f_nested_jax, x)
# We also generate tf.while in the batching rule for cond
res += jax.vmap(lambda x: lax.cond(x[0] >= 2,
lambda: f_nested_jax(x),
lambda: f_nested_jax(x)))(jnp.stack([x, x]))
res += f_nested_jax(x)
return res
# Must be odd to trigger the failure
x = np.array([123, 456, 789], dtype=np.int32)
f_tf = tf.function(jax2tf.convert(f_jax), autograph=False)
res_tf = f_tf(x)
self.assertAllClose(res_tf, f_jax(x))
def test_weak_types(self):
mul = jax.jit(jnp.multiply)
# The value `2` here should be weakly typed, and should not lead to
# promotion.
tf_fn = jax2tf.convert(lambda x: mul(x, 2.))
self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(),
jnp.bfloat16(2.750))
@jtu.sample_product(with_function=[False, True])
def test_kwargs(self, with_function=False):
# Re: https://github.com/google/jax/issues/6791
def f_jax(*, x):
return jnp.sum(x)
f_tf = jax2tf.convert(f_jax)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
self.assertAllClose(
f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs.
np.zeros((), dtype=np.float32))
@jtu.sample_product(with_function=[False, True])
def test_grad_kwargs(self, with_function=False):
# Re: https://github.com/google/jax/issues/6791
x = (np.zeros(3, dtype=np.float32),
np.zeros(4, dtype=np.float32))
def f_jax(*, x=(1., 2.)):
return jnp.sum(x[0]) + 2. * jnp.sum(x[1])
f_tf = jax2tf.convert(f_jax)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
xv = tf.nest.map_structure(tf.Variable, x)
with tf.GradientTape() as tape:
res = f_tf(x=xv)
grad_tf = tape.gradient(res, xv)
self.assertAllClose((np.full_like(x[0], fill_value=1.),
np.full_like(x[1], fill_value=2.)),
(grad_tf[0].numpy(), grad_tf[1].numpy()))
@jtu.skip_on_flag("jax2tf_default_native_serialization", True)
def test_enable_xla(self):
# Tests that enable_xla flag is properly scoped to a conversion.
def fun(x):
# lax.reduce is unlikely to ever be convertible with enable_xla=False
return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))
tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
x = np.ones((2, 3), dtype=np.float32)
self.assertAllClose(fun(x), tf_fun_with_xla(x))
with self.assertRaisesRegex(NotImplementedError,
"Call to reduce cannot be converted with enable_xla=False"):
tf_fun_without_xla(x)
# Now in reverse order (we had bugs with the management of enable_xla global)
tf_fun2_without_xla = jax2tf.convert(lambda x: fun(x), enable_xla=False)
tf_fun2_with_xla = jax2tf.convert(lambda x: fun(x), enable_xla=True)
with self.assertRaisesRegex(NotImplementedError,
"Call to reduce cannot be converted with enable_xla=False"):
tf_fun2_without_xla(x)
self.assertAllClose(fun(x), tf_fun2_with_xla(x))
def test_device_array_arg(self):
self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32))
def test_randint(self):
def randint():
return jax.random.randint(
jax.random.PRNGKey(42), shape=(), minval=0, maxval=1)
self.ConvertAndCompare(randint)
def test_error_disallowed_custom_call(self):
if jtu.device_under_test() != "cpu":
self.skipTest("Test intended for CPU only")
# For now triangular_solve on CPU uses the unsupported "blas_strsm" target
a = np.arange(16, dtype=np.float32).reshape((4, 4))
b = np.arange(4, dtype=np.float32).reshape((4, 1))
with self.assertRaisesRegex(ValueError,
"Cannot serialize code with custom calls whose targets .*"):
jax2tf.convert(
lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True),
native_serialization=True)(a, b)
def test_op_metadata_simple(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# A simple example
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_simple(x):
return jnp.sin(x)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_simple, x,
[tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.start_line + 2,
op_name="jax2tf(f_simple)/sin",
op_type="sin")
]
)
def test_op_metadata_sub_jit(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# Calling a jitted-function
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_callee(x):
return jnp.cos(x)
def f_caller(x):
y = jnp.tanh(x)
z = jax.jit(f_callee)(y)
return jnp.sin(z)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_caller, x,
[tf_test_util.OpMetadataGraph(tf_type="Tanh",
source_file=__file__,
source_line=user_frame.start_line + 4,
op_name="jax2tf(f_caller)/tanh",
op_type="tanh"),
tf_test_util.OpMetadataGraph(tf_type="Cos",
source_file=__file__,
source_line=user_frame.start_line + 2,
op_name="jax2tf(f_caller)/jit(f_callee)/cos",
op_type="cos"),
tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.start_line + 6,
op_name="jax2tf(f_caller)/sin",
op_type="sin"),
]
)
def test_op_metadata_named(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# Calling a jax.named_call
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_callee(x):
return jnp.cos(x)
def f_caller(x):
y = jnp.tanh(x)
z = jax.named_call(f_callee, name="callee")(y)
return jnp.sin(z)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_caller, x,
[tf_test_util.OpMetadataGraph(tf_type="Tanh",
source_file=__file__,
source_line=user_frame.start_line + 4,
op_name="jax2tf(f_caller)/tanh",
op_type="tanh"),
tf_test_util.OpMetadataGraph(tf_type="Cos",
source_file=__file__,
source_line=user_frame.start_line + 2,
op_name="jax2tf(f_caller)/named(callee)/cos",
op_type="cos"),
tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.start_line + 6,
op_name="jax2tf(f_caller)/sin",
op_type="sin"),
]
)
def test_op_metadata_while_and_cond(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# An example with while and cond
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_while_cond(x):
def body_fun(i_acc):
i, acc = i_acc
return (i + 1,
(jnp.cos(acc) +
lax.cond(jnp.mod(i, 2) == 0,
lambda acc: jnp.sin(acc),
lambda acc: acc,
acc)))
_, acc = lax.while_loop(
lambda i_acc: i_acc[0] <= 5,
body_fun, (0, x))
return acc
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_while_cond, x,
[tf_test_util.OpMetadataGraph(tf_type="Cos",
source_file=__file__,
source_line=user_frame.start_line + 5,
op_name="jax2tf(f_while_cond)/while/body/cos",
op_type="cos"),
tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.start_line + 7,
op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
op_type="sin"),
tf_test_util.OpMetadataGraph(tf_type="FloorMod",
source_file=__file__,
source_line=user_frame.start_line + 6,
op_name="jax2tf(f_while_cond)/while/body/rem",
op_type="rem"),
]
)
def test_op_metadata_batched_while(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# An example with while and cond
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
@jax.vmap
def f_while(x):
def body_fun(carry):
new_carry = jnp.sin(carry) # We look for "sin" in the graph
return new_carry
_, carry = lax.while_loop(
lambda carry: jnp.all(carry <= x), # We look for "le" in the graph
body_fun, x)
return carry
shape = (3, 2)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jax_comp = jax.xla_computation(f_while)(x)
backend = jax._src.xla_bridge.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
print(f"JAX OPT HLO = {jax_opt_hlo}")
self.CheckOpMetadata(
f_while, x,
[tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.start_line + 4,
op_name="jax2tf(f_while)/while/body/sin",
op_type="sin"),
tf_test_util.OpMetadataGraph(tf_type="LessEqual",
source_file=__file__,
source_line=user_frame.start_line + 8,
op_name="jax2tf(f_while)/while/body_pred/le",
op_type="le"),
]
)
def test_op_metadata_disabled(self):
self.skipTest("include_xla_op_metadata not yet enabled")
def f_simple(x):
return jnp.sin(x)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_simple, x,
[],
include_xla_op_metadata=False
)
def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
"""Assert all operations name start with ```scope_name```.
Also the scope_name only occur one time.
"""
result = g.get_operations()
if not result:
self.fail("result is empty.")
for op in result:
logging.info("tf op.name = %s", op.name)
if not op.name.startswith(scope_name):
self.fail(f"{op.name} does not start with {scope_name}.")
def test_name_scope_polymorphic(self):
if config.jax2tf_default_native_serialization and not config.jax_dynamic_shapes:
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
def func_jax(x, y):
return jnp.sin(x) + jnp.cos(y)
func_tf = jax2tf.convert(
func_jax, polymorphic_shapes="(b,...)", with_gradient=True)
outer_scope = "output_a"
g = tf.Graph()
with g.as_default() as g:
with tf.name_scope(outer_scope):
x = tf.Variable(
tf.zeros(shape=(1, 5), dtype=tf.dtypes.float32), name="x")
y = tf.compat.v1.placeholder(tf.dtypes.float32, (None, 5), "y")
_ = func_tf(x, y)
self.assertAllOperationStartWith(g, outer_scope)
# wrap tf.function
g2 = tf.Graph()
with g2.as_default() as g:
with tf.name_scope(outer_scope):
x = tf.Variable(
tf.zeros(shape=(1, 5), dtype=tf.dtypes.float32), name="x")
y = tf.compat.v1.placeholder(tf.dtypes.float32, (None, 5), "y")
_ = tf.function(func_tf, jit_compile=True, autograph=False)(x, y)
self.assertAllOperationStartWith(g2, outer_scope)
def test_name_scope_cond(self):
def f(x):
def f_pos(x):
with jax.named_scope("jax_f_pos"):
return lax.cond(x < 1., jnp.cos, jnp.sin, x)
with jax.named_scope("jax_f_outer"):
return lax.cond(x > 0., f_pos, lambda x: x, x)
@tf.function(jit_compile=True, autograph=False)
def outer_forward():
with tf.name_scope("tf_outer_forward"):
x = 0.5
f_tf = jax2tf.convert(f)
_ = f_tf(x)
g = outer_forward.get_concrete_function().graph
self.assertAllOperationStartWith(g, "tf_outer_forward")
for func in g._functions.values():
self.assertAllOperationStartWith(
func.graph, "tf_outer_forward/jax2tf_f_/jax_f_outer")
x = tf.Variable(0.5, name="tf_outer_back/x")
@tf.function(jit_compile=True, autograph=False)
def outer_back():
with tf.name_scope("tf_outer_back"):
f_tf = jax2tf.convert(f)
with tf.GradientTape() as tape:
res_tf = f_tf(x)
_ = tape.gradient(res_tf, x)
g = outer_back.get_concrete_function().graph
self.assertAllOperationStartWith(g, "tf_outer_back")
for func in g._functions.values():
self.assertAllOperationStartWith(func.graph, "tf_outer_back")
def test_name_scope_while_loop(self):
def f(x):
with tf.name_scope("outer_scope"):
def condition(x):
return jnp.sum(x, keepdims=False) < 100
def body(x):
return jnp.add(x, 2.0)
result = jax.lax.while_loop(condition, body, x)
return result
tf_f = tf.function(jax2tf.convert(f), jit_compile=True, autograph=False)
g = tf_f.get_concrete_function(tf.zeros((1, 3))).graph
for func in g._functions.values():
for op in func.graph.get_operations():
if op.name.count(f"outer_scope/jax2tf_{f.__name__}_/while") > 1:
self.fail(
"tf graph has repeated name issue on when converting lax.while to tf.while."
f"See op.name = : {op.name}")
@parameterized.named_parameters(
dict(testcase_name=(
f"{'with_mesh_' if with_mesh else ''}"
f"2={transform2 if transform2 != 'none' else ''}"
f"_1={transform1 if transform1 != 'none' else ''}"
f"{'_nullary' if nullary else ''}"),
with_mesh=with_mesh, transform1=transform1,
transform2=transform2, nullary=nullary)
# Test transform2(transform1(func)
for transform1 in [
"none",
"jit",
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P",
"pjit_in_shardings_Sharding",
"shard_map", "xmap", "pmap"]
for transform2 in (
["none", "pjit_in_shardings_None", "pjit_in_shardings_P",
"pjit_in_shardings_Sharding"]
)
# Whether the function can be nullary
for nullary in (
# To reduce the number of tests
[True, False] if transform2 == "none" else
[False])
# Whether we use a "with mesh"
for with_mesh in (
[True] if (transform1 not in ["base", "jit", "pjit"] or
transform2 != "none") else
[False, True])
)
def test_cross_platform(self, with_mesh=True, transform1="pjit_in_shardings_P",
transform2="pjit_in_shardings_P", nullary=False):
# Tests cross-lowering for
# with mesh:
# transform2(transform1(func))
if transform2 == "none" and (
transform1 == "shard_map" or
transform1 in ["pjit_in_shardings_P", "pjit_in_shardings_Sharding"] and nullary):
raise unittest.SkipTest("Skip because must have pjit at top level")
x = np.ones((4, 6), dtype=np.float32)
mesh = sharding.Mesh(jax.devices()[:1], ("a",))
# cummax has distinctive lowering for TPU, using a reduce-window op
func = lambda x: lax.cummax(x, axis=0, reverse=False)
# For shard_map we cannot use cummax :-( because it does not have a
# replication rule. But we use lax.all_gather which on TPU is lowered with
# an all-gather op
func_shard_map = lambda x: lax.all_gather(x, 'a', axis=1, tiled=True)
def apply_transform(func, transform: str):
transformed_func = dict(
none=func,
jit=jax.jit(func),
jit_in_shardings_None=jax.jit(func, in_shardings=None), # type: ignore
jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), # type: ignore
jit_in_shardings_Sharding=jax.jit(
func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), # type: ignore
pjit=pjit.pjit(func),
pjit_in_shardings_None=pjit.pjit(func, in_shardings=None,
out_shardings=None),
pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),),
out_shardings=P("a")),
pjit_in_shardings_Sharding=pjit.pjit(
func,
in_shardings=(sharding.NamedSharding(mesh, P("a")),),
out_shardings=sharding.NamedSharding(mesh, P("a"))),
shard_map=(
shard_map(func, mesh, in_specs=(P("a", None),),
out_specs=P("a", None))),
xmap=xmap(func, in_axes=({0: 'axis'},),
out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
pmap=jax.pmap(func, in_axes=0, out_axes=0),
)[transform]
return transformed_func
transformed1_func = apply_transform(
(func_shard_map if transform1 == "shard_map" else func),
transform1)
assert transform2 not in ["xmap", "shard_map"]
transformed2_func = apply_transform(transformed1_func, transform2)
if transform1 == "xmap" and transform2 in ["pjit", "none"]:
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
if transform1 == "pmap":
x = x.reshape((1, -1)) # Since we use 1 device
if not nullary:
func_to_convert = transformed2_func
args = [x]
else:
func_to_convert = lambda: transformed2_func(jnp.ones(x.shape,
dtype=x.dtype))
args = []
if transform1 == "pmap":
if nullary:
raise unittest.SkipTest("Cannot lower nested pmap: jit-of-pmap warning")
raise unittest.SkipTest("TODO: figure out how to invoke pmap from TF")
f_tf = jax2tf.convert(func_to_convert,
native_serialization=True,
native_serialization_platforms=('tpu',))
f_tf = tf.function(f_tf, jit_compile=True, autograph=False)
with contextlib.ExitStack() as stack:
if with_mesh:
stack.enter_context(mesh)
# Run the JAX native version, to check it works, and to fill caches.
_ = func_to_convert(*args)
exported = jax_export.export(
func_to_convert,
lowering_platform='tpu',
strict_checks=True
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
if transform1 == "shard_map":
self.assertIn("stablehlo.all_gather", str(exported.mlir_module))
else:
self.assertIn("stablehlo.reduce_window", str(exported.mlir_module))
def test_cross_platform_error(self):
f_tf = jax2tf.convert(jnp.sin, native_serialization=True,
native_serialization_platforms=('tpu',))
x = np.float32(.5)
if jtu.device_under_test() == "tpu":
self.assertAllClose(jnp.sin(x), f_tf(x))
else:
# We can construct the tf.Graph
f_tf_fun = tf.function(f_tf, jit_compile=True, autograph=False)
graph_def = f_tf_fun.get_concrete_function(x).graph.as_graph_def()
self.assertIn("XlaCallModule", str(graph_def))
with self.assertRaisesRegex(tf.errors.NotFoundError,
"The current platform .* is not among the platforms required by the module"):
f_tf(x)
@jtu.ignore_warning(message="using native_serialization_platforms without native_serialization")
def test_native_parameters_for_non_native(self):
# We can use the native_serialization_platforms even for non-native
# serialization.
f_tf = jax2tf.convert(jnp.sin,
native_serialization_platforms=('cpu',))
x = np.float32(.5)
# Run the TF code on CPU
tf_cpus = tf.config.list_logical_devices("CPU")
self.assertNotEmpty(tf_cpus)
with tf.device(tf_cpus[0]):
self.assertAllClose(jnp.sin(x), f_tf(x))
f_tf = jax2tf.convert(jnp.sin,
native_serialization_strict_checks=False)
self.assertAllClose(jnp.sin(x), f_tf(x))
def test_native_serialization_grad(self):
# Check that the grad function uses the same native serialization parameters
# as the primal function.
f_tf = jax2tf.convert(jnp.sin, native_serialization=True,
native_serialization_platforms=('tpu',))
x = np.arange(4, dtype=np.float32)
x_v = tf.Variable(x)
@tf.function(autograph=False)
def f_grad_tf(x_v):
with tf.GradientTape() as tape:
tape.watch(x_v)
res_tf = f_tf(x_v)
return tape.gradient(res_tf, x_v)
# Make sure that we have 2x XlaCallModule in the graph of the gradient
# function
f_grad_tf_fun = tf.function(f_grad_tf, autograph=False)
graph_def = f_grad_tf_fun.get_concrete_function(x).graph.as_graph_def()
logging.info("Found graph_def: %s", graph_def)
self.assertLen(re.findall(r'op:\s*"XlaCallModule"', str(graph_def)), 2)
if jtu.device_under_test() != "tpu":
with self.assertRaisesRegex(
tf.errors.NotFoundError,
r"The current platform .* is not among the platforms required by the module: \[TPU\]"):
f_grad_tf(x_v)
def test_effects_error(self):
def f_jax(x):
jax.debug.print("{}", x)
return jnp.sin(x)
with self.assertRaisesRegex(NotImplementedError,
"serialization of host_callbacks is not yet implemented"):
jax2tf.convert(f_jax, native_serialization=True)(np.float32(42.))
def f_ordered_jax(x):
jax.debug.print("{}", x, ordered=True)
return jnp.sin(x)
with self.assertRaisesRegex(NotImplementedError,
"serialization of host_callbacks is not yet implemented"):
jax2tf.convert(f_ordered_jax, native_serialization=True)(np.float32(42.))
def test_tuple_args(self):
# On TPU if we have more than 2000 arguments, we pass them as a tuple.
# This is a compiler option, and should have no effect on lowering.
if jtu.device_under_test() != "tpu":
raise unittest.SkipTest("Test enabled on TPU only")
def f_jax(*many_args):
acc = 0.
for a in many_args:
acc += a
return acc
many_args = [np.float32(i) for i in range(2001)]
# Test that we do set lowered.compile_args[tuple_args]
lowered = jax.jit(f_jax).lower(*many_args)
self.assertTrue(lowered._lowering.compile_args["tuple_args"])
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
self.assertAllClose(f_jax(*many_args), res)
def test_nested_convert(self):
# Test call sequence: convert -> call_tf -> convert.
@jax.jit
def f_jax(x):
return x + 1
inputs = np.ones((10), dtype=np.float32)
res = f_jax(inputs)
f_tf = jax2tf.convert(f_jax, native_serialization=True)
self.assertAllClose(res, f_tf(inputs))
f_jax_nested = jax2tf.call_tf(f_tf)
self.assertAllClose(res, f_jax_nested(inputs))
f_tf_nested = jax2tf.convert(f_jax_nested, native_serialization=True)
self.assertAllClose(res, f_tf_nested(inputs))
def get_serialized_computation(
f_jax: Callable,
*args,
abstracted_axes: Optional[Tuple[Dict[int, str]]] = None,
use_pjit: bool = False,
in_shardings = None,
out_shardings = None) -> Tuple[str, int]:
if use_pjit:
assert not abstracted_axes
lowered = pjit.pjit(
f_jax, in_shardings=in_shardings, out_shardings=out_shardings
).lower(*args)
else:
lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args)
mlir_module = lowered._lowering.stablehlo()
xla_call_module_version = 5
mlir_str = mlir.module_to_bytecode(mlir_module)
if stablehlo.get_api_version() < 4:
target_version = stablehlo.get_earliest_forward_compatible_version()
else:
# See comments next to the usage of stablehlo.get_minimum_version() in
# jax_export.py for an explanation how it works.
target_version = stablehlo.get_minimum_version()
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return mlir_module_serialized, xla_call_module_version
class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
"""Unit tests for XlaCallModule. Will move these eventually to TF."""
def test_simple(self):
def f_jax(x):
return jnp.sin(x)
x = np.ones((2, 3), dtype=np.float32)
jax_res = f_jax(x)
module, version = get_serialized_computation(f_jax, x)
res = tfxla.call_module([x],
version=version,
module=module,
Tout=[jax_res.dtype],
Sout=[jax_res.shape])
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
[jax_res])
def test_while(self):
# With nested computation
def f_jax(count, x):
return lax.while_loop(lambda carry: carry[0] < count, lambda carry:
(carry[0] + 1, carry[1] + 1.), (0, x))[1]
count = np.int32(5)
x = np.ones((2, 3), dtype=np.float32)
jax_res = f_jax(count, x)
module, version = get_serialized_computation(f_jax, count, x)
res = tfxla.call_module([count, x],
version=version,
module=module,
Tout=[jax_res.dtype],
Sout=[jax_res.shape])
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
[jax_res])
def test_multiple_args_results(self):
def f_jax(x1, x2):
return (jnp.sin(x1), jnp.cos(x2))
x1 = np.ones((2, 3), dtype=np.float32)
x2 = np.ones((3, 4), dtype=np.float32)
jax_res = f_jax(x1, x2)
module, version = get_serialized_computation(f_jax, x1, x2)
def f_tf(x1_tf, x2_tf):
return tfxla.call_module([x1_tf, x2_tf],
version=version,
module=module,
Tout=[jax_res[0].dtype, jax_res[1].dtype],
Sout=[jax_res[0].shape, jax_res[1].shape])
res = tf.function(f_tf, jit_compile=True, autograph=False)(x1, x2)
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
jax_res)
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D(self):
def func_jax(x, y):
return x + y
shape = (8, 10)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
in_axis_resources = (P("x"), P("x"))
out_axis_resources = None
res_jax = pjit.pjit(
func_jax,
in_shardings=in_axis_resources,
out_shardings=out_axis_resources,
)(x, x)
module, version = get_serialized_computation(
func_jax,
x,
x,
use_pjit=True,
in_shardings=in_axis_resources,
out_shardings=out_axis_resources)
def f_tf(x_tf, y_tf):
return tfxla.call_module([x_tf, y_tf],
version=version,
module=module,
Tout=[x.dtype],
Sout=[x.shape])
res_tf = tf.function(f_tf, jit_compile=True, autograph=False)(x, x)[0]
self.assertAllClose(res_tf.numpy(), res_jax)
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def test_key_argument(self):
func = lambda key: jax.random.uniform(key, ())
key = jax.random.PRNGKey(0)
key_raw = jax.random.key_data(key)
with self.assertWarnsRegex(FutureWarning, "Raw arrays as random keys.*"):
tf_result = jax2tf.convert(func)(key_raw)
jax_result = func(key)
self.assertEqual(tf_result, jax_result)
def test_key_from_seed(self):
func = lambda seed: jax.random.uniform(jax.random.PRNGKey(seed), ())
seed = 1701
tf_result = jax2tf.convert(func)(seed)
jax_result = func(seed)
self.assertEqual(tf_result, jax_result)
def test_key_closure(self):
def func():
# Include nontrivial shape operations to catch tracing bugs.
key = global_key.reshape(1).squeeze()
return jax.random.uniform(key)
global_key = jax.random.PRNGKey(0)
tf_result = jax2tf.convert(func)()
jax_result = func()
self.assertEqual(tf_result, jax_result)
if __name__ == "__main__":
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if not hasattr(tfxla, "optimization_barrier"):
jax.config.update("jax_remat_opt_barrier", False)
absltest.main(testLoader=jtu.JaxTestLoader())