# Copyright 2020 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for JAX primitive coverage. The bulk of the testing is done by `test_prim`, which is parameterized by about 2000+ test harnesses. See `primitive_harness.py` docstring for a description of test harnesses. That module contains also the definitions of all the test harnesses, and a specification of which are only partially implemented for JAX. For each harness, we convert the JAX function to Tensorflow and then we run it on the same inputs in "eager", "graph", or "compiled" mode and we check that we get the same result as in JAX (see `tf_test_util.ConvertAndCompare`). Some harnesses need specific tolerances, or even custom equality assertions. Also, for some harnesses we need to specify some data types that result in Tensorflow errors (for some devices and compilation modes). These limitations are captured as jax2tf_limitations.Jax2TfLimitation objects. From the limitations objects, we generate a [report](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). The report has instructions for how to re-generate it. If a harness run fails with error, and a limitation that matches the device and data types is found, the error is logged but does not abort the test. If a harness run succeeds and there are matching limitations, the test emits a warning. If you want to turn these warnings into errors, you'd have to uncomment an assertion in `tf_test_util.ConvertAndCompare`. IMPORTANT: If you need to customize the testing of a particular primitive conversion, you must create a class method in jax2tf_limitations.jax2tf_limitations, with the same name as the harness.group_name (typically the same as the primitive name). That class method should return the list of Jax2TfLimitation objects for the harness. See `jax2tf_limitations.limitations_for_harness`. If a group name does not need limitations, then it must be listed in the `jax2tf_limitations.harness_groups_no_limitations`. """ import datetime import os from typing import Any, Dict, Tuple import unittest from absl import logging from absl.testing import absltest from absl.testing import parameterized import jax from jax import dtypes from jax import numpy as jnp from jax._src import test_util as jtu from jax import config from jax.experimental import jax2tf from jax.interpreters import mlir from jax._src.interpreters import xla import numpy as np import tensorflow as tf # type: ignore[import] config.parse_flags_with_absl() # Import after parsing flags from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation from jax.experimental.jax2tf.tests import primitive_harness DType = Any REDUCE = ( jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum, ) class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # This test runs for all primitive harnesses. For each primitive "xxx" the # test will be called "test_prim_xxx_..." and the custom parameters for # the test are defined in the class method "jax2tf_limitations.Jax2TfLimitation.xxx". # See more details in the comment at top of file and in Jax2TfLimitation class. # If you want to run this test for only one harness, add parameter # `one_containing="foo"` to parameterized below. @primitive_harness.parameterized( primitive_harness.all_harnesses, include_jax_unimpl=False, #one_containing="", ) @jtu.ignore_warning( category=UserWarning, message="Using reduced precision for gradient.*") def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple(filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) if config.jax2tf_default_native_serialization and not enable_xla: raise unittest.SkipTest("native_serialization not supported with enable_xla=False") if ("eigh" == harness.group_name and np.complex64 == harness.dtype and device == "tpu"): raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32") if (config.jax2tf_default_native_serialization and device == "gpu" and "lu" in harness.fullname): raise unittest.SkipTest("b/269388847: lu failures on GPU") def skipCustomCallTest(target: str): raise unittest.SkipTest( f"TODO(b/272239584): custom call target not guaranteed stable: {target}") if config.jax2tf_default_native_serialization: if device == "cpu": if "cholesky_shape" in harness.fullname: skipCustomCallTest("lapack_spotrf, lapack_dpotrf, lapack_zpotrf, lapack_cpotrf") if "eig_shape" in harness.fullname: skipCustomCallTest("lapack_cgeev, lapack_sgeev, lapack_dgeev, lapack_zgeev") if "lu_shape" in harness.fullname: skipCustomCallTest("lapack_zgetrf, lapack_sgetrf") if "svd_shape" in harness.fullname: skipCustomCallTest("lapack_sgesdd, lapack_zgesdd, lapack_cgesdd") if "triangular_solve_" in harness.fullname: skipCustomCallTest("blas_ctrsm, blas_dtrsm, blas_ztrsm, blas_strsm") if "custom_linear_solve" in harness.fullname: skipCustomCallTest("lapack_sgetrf, lapack_dgetrf") elif device == "gpu": if "custom_linear_solve_" in harness.fullname: skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched") if "svd_shape" in harness.fullname: skipCustomCallTest("cusolver_gesvdj") if "tridiagonal_solve_shape" in harness.fullname: skipCustomCallTest("cusparse_gtsv2_f32, cusparse_gtsv2_f64") associative_scan_reductions = harness.params.get("associative_scan_reductions", False) try: with jax.jax2tf_associative_scan_reductions(associative_scan_reductions): self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla) except Exception as e: # TODO(b/264596006): custom calls are not registered properly with TF in OSS if (config.jax2tf_default_native_serialization and "does not work with custom calls" in str(e)): logging.warning("Supressing error %s", e) raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF") else: raise e def test_primitive_coverage(self): """Fail if there are JAX primitives that are not implemented.""" # Harvest primitives from XLA translation tables all_primitives = ( set(xla._translations) | set(xla._backend_specific_translations["cpu"]) | set(xla._backend_specific_translations["gpu"]) | set(xla._backend_specific_translations["tpu"]) | set(mlir._lowerings) | set(mlir._platform_specific_lowerings["cpu"]) | set(mlir._platform_specific_lowerings["gpu"]) | set(mlir._platform_specific_lowerings["tpu"])) tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set( jax.experimental.jax2tf.jax2tf.tf_impl_with_avals) tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl) all_primitives = tuple(sorted(all_primitives, key=str)) for p in all_primitives: if p.name == "axis_index": continue if p.name == "sharding_constraint": continue # TODO: Remove once tensorflow is 2.10.0 everywhere. if p.name == "optimization_barrier": continue if p.name == "debug_callback": # TODO(sharadmv,necula): enable debug callbacks in TF continue if p.name in tf_not_yet_impl: self.assertNotIn( p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl else: self.assertIn(p, tf_impl) def test_generate_limitations_doc(self): """Generates primitives_with_limited_support.md. See the doc for instructions. """ harnesses = [ h for h in primitive_harness.all_harnesses if h.filter(h, include_jax_unimpl=True) ] print(f"Found {len(harnesses)} test harnesses that work in JAX") def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation): return (h.group_name, l.description, l.devices, tuple(np.dtype(d).name for d in l.dtypes), l.modes) unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {} for h in harnesses: for l in h.jax_unimplemented: if l.enabled: # Fake a Jax2TFLimitation from the Limitation tfl = Jax2TfLimitation(description="Not implemented in JAX: " + l.description, devices = l.devices, dtypes = l.dtypes, expect_tf_error = False, skip_tf_run = True) unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl) for h in harnesses: for l in Jax2TfLimitation.limitations_for_harness(h): unique_limitations[hash(unique_hash(h, l))] = (h, l) print(f"Found {len(unique_limitations)} unique limitations") tf_error_table = [ """ | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | | --- | --- | --- | --- | --- |""" ] tf_numerical_discrepancies_table = list(tf_error_table) # a copy for h, l in sorted( unique_limitations.values(), key=lambda pair: unique_hash(*pair)): devices = ", ".join(sorted(l.devices)) modes = ", ".join(sorted(l.modes)) description = l.description if l.skip_comparison: description = "Numeric comparison disabled: " + description if l.expect_tf_error: description = "TF error: " + description if l.skip_tf_run: description = "TF test skipped: " + description if l.skip_tf_run or l.expect_tf_error: to_table = tf_error_table elif l.skip_comparison or l.custom_assert: to_table = tf_numerical_discrepancies_table else: continue to_table.append( f"| {h.group_name} | {description} | " f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |" ) if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): raise unittest.SkipTest( "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" ) # The CPU has more supported types, and harnesses self.assertEqual("cpu", jtu.device_under_test()) self.assertTrue( config.x64_enabled, "Documentation generation must be run with JAX_ENABLE_X64=1") with open( os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md.template")) as f: template = f.read() output_file = os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md") with open(output_file, "w") as f: f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \ .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \ ) # The rest of the test are checking special cases @parameterized.named_parameters( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less, jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal, jnp.not_equal, jnp.maximum, jnp.minimum]) def test_type_promotion(self, f_jax=jnp.add): # We only test a few types here, as tensorflow does not support many # types like uint* or bool in binary ops. types = [dtypes.bfloat16, np.int32, np.int64, np.float32] for x_dtype in types: for y_dtype in types: x = np.array([1, 2], dtype=x_dtype) y = np.array([3, 4], dtype=y_dtype) self.ConvertAndCompare(f_jax, x, y) def test_integer_div(self): x = jnp.array([-4, -3, -1, 0, 1, 3, 6]) y = np.int32(3) self.ConvertAndCompare(jnp.floor_divide, x, y) expected = jnp.floor_divide(x, y) if not config.jax2tf_default_native_serialization: # With native serialization TF1 seems to want to run the converted code # on the CPU even when the default backend is the TPU. # Try it with TF 1 as well (#5831) with tf.compat.v1.Session() as sess: tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y)) self.assertAllClose(expected, tf1_res) def test_boolean_gather(self): values = np.array([[True, True], [False, True], [False, False]], dtype=np.bool_) indices = np.array([0, 1], dtype=np.int32) for axis in [0, 1]: f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop self.ConvertAndCompare(f_jax, values, indices) def test_gather_rank_change(self): params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]]) indices = jnp.array([[1, 1, 2], [0, 1, 0]]) f_jax = jax.jit(lambda i: params[i]) self.ConvertAndCompare(f_jax, indices) @jtu.sample_product(f_jax=REDUCE) def test_reduce_ops_with_numerical_input(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values) @jtu.sample_product(op=["add", "max", "min", "multiply", "set"]) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u)) self.ConvertAndCompare(f_jax, values, update) @jtu.sample_product(f_jax=REDUCE) def test_reduce_ops_with_boolean_input(self, f_jax): values = np.array([True, False, True], dtype=np.bool_) self.ConvertAndCompare(f_jax, values) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())