Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jax2tf/tests/primitive_harness.py
2023-06-19 00:49:18 +02:00

3345 lines
120 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines test inputs and invocations for JAX primitives.
A primitive harness encodes one use case for one JAX numeric primitives. It
describes how to generate the inputs and parameters and how to invoke the
JAX primitive. A primitive harness can be used in multiple kinds of tests.
For example, we can use the harnesses to check
that each primitive is compiled correctly, or that we can apply a certain
transformation, e.g., `vmap`.
See the `Harness` class below for how to define a harness, describing one
use case of one primitive.
Some use cases are known to be partially implemented
in JAX, e.g., because of an implementation limitation. We do have harnesses
for those cases too, but there is a mechanism to filter them out.
Instead of writing this information as conditions inside one
particular test, we write them as `Limitation` objects that can be reused in
multiple tests and can also be used to generate documentation, e.g.,
the report of [unsupported and
partially-implemented JAX
primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
The limitations are used to filter out from tests the harnesses that are known
to fail. A Limitation is specific to a harness.
"""
import operator
import os
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Optional,
NamedTuple, Sequence, Tuple, Union)
from absl import testing
import numpy as np
import jax
from jax import config
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax._src import ad_util
from jax._src import dispatch
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src import random as jax_random
FLAGS = config.FLAGS
Rng = Any # A random number generator
DType = Any
class RandArg(NamedTuple):
"""Descriptor for a randomly generated argument.
See description of `Harness`.
"""
shape: Tuple[int, ...]
dtype: DType
class StaticArg(NamedTuple):
"""Descriptor for a static argument.
See description of `Harness`.
"""
value: Any
class CustomArg(NamedTuple):
"""Descriptor for a dynamic argument generated from a PRNG factory.
See description of `Harness`.
"""
make: Callable[[Rng], Any] # Called with a Rng to make a tensor
ArgDescriptor = Union[RandArg, StaticArg, CustomArg, Any]
class Harness:
"""Specifies inputs and callable for a primitive.
See the module docstring for an introduction to harnesses.
A harness is conceptually a callable and a list of arguments, that together
exercise a use case. The harness can optionally have additional parameters
that can be used by the test.
The arguments are specified through argument descriptors. An argument
descriptor can be:
* a numeric value or ndarray, or
* an instance of ``RandArg(shape, dtype)`` to be used with a PRNG to
generate random tensor of the given shape and type, or
* an instance of ``CustomArg(fun)`` to be used with a PRNG, or
* an instance of ``StaticArg(value)``. Often these are the non-array
arguments, e.g., a shape.
The given callable will be passed one argument corresponding to each
argument descriptor, e.g., `harness.fun(* harness.args_maker(rng))`.
However, in many applications we only care about the non-static arguments.
For that purpose, you can use `harness.dyn_fun(*
harness.dyn_args_maked(rng))`,
where `harness.dyn_fun` is `harness.fun` specialized to the static arguments.
For example, a harness for ``lax.take(arr, indices, axis=None)`` may want
to expose as external (non-static) argument the array and the indices, and
keep the axis as a static argument (technically specializing the `take` to
a axis):
Harness(lax.slice_p,
f"take_axis={axis}",
lax.take,
[RandArg((2, 4), np.float32), np.array([-1, 0, 1]),
StaticArg(axis)],
axis=axis)
Each harness can have a list of Limitations that describe the cases when
the harness may not be fully implemented.
"""
# The group name most often is the primitive name.
group_name: str
# Descriptive name of the harness, used as a testcase_name. Unique in a group.
name: str
# The function taking all arguments (static and dynamic).
fun: Callable
# Describes how to construct arguments, see the class docstring.
arg_descriptors: Sequence[ArgDescriptor]
dtype: DType
# A set of limitations describing the cases that are not supported or
# partially implemented in JAX for this harness.
jax_unimplemented: Sequence["Limitation"]
rng_factory: Callable
# Carry some arbitrary parameters that the test can access.
params: Dict[str, Any]
def __init__(self,
group_name,
name,
fun,
arg_descriptors,
*,
dtype,
rng_factory=jtu.rand_default,
jax_unimplemented: Sequence["Limitation"] = (),
**params):
"""See class docstring."""
self.group_name = group_name
self.name = name
self.fun = fun # type: ignore[assignment]
self.arg_descriptors = arg_descriptors
self.rng_factory = rng_factory # type: ignore[assignment]
self.jax_unimplemented = jax_unimplemented
self.dtype = dtype
self.params = params
def __str__(self):
return self.fullname
@property
def fullname(self):
return self.name if self.group_name is None else f"{self.group_name}_{self.name}"
def _arg_maker(self, arg_descriptor, rng: Rng):
if isinstance(arg_descriptor, StaticArg):
return arg_descriptor.value
if isinstance(arg_descriptor, RandArg):
return self.rng_factory(rng)(arg_descriptor.shape, arg_descriptor.dtype)
if isinstance(arg_descriptor, CustomArg):
return arg_descriptor.make(rng)
return arg_descriptor
def args_maker(self, rng: Rng) -> Sequence:
"""All-argument maker, including the static ones."""
return [self._arg_maker(ad, rng) for ad in self.arg_descriptors]
def dyn_args_maker(self, rng: Rng) -> Sequence:
"""A dynamic-argument maker, for use with `dyn_fun`."""
return [
self._arg_maker(ad, rng)
for ad in self.arg_descriptors
if not isinstance(ad, StaticArg)
]
def dyn_fun(self, *dyn_args):
"""Invokes `fun` given just the dynamic arguments."""
all_args = self._args_from_dynargs(dyn_args)
return self.fun(*all_args)
def _args_from_dynargs(self, dyn_args: Sequence) -> Sequence:
"""All arguments, including the static ones."""
next_dynamic_argnum = 0
all_args = []
for ad in self.arg_descriptors:
if isinstance(ad, StaticArg):
all_args.append(ad.value)
else:
all_args.append(dyn_args[next_dynamic_argnum])
next_dynamic_argnum += 1
return all_args
def filter(self,
device_under_test: str,
*,
include_jax_unimpl: bool = False,
one_containing: Optional[str] = None) -> bool:
if not include_jax_unimpl:
if any([
device_under_test in l.devices
for l in self.jax_unimplemented
if l.filter(device=device_under_test, dtype=self.dtype)
]):
return False
if one_containing is not None and one_containing not in self.fullname:
return False
return True
def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str:
"""User-friendly description of a set of dtypes"""
if not dtype_list and empty_means_all:
return "all"
names = {np.dtype(dt).name for dt in dtype_list}
signed = {"int8", "int16", "int32", "int64"}
if all([t in names for t in signed]):
names = (names - signed) | {"signed"}
integers = {"uint8", "uint16", "uint32", "uint64"}
if all([t in names for t in integers]):
names = (names - integers) | {"unsigned"}
integer = {"signed", "unsigned"}
if all([t in names for t in integer]):
names = (names - integer) | {"integer"}
floating = {"bfloat16", "float16", "float32", "float64"}
if all([t in names for t in floating]):
names = (names - floating) | {"floating"}
complex = {"complex64", "complex128"}
if all([t in names for t in complex]):
names = (names - complex) | {"complex"}
inexact = {"floating", "complex"}
if all([t in names for t in inexact]):
names = (names - inexact) | {"inexact"}
all_types = {"integer", "inexact", "bool"}
if all([t in names for t in all_types]):
names = (names - all_types) | {"all"}
return ", ".join(sorted(list(names)))
##### All harnesses in this file.
all_harnesses: List[Harness] = []
def define(
group_name, # Should be the primitive name, as much as possible
name,
fun,
arg_descriptors,
*,
dtype,
rng_factory=jtu.rand_default,
jax_unimplemented: Sequence["Limitation"] = (),
**params):
"""Defines a harness and stores it in `all_harnesses`. See Harness."""
group_name = str(group_name)
h = Harness(
group_name,
name,
fun,
arg_descriptors,
rng_factory=rng_factory,
jax_unimplemented=jax_unimplemented,
dtype=dtype,
**params)
all_harnesses.append(h)
class Limitation:
"""Encodes conditions under which a harness is limited, e.g., not runnable in JAX.
See the module docstring for an introduction to harnesses and limitations.
"""
def __init__(
self,
description: str,
*,
enabled: bool = True,
devices: Union[str, Sequence[str]] = ("cpu", "gpu", "tpu"),
dtypes: Sequence[DType] = (),
skip_run: bool = False,
):
"""Args:
description: text to augment the harness group name with the description
of the limitation. Used for reports.
enabled: whether this limitation is enabled for the harness in which
it appears. This is only used during testing to know whether to ignore
harness errors. Use this sparingly, prefer `devices` and
`dtypes` for enabled conditions that are included in reports.
devices: a device type (string) or a sequence of device types
for which this applies. By default, it applies to all devices types.
Used for filtering during harness execution, and for reports.
dtypes: the sequence of dtypes for which this applies. An empty sequence
denotes all dtypes. Used for filtering during harness execution, and
for reports.
skip_run: this harness should not even be invoked (typically because it
results in a crash). This should be rare.
"""
assert isinstance(description, str), f"{description}"
self.description = description
self.skip_run = skip_run
if isinstance(devices, str):
devices = (devices,)
else:
devices = tuple(devices)
self.devices = devices
assert isinstance(dtypes, Iterable)
dtypes = tuple(dtypes)
self.dtypes = dtypes
self.enabled = enabled # Does it apply to the current harness?
def __str__(self):
return (f"\"{self.description}\" devices={self.devices} "
f"dtypes={[np.dtype(dt).name for dt in self.dtypes]}" +
(" (skip_run) " if self.skip_run else ""))
__repr__ = __str__
def filter(self,
device: Optional[str] = None,
dtype: Optional[DType] = None) -> bool:
"""Check that a limitation is enabled for the given dtype and device."""
return (self.enabled and
(not self.dtypes or dtype is None or dtype in self.dtypes) and
(device is None or device in self.devices))
def parameterized(harnesses: Iterable[Harness],
*,
one_containing: Optional[str] = None,
include_jax_unimpl: bool = False):
"""Decorator for tests.
The tests receive a `harness` argument.
The `JAX_TEST_HARNESS_ONE_CONTAINING` environment variable is useful for
debugging. If given, then picks only one harness whose name contains the
string. The whole set of parameterized tests is reduced to one test,
whose name is not decorated to make it easier to pick with an IDE for
running.
"""
one_containing = one_containing or os.environ.get(
"JAX_TEST_HARNESS_ONE_CONTAINING")
cases = tuple(
# Change the testcase name to include the harness name.
dict(
testcase_name=harness.fullname if one_containing is None else "",
harness=harness) for harness in harnesses if harness.filter(
jtu.device_under_test(),
one_containing=one_containing,
include_jax_unimpl=include_jax_unimpl))
if one_containing is not None:
if not cases:
raise ValueError(
f"Cannot find test case with name containing {one_containing}."
"Names are:"
"\n".join([harness.fullname for harness in harnesses]))
cases = cases[0:1]
if not cases:
# We filtered out all the harnesses.
return jtu.skip_on_devices(jtu.device_under_test())
return testing.parameterized.named_parameters(*cases)
###############################################################################
### Harness definitions ###
###############################################################################
def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype):
define(
str(prim),
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
prim.bind, [RandArg(shape, dtype)],
prim=prim,
dtype=dtype,
shape=shape)
for dtype in (set(jtu.dtypes.all) -
set(jtu.dtypes.all_unsigned + jtu.dtypes.boolean)):
_make_unary_elementwise_harness(prim=lax.abs_p, dtype=dtype)
for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
_make_unary_elementwise_harness(prim=lax.acosh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.asinh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.atanh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype)
for dtype in jtu.dtypes.all_floating:
_make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.bessel_i1e_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.ceil_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.erf_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.erf_inv_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.erfc_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.floor_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.is_finite_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.lgamma_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.digamma_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.cbrt_p, dtype=dtype)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
_make_unary_elementwise_harness(prim=lax.neg_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sign_p, dtype=dtype)
define(
str(lax.sign_p),
f"special_0_dtype={jtu.dtype_str(dtype)}",
lax.sign_p.bind, [StaticArg(np.zeros((2, 2), dtype=dtype))],
prim=lax.sign_p,
dtype=dtype)
def _make_round_harness(name,
*,
shape=(100, 100),
dtype=np.float32,
rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO,
operand=None):
operand = operand if operand is not None else RandArg(shape, dtype)
define(
"round",
f"{name}_shape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_roundingmethod={rounding_method}",
lax.round, [operand, StaticArg(rounding_method)],
dtype=dtype,
operand=operand,
rounding_method=rounding_method)
# Validate dtypes
for dtype in jtu.dtypes.all_floating:
_make_round_harness("dtypes", dtype=dtype)
for rounding_method in [
lax.RoundingMethod.AWAY_FROM_ZERO, lax.RoundingMethod.TO_NEAREST_EVEN
]:
operand = np.array([[0.5, 1.2, 1.5, 1.7, 2.5], [-0.5, -1.2, -1.5, -1.7, -2.5]], dtype=np.float32)
_make_round_harness(
"rounding_methods", operand=operand, rounding_method=rounding_method)
def _make_convert_element_type_harness(name,
*,
shape=(100, 100),
dtype=np.float32,
new_dtype=np.float32):
define(
"convert_element_type",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}",
lambda arg: (lax.convert_element_type_p.bind(
arg, new_dtype=np.dtype(new_dtype), weak_type=False)),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
new_dtype=new_dtype)
for old_dtype in jtu.dtypes.all:
# TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating
# point numbers and new_dtype is an unsigned integer. See issue
# https://github.com/google/jax/issues/5082 for details.
for new_dtype in (jtu.dtypes.all
if not (dtypes.issubdtype(old_dtype, np.floating) or
dtypes.issubdtype(old_dtype, np.complexfloating))
else set(jtu.dtypes.all) - set(jtu.dtypes.all_unsigned)):
_make_convert_element_type_harness(
"dtypes_to_dtypes", dtype=old_dtype, new_dtype=new_dtype)
def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3):
define(
"integer_pow",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{y=}",
lax.integer_pow,
[RandArg(shape, dtype), StaticArg(y)],
shape=shape,
dtype=dtype,
y=y)
for dtype in [d for d in jtu.dtypes.all if d not in jtu.dtypes.boolean]:
# Validate dtypes and y values for some special cases.
for y in range(-3, 5):
if np.issubdtype(dtype, np.integer) and y < 0:
continue # No negative powers for integers
_make_integer_pow_harness("dtypes", dtype=dtype, y=y)
# Validate overflow behavior by dtype. 127
_make_integer_pow_harness("overflow", y=127, dtype=dtype)
for dtype in jtu.dtypes.all_inexact:
# Validate negative y by dtype
_make_integer_pow_harness("negative_overflow", y=-127, dtype=dtype)
def _make_pow_harness(name,
*,
shapes=((20, 30), (20, 30)),
dtype=np.float32,
lhs=None,
rhs=None):
lhs = RandArg(shapes[0], dtype) if lhs is None else lhs
rhs = RandArg(shapes[1], dtype) if rhs is None else rhs
define(
"pow",
f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs.shape, dtype)}",
lax.pow, [lhs, rhs],
lhs=lhs,
rhs=rhs,
dtype=dtype)
for dtype in jtu.dtypes.all_inexact:
# Validate dtypes
_make_pow_harness("dtypes", dtype=dtype)
# Validate broadcasting behavior
for shapes in [
((), (4, 5, 6)), # broadcasting lhs
((4, 5, 6), ()), # broadcasting rhs
((4, 1, 6), (4, 5, 6)), # broadcasting lhs on a specific axis
((4, 5, 6), (4, 1, 6)), # broadcasting rhs on a specific axis
]:
_make_pow_harness("broadcast", shapes=shapes)
def _make_reshape_harness(name,
*,
shape=(2, 3),
new_sizes=(3, 2),
dimensions=(0, 1),
dtype=np.float32):
define(
"reshape",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newsizes={new_sizes}_{dimensions=}",
lax.reshape,
[RandArg(shape, dtype),
StaticArg(new_sizes),
StaticArg(dimensions)],
shape=shape,
dtype=dtype,
new_sizes=new_sizes,
dimensions=dimensions)
# Validate dtypes
for dtype in jtu.dtypes.all:
_make_reshape_harness("dtypes", dtype=dtype)
# Validate new_sizes
for shape, new_sizes, dimensions in [
((3, 4, 5), (3, 20), (0, 1, 2)), # merging two dimensions
((3, 4, 5), (4, 15), (0, 1, 2)), # changing leading dimension
]:
_make_reshape_harness(
"new_sizes", shape=shape, new_sizes=new_sizes, dimensions=dimensions)
# Validate dimensions collapsing order
for shape, new_sizes, dimensions in [
((3, 4, 5), (3, 20), (2, 1, 0)), # transpose shape (0, 1, 2) into (2, 1, 0)
((3, 4, 5), (3, 20), (2, 0, 1)), # transpose shape (0, 1, 2) into (2, 0, 1)
]:
_make_reshape_harness(
"dimensions", shape=shape, new_sizes=new_sizes, dimensions=dimensions)
def _make_rev_harness(name, *, shape=(4, 5), dtype=np.float32, dimensions=(0,)):
define(
"rev",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}",
lax.rev,
[RandArg(shape, dtype), StaticArg(dimensions)],
shape=shape,
dtype=dtype,
dimensions=dimensions)
# Validate dtypes
for dtype in jtu.dtypes.all:
_make_rev_harness("dtypes", dtype=dtype)
# Validate dimensions
for shape, dimensions in [
((3, 4, 5), ()), # empty dimensions
((3, 4, 5), (0, 2)), # some dimensions
((3, 4, 5), (0, 1, 2)), # all dimensions (ordered)
((3, 4, 5), (2, 0, 1)), # all dimensions (unordered)
]:
_make_rev_harness("dimensions", shape=shape, dimensions=dimensions)
def _make_device_put_harness(name,
*,
shape=(3, 4),
dtype=np.float32,
device=None):
_device_fn = lambda: jax.devices(device)[0] if device is not None else None
define(
"device_put",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}",
lambda x: dispatch.device_put_p.bind(x, device=_device_fn(), src=None),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
device=device)
# Validate dtypes
for dtype in jtu.dtypes.all:
_make_device_put_harness("dtypes", dtype=dtype)
# Validate devices
_make_device_put_harness("devices", device="cpu")
def _make_bitcast_convert_type_harness(name,
*,
shape=(2, 3),
dtype=np.float32,
new_dtype=np.float32):
define(
"bitcast_convert_type",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newdtype={np.dtype(new_dtype).name}",
lambda x: lax.bitcast_convert_type_p.bind(x,
new_dtype=np.dtype(new_dtype)),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
new_dtype=new_dtype)
def _can_bitcast(dtype, target_dtype):
def _to_equivalence_class(dtype):
if dtypes.issubdtype(dtype, np.integer):
return dtypes.iinfo(dtype).bits
elif dtypes.issubdtype(dtype, np.floating):
return dtypes.finfo(dtype).bits
else:
assert dtype == np.bool_ or dtypes.issubdtype(dtype, np.complexfloating)
# Complex and boolean types can only be cast to themselves
return np.dtype(dtype).name
return _to_equivalence_class(dtype) == _to_equivalence_class(target_dtype)
# Validate dtypes combinations
for dtype in jtu.dtypes.all:
for new_dtype in filter(partial(_can_bitcast, dtype), jtu.dtypes.all):
_make_bitcast_convert_type_harness(
"dtypes_to_new_dtypes", dtype=dtype, new_dtype=new_dtype)
def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32):
define(
ad_util.add_any_p,
f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}",
ad_util.add_jaxvals_p.bind,
list(map(lambda s: RandArg(s, dtype), shapes)),
dtype=dtype,
shapes=shapes)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
_make_add_any_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.all:
shape: Tuple[int, ...] = (20, 20)
define(
ad_util.stop_gradient_p,
f"{jtu.format_shape_dtype_string(shape, dtype)}",
ad_util.stop_gradient_p.bind, [RandArg(shape, dtype)],
shape=shape,
dtype=dtype)
_LAX_COMPARATORS = dict(eq=jnp.equal, ne=jnp.not_equal,
ge=jnp.greater_equal, gt=jnp.greater,
le=jnp.less_equal, lt=jnp.less)
def _make_comparator_harness(name,
*,
dtype=np.float32,
op=lax.eq_p,
op_name="eq",
lhs_shape=(),
rhs_shape=()):
define(
op_name,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
lambda *args: op(*args),
[RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype)],
op=op,
op_name=op_name,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype)
for op_name, op in _LAX_COMPARATORS.items():
for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else
set(jtu.dtypes.all)):
# Validate dtypes
_make_comparator_harness("dtypes", dtype=dtype, op=op, op_name=op_name)
# Validate broadcasting behavior
for lhs_shape, rhs_shape in [
((), (2, 3)), # broadcast scalar
((1, 2), (3, 2)), # broadcast along specific axis
]:
_make_comparator_harness(
"broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
op=op, op_name=op_name)
for dtype in jtu.dtypes.all:
shape = (3, 4, 5)
define(
"zeros_like",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
ad_util.zeros_like_p.bind, [RandArg(shape, dtype)],
shape=shape,
dtype=dtype)
for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
if np.issubdtype(dtype, np.unsignedinteger):
arg = np.array([0, 1, 2], dtype=dtype)
else:
arg = np.array([-1, -2, 0, 1], dtype=dtype)
define(
"population_count",
f"{jtu.dtype_str(dtype)}",
lax.population_count, [arg],
dtype=dtype)
def _get_max_identity(dtype):
if dtypes.issubdtype(dtype, np.inexact):
return np.array(-np.inf, dtype)
elif dtypes.issubdtype(dtype, np.integer):
return np.array(dtypes.iinfo(dtype).min, dtype)
elif dtypes.issubdtype(dtype, np.bool_):
return np.array(False, np.bool_)
def _get_min_identity(dtype):
if dtypes.issubdtype(dtype, np.inexact):
return np.array(np.inf, dtype)
elif dtypes.issubdtype(dtype, np.integer):
return np.array(dtypes.iinfo(dtype).max, dtype)
elif dtypes.issubdtype(dtype, np.bool_):
return np.array(True, np.bool_)
def _make_argminmax_harness(prim,
name,
*,
shape=(15,),
dtype=jnp.float32,
axes=(0,),
index_dtype=np.int32,
arr=None,
works_without_xla=True):
arr = arr if arr is not None else RandArg(shape, dtype)
dtype, shape = arr.dtype, arr.shape
index_dtype = dtypes.canonicalize_dtype(index_dtype)
for enable_xla in [True, False]:
define(
prim,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{axes=}_indexdtype={index_dtype}_enablexla={enable_xla}",
lambda arg: prim.bind(arg, axes=axes, index_dtype=index_dtype), [arr],
shape=shape,
dtype=dtype,
axes=axes,
index_dtype=index_dtype,
prim=prim,
enable_xla=enable_xla)
for prim in [lax.argmin_p, lax.argmax_p]:
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex):
# Validate dtypes for each primitive
_make_argminmax_harness(prim, "dtypes", dtype=dtype)
# Validate axes for each primitive; non major axis
_make_argminmax_harness(prim, "axes", shape=(18, 12), axes=(1,))
# Validate index dtype for each primitive
for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
_make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)
# Some special cases, with equal elements and NaN
for name, operand in [
("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)),
("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)),
("inf_0", np.array([2., np.inf, np.inf, -2.], np.float32)),
("inf_1", np.array([2., np.inf, -np.inf, -2.], np.float32)),
("inf_2", np.array([2., -np.inf, np.inf, -2.], np.float32)),
("inf_3", np.array([2., -np.inf, -np.inf, -2.], np.float32)),
("nan_inf_0", np.array([2., np.nan, np.inf, -2.], np.float32)),
("nan_inf_1", np.array([2., np.nan, -np.inf, -2.], np.float32)),
("equal", np.array([2., 2., 2.], np.int32)),
("singleton", np.array([1.], np.float32)),
]:
_make_argminmax_harness(prim, f"special_{name}", shape=operand.shape,
arr=operand)
# TODO(bchetioui): the below documents a limitation of argmin and argmax when a
# dimension of the input is too large. However, it is not categorizable as it
# seems that the converter fails before reaching the actual primitive call. This
# suggests that we may need to harden the converter to handle inputs this big.
# + tuple( # Document limitation in case of too large axis
# _make_argminmax_harness("overflow_axis", prim=prim,
# arr=np.ones((2**31,), dtype=np.uint8))
# for prim in [lax.argmin_p, lax.argmax_p]
# )
def _make_iota_harness(name, *, shape=(2, 3), dtype=np.float32, dimension=0):
define(
lax.iota_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{dimension=}",
lambda dtype, shape, dim:
(lax.iota_p.bind(dtype=np.dtype(dtype), shape=shape, dimension=dim)),
[StaticArg(dtype),
StaticArg(shape),
StaticArg(dimension)],
shape=shape,
dtype=dtype,
dimension=dimension)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
_make_iota_harness("dtypes", dtype=dtype)
# Validate broadcasting
for shape, dimension in [
((4, 8, 1, 1), 1), # broadcasting along non-major dimension
((4, 8, 1, 1), 2), # broadcasting along dimension == 1
]:
_make_iota_harness("broadcasting", shape=shape, dimension=dimension)
def _make_div_rem_harness(prim,
name,
*,
shapes=((2,), (2,)),
dtype=np.float32,
arrs=(None, None)):
lhs, rhs = arrs
def _make_non_zero(rng):
return jtu.rand_nonzero(rng)(shapes[1], dtype)
lhs = RandArg(shapes[0], dtype) if lhs is None else lhs
rhs_shape = rhs.shape if rhs is not None else shapes[1]
rhs_dtype = rhs.dtype if rhs is not None else dtype
rhs = CustomArg(_make_non_zero) if rhs is None else rhs
define(
prim,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs.shape, lhs.dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}",
prim.bind, [lhs, rhs],
dtype=dtype,
lhs=lhs,
rhs=rhs,
prim=prim)
for prim in [lax.div_p, lax.rem_p]:
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean) - (
set() if prim is lax.div_p else set(jtu.dtypes.complex)):
_make_div_rem_harness(prim, "dtypes", dtype=dtype)
# Validate broadcasting
for shapes in [
((2, 1, 3), (2, 4, 3)), # broadcast dividend
((2, 4, 3), (2, 1, 3)), # broadcast divisor
]:
_make_div_rem_harness(prim, "broadcast", shapes=shapes)
# Validate singularity points
for name, arrs in [
("positive_by_0", (np.ones(
(2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))),
# TODO: this fails on CPU, different result
# ("positive_by_0_int32", (np.ones((2,), dtype=np.int32),
# np.zeros((2,), dtype=np.int32))),
("negative_by_0", (-np.ones(
(2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))),
("0_by_0", (np.zeros(
(2,), dtype=np.float32), np.zeros((2,), dtype=np.float32))),
("inf_by_inf", (np.array([np.inf], dtype=np.float32),
np.array([np.inf], dtype=np.float32))),
]:
_make_div_rem_harness(prim, f"singularity_{name}", arrs=arrs)
def _make_binary_elementwise_harnesses(prim,
dtypes,
default_dtype=np.float32,
broadcasting_dtypes=None,
jax_unimplemented=lambda **kwargs: []):
def _make(name, *, shapes=((20, 20), (20, 20)), dtype):
lhs_shape, rhs_shape = shapes
define(
prim,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
prim.bind, [RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype)],
jax_unimplemented=jax_unimplemented(
dtype=dtype, prim=prim, shapes=shapes),
prim=prim,
dtype=dtype,
shapes=shapes)
broadcasting_dtypes = broadcasting_dtypes or (default_dtype,)
return (
# Validate dtypes
tuple(_make("dtypes", dtype=dtype) for dtype in dtypes) +
# Validate broadcasting
tuple(_make("broadcasting", dtype=dtype, shapes=shapes)
for shapes in [
((20, 20), (1, 20)), # broadcasting rhs
((1, 20), (20, 20)), # broadcasting lhs
]
for dtype in broadcasting_dtypes)
)
_make_binary_elementwise_harnesses(
prim=lax.add_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean))
_make_binary_elementwise_harnesses(
prim=lax.mul_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean))
_make_binary_elementwise_harnesses(
prim=lax.atan2_p, dtypes=jtu.dtypes.all_floating)
_make_binary_elementwise_harnesses(
prim=lax.igamma_p, dtypes=jtu.dtypes.all_floating)
_make_binary_elementwise_harnesses(
prim=lax.igammac_p, dtypes=jtu.dtypes.all_floating)
_make_binary_elementwise_harnesses(
prim=lax.nextafter_p, dtypes=jtu.dtypes.all_floating)
_make_binary_elementwise_harnesses(
prim=lax.and_p,
default_dtype=np.int32,
dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned +
jtu.dtypes.boolean)
_make_binary_elementwise_harnesses(
prim=lax.or_p,
default_dtype=np.int32,
dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned +
jtu.dtypes.boolean)
_make_binary_elementwise_harnesses(
prim=lax.xor_p,
default_dtype=np.int32,
dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned +
jtu.dtypes.boolean)
_make_binary_elementwise_harnesses(
prim=lax.shift_left_p,
default_dtype=np.int32,
dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned)
_make_binary_elementwise_harnesses(
prim=lax.shift_right_logical_p,
default_dtype=np.int32,
dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned)
_make_binary_elementwise_harnesses(
prim=lax.shift_right_arithmetic_p,
default_dtype=np.int32,
dtypes=jtu.dtypes.all_integer + jtu.dtypes.all_unsigned)
_make_binary_elementwise_harnesses(
prim=lax.sub_p, dtypes=set(jtu.dtypes.all) - set(jtu.dtypes.boolean))
for minmax_p in [lax.min_p, lax.max_p]:
_make_binary_elementwise_harnesses(
prim=minmax_p, dtypes=jtu.dtypes.all,
broadcasting_dtypes=(np.float32, np.complex64, np.complex128))
# Validate special cases nan/inf/-inf
for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
define(
minmax_p,
f"inf_nan_{jtu.dtype_str(dtype)}",
minmax_p.bind, [np.array([[np.nan, np.nan, np.nan],
[np.inf, np.inf, np.inf],
[-np.inf, -np.inf, -np.inf]], dtype=dtype),
np.array([[np.nan, np.inf, -np.inf],
[np.nan, np.inf, -np.inf],
[np.nan, np.inf, -np.inf]], dtype=dtype)],
prim=minmax_p,
dtype=dtype)
def _make_broadcast_in_dim_harness(name,
*,
dtype=np.float32,
shape=(2,),
outshape=(2,),
broadcast_dimensions=(0,)):
define(
lax.broadcast_in_dim_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{outshape=}_broadcastdimensions={broadcast_dimensions}",
lambda operand: lax.broadcast_in_dim_p.bind(
operand, shape=outshape, broadcast_dimensions=broadcast_dimensions),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
outshape=outshape,
broadcast_dimensions=broadcast_dimensions)
for dtype in jtu.dtypes.all:
_make_broadcast_in_dim_harness("dtypes", dtype=dtype)
# Validate parameter combinations
for shape, outshape, broadcast_dimensions in [
[(2,), (3, 2), (1,)], # add major dimension
[(2,), (2, 3), (0,)], # add inner dimension
[(), (2, 3), ()], # use scalar shape
[(1, 2), (4, 3, 2), (0, 2)], # map size 1 dim to different output dim value
]:
_make_broadcast_in_dim_harness(
"parameter_combinations",
shape=shape,
outshape=outshape,
broadcast_dimensions=broadcast_dimensions)
for dtype in jtu.dtypes.all_floating:
for arg1, arg2, arg3 in [
(np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype),
np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype),
np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6], dtype=dtype))
]:
define(
lax.regularized_incomplete_beta_p,
f"_{jtu.dtype_str(dtype)}",
lax.betainc, [arg1, arg2, arg3],
dtype=dtype)
## GATHER
# Validate dtypes
for dtype in set(jtu.dtypes.all):
indices = np.array(2, dtype=np.int32)
shape = (10,)
axis = 0
define(
lax.gather_p,
f"dtypes_shape={jtu.format_shape_dtype_string(shape, dtype)}_{axis=}_enable_xla=True",
lambda a, i, axis: jnp.take(a, i, axis=axis),
[RandArg(shape, dtype), indices,
StaticArg(axis)],
dtype=dtype,
enable_xla=True)
# Construct gather harnesses using take
_gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10))
for indices, index_oob, indices_name in [
# Ensure each set of indices has a distinct name
(np.array(2, dtype=np.int32), False, "1"),
(np.array([2], dtype=np.int32), False, "2"),
(np.array([2, 4], dtype=np.int32), False, "3"),
(np.array([2, 4], dtype=np.uint32), False, "3_uint32"), # uint32 indices
(np.array([[2, 4], [5, 6]], dtype=np.int32), False, "4"),
(np.array([[0], [1], [10]], dtype=np.int32), True, "5_oob"), # Index out of bounds too high
(np.array([[0, 1], [2, -1]], dtype=np.int32), False, "6_neg"), # Negative index is from the end
(np.array([0, 1, 2, 3, -10], dtype=np.int32), False, "7_neg"), # Index out of bounds, but works
(np.array([[[0], [1]], [[3], [-11]]], dtype=np.int32), True, "8_neg_oob") # Index out of bounds, too low
]:
for axis in [0, 1, 2]:
for enable_xla in [True, False]:
for mode in ["clip", "fill"]:
define(
lax.gather_p,
f"from_take_{indices_name=}_{axis=}_{enable_xla=}_{mode=!s}",
lambda a, i, axis, mode: jnp.take(a, i, axis=axis, mode=mode),
[_gather_input, indices, StaticArg(axis), StaticArg(mode)],
dtype=_gather_input.dtype,
enable_xla=enable_xla,
index_oob=index_oob,
mode=mode)
# Construct gather harnesses using array indexing and slicing.
for slices, name in [
((0,), "[0]"),
((0, 1), "[0,1]"),
((slice(0, 10), 2, slice(0, 10)), "[:,:2,:]"),
((slice(2, 5), 5), "[2:5,5]"),
((-1, -5, -200), "[-1,-5,-200]"),
((slice(5, -2), 300), "[5:-2,300]"),
]:
for enable_xla in [False, True]:
define(
lax.gather_p,
f"from_slicing_{name=}_{enable_xla=}",
lambda arr, *s: jnp.array(arr).__getitem__(*s),
[_gather_input, StaticArg(slices)],
dtype=_gather_input.dtype,
enable_xla=enable_xla)
# Directly from lax.gather in lax_test.py.
for shape, idxs, dnums, slice_sizes, needs_xla in [
((5,), np.array([[0], [2]]),
lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,),
start_index_map=(0,)), (1,), False),
((10,), np.array([[0], [0], [0]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(),
start_index_map=(0,)), (2,), True),
((10, 5,), np.array([[0], [2], [1]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0,)), (1, 3), True),
((10, 6), np.array([[0, 2], [1, 0]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0, 1)), (1, 3), True),
]:
dtype = np.float32
for enable_xla in ([True] if needs_xla else [True, False]):
define(
lax.gather_p,
f"{shape=}_idxs_shape={idxs.shape}_{dnums=}_{slice_sizes=}_{enable_xla=}",
lambda op, idxs, dnums, slice_sizes: lax.gather(
op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes),
[RandArg(shape, dtype), idxs,
StaticArg(dnums),
StaticArg(slice_sizes)],
dtype=dtype,
enable_xla=enable_xla)
# Test cases lax.gather with non-empty batch_dims. This is for instance
# triggered when executing `jax.vmap(lax.dynamic_slice)`.
# We currently only support the case where we have a single batch dimension.
dnums_2d = lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(0,), # Batch dimension.
start_index_map=(0, 1))
dnums_2d_2 = lax.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=(0, 1), # Only first dim is batch, collapse both.
start_index_map=(0, 1))
dnums_3d = lax.GatherDimensionNumbers(
offset_dims=(1, 2),
collapsed_slice_dims=(0,), # Batch dimension.
start_index_map=(0, 1, 2))
for op_shape, start_indices, slice_sizes, dnums in [
((4, 6), [[0, 1], [1, 2], [2, 3], [3, 2]], (1, 3), dnums_2d),
((1, 2), [[0, 1]], (1, 1), dnums_2d_2),
((2, 6, 3), [[0, 1, 0], [1, 2, 0]], (1, 3, 3), dnums_3d),
# Test out of bounds behavior.
((3, 10), [[0, 0], [1, 8], [2, 0]], (1, 5), dnums_2d),
((2, 3, 3), [[0, 1, 0], [1, 2, 1]], (1, 3, 2), dnums_3d)]:
start_indices = np.array(start_indices)
for enable_xla in [True, False]:
define(
lax.gather_p,
f"batchdims_shape={op_shape}_start_indices_shape={start_indices.shape}_{slice_sizes=}_{enable_xla=}",
lambda op, idxs, dnums, slice_sizes: lax.gather(
op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes),
[RandArg(op_shape, np.float32), start_indices,
StaticArg(dnums),
StaticArg(slice_sizes)],
dtype=np.float32,
enable_xla=enable_xla)
# Test cases lax.gather with non-empty 2D batch_dims. This is for instance
# triggered when executing `jax.vmap(jax.vmap(lax.dynamic_slice))`.
gather_2d_bd = lax.GatherDimensionNumbers(
offset_dims=(2, 3, 4), collapsed_slice_dims=(), start_index_map=(1, 2)
)
gather_2d_bd_nid = lax.GatherDimensionNumbers(
offset_dims=(2, 3, 4), collapsed_slice_dims=(), start_index_map=(2, 1)
)
gather_3d_bd = lax.GatherDimensionNumbers(
offset_dims=(1, 2, 3, 4), collapsed_slice_dims=(), start_index_map=(1, 2, 3)
)
gather_2d_bd2 = lax.GatherDimensionNumbers(
offset_dims=(2, 3, 4), collapsed_slice_dims=(), start_index_map=(1, 2)
)
for op_shape, start_indices, slice_sizes, dnums in [
((10, 10, 10), [[[0, 1], [1, 0], [0, 1], [1, 0]]], (2, 2, 2), gather_2d_bd,), # non-contigous 2d batch dims
((10, 10, 10), [[[-1, 1], [1, -1], [0, -1], [-1, 0]]], (2, 2, 3), gather_2d_bd,), # negative indices
((10, 10, 10), [[[1, 1], [1, 1], [0, 1], [2147483647, 0]]], (2, 3, 3), gather_2d_bd,), # oob indices
((10, 10, 10), [[[0, 1], [1, 0], [0, 1], [1, 0]]], (2, 2, 2), gather_2d_bd_nid,), # start_index_map not identity
((10, 10, 10), [[[0, 1], [1, 0], [0, 1], [1, 0]]], (10, 10, 10), gather_2d_bd,), # oob behavior coming from slice_sizes
((10,10,10,10), [[[[0, 1,0], [1, 0,0], [0, 1,0], [1, 0,0]]]], (10,2,2,2), gather_3d_bd), # test 3d batch dims
((10,10,10), [[[0, 1], [1, 0], [0, 1], [1, 0]]], (10,2,2), gather_2d_bd2), # test contiguous 2d batch dims
]:
start_indices = np.array(start_indices)
for enable_xla in [True, False]:
define(
"gather",
f"op_shape={op_shape}_offset_dims={dnums.offset_dims}_start_index_map={dnums.start_index_map}_start_indices_shape={start_indices.shape}_{slice_sizes=}_{enable_xla=}",
lambda op, idxs, dnums, slice_sizes: lax.gather(
op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes
),
[
RandArg(op_shape, np.float32),
start_indices,
StaticArg(dnums),
StaticArg(slice_sizes),
],
dtype=np.float32,
enable_xla=enable_xla,
)
def _make_scatter_harness(name,
*,
shape=(5,),
f_lax=lax.scatter_min,
indices_are_sorted=False,
unique_indices=False,
scatter_indices=np.array([[0], [2]]),
update_shape=(2,),
mode=lax.GatherScatterMode.FILL_OR_DROP,
dtype=np.float32,
dimension_numbers=((), (0,), (0,)),
enable_and_disable_xla=False):
dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers)
xla_options = [True, False] if enable_and_disable_xla else [True]
for enable_xla in xla_options:
define(
f_lax.__name__,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}"
.replace(" ", ""),
partial(
f_lax,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode), [
RandArg(shape, dtype),
StaticArg(scatter_indices),
RandArg(update_shape, dtype),
StaticArg(dimension_numbers)
],
jax_unimplemented=[
Limitation(
"unimplemented",
dtypes=[np.bool_],
enabled=(f_lax in [lax.scatter_add, lax.scatter_mul])),
],
f_lax=f_lax,
shape=shape,
dtype=dtype,
scatter_indices=scatter_indices,
update_shape=update_shape,
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
enable_xla=enable_xla)
# Validate dtypes
for dtype in jtu.dtypes.all:
for f_lax in [
lax.scatter, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
]:
_make_scatter_harness("dtypes", dtype=dtype, f_lax=f_lax)
# Validate f_lax/update_jaxpr
# We explicitly decide against testing lax.scatter, as its reduction function
# is lambda x, y: y, which is not commutative and thus makes results
# non-deterministic when an index into the operand is updated several times.
# Validate shapes, dimension numbers and scatter indices. All are in bounds.
for shape, scatter_indices, update_shape, dimension_numbers in [
((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))),
((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,)))
]:
_make_scatter_harness(
"shapes_and_dimension_numbers",
shape=shape,
update_shape=update_shape,
scatter_indices=np.array(scatter_indices),
dimension_numbers=dimension_numbers)
# Validate sorted indices
_make_scatter_harness("indices_are_sorted", indices_are_sorted=True)
# Validate unique_indices
# `unique_indices` does not affect correctness, only performance, and thus
# does not need to be tested here. If/when it will make sense to add a test
# with `unique_indices` = True, particular care will have to be taken with
# regards to the choice of parameters, as the results are only predictable
# when all the indices to be updated are pairwise non-overlapping. Identifying
# such cases is non-trivial.
# _make_scatter_harness("unique_indices", unique_indices=False)
# Validate different out-of-bounds modes
for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS,
lax.GatherScatterMode.FILL_OR_DROP):
for f_lax in [
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter
]:
_make_scatter_harness("modes_in_bounds",
f_lax=f_lax,
mode=mode)
_make_scatter_harness("modes_out_of_bounds", mode=mode,
shape=(1, 5),
f_lax=f_lax,
scatter_indices=np.array([10]),
update_shape=(1,),
dimension_numbers=((0,), (1,), (1,)),
enable_and_disable_xla=True)
# Validate no XLA scatters
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean):
for f_lax in [
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter
]:
for shape, scatter_indices, update_shape, dimension_numbers in [
((1,), [0], (), ((), (0,), (0,))), # zero case
((1, 1), [0], (1,), ((0,), (0,), (0,))),
((1, 1, 1), [0], (1, 1), ((0, 1), (0,), (0,))),
((1, 50, 3), [32], (1, 3), ((0, 1), (1,), (1,))),
((1, 2, 3), [1], (1, 3), ((0, 1), (1,), (1,))), # slice 2nd dim
((1, 2, 3), [0], (2, 3), ((0, 1), (0,), (0,))), # slice 1st dim
((1, 2, 3), [1, 2], (1,), ((0,), (1, 2), (1, 2))), # 2nd and 3rd
((4, 2, 3), [3, 2], (2,), ((0,), (0, 2), (0, 2))), # 1st and 3rd
((4, 2, 3, 5), [0, 4], (4, 3), ((0, 1), (1, 3), (1, 3))), # 2nd and 4th
((5, 6, 7), [[0, 1], [2, 3]], (2, 7),
((1,), (0, 1), (0, 1))), # .at[((3,4),(5,5))] shapes
((5, 6, 7), [[[0], [1]], [[2], [3]]], (5, 2, 2, 7),
((0, 3), (1,), (1,))), # .at[:, ((3,4),(5,5))] shapes
((5, 6, 7), [[[0, 1], [2, 3]], [[4, 0], [1, 2]]], (5, 2, 2),
((0,), (1, 2), (1, 2))), # .at[:, ((3,4),(5,5)), 3] shapes
((1, 125), [0], (1,), ((0,), (1,), (1,))),
]:
for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS,
lax.GatherScatterMode.FILL_OR_DROP):
_make_scatter_harness(
"no_xla_unique_indices",
shape=shape,
f_lax=f_lax,
unique_indices=True,
scatter_indices=np.array(scatter_indices),
update_shape=update_shape,
mode=mode,
dtype=dtype,
dimension_numbers=dimension_numbers,
enable_and_disable_xla=True)
# Validate no XLA scatters with non-unique indices with an indexing depth of 1.
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean):
# Note that lax.scatter currently does not work here.
for f_lax in [
lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min
]:
for shape, scatter_indices, update_shape, dimension_numbers in [
((1,), [[0],[0]], (2,), ((), (0,), (0,))), # .at[((0,0),)]
((3,), [[1],[0],[1]], (3,), ((), (0,), (0,))), # .at[((1,0,1),)]
((2, 3), [[[2],[2],[2]]], (2, 1, 3), ((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)]
((3, 5, 40), [[1],[1]], (3, 5, 2), ((0, 1), (2,), (2,))),
((3, 5, 4), [[1],[1]], (3, 2, 4), ((0, 2), (1,), (1,))),
]:
for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS,
lax.GatherScatterMode.FILL_OR_DROP):
_make_scatter_harness(
"no_xla_non_unique_indices",
shape=shape,
f_lax=f_lax,
unique_indices=False,
scatter_indices=np.array(scatter_indices),
update_shape=update_shape,
mode=mode,
dtype=dtype,
dimension_numbers=dimension_numbers,
enable_and_disable_xla=True)
for dtype in jtu.dtypes.all:
arg_shape = (2, 3)
for pads in [
[(0, 0, 0), (0, 0, 0)], # no padding
[(1, 1, 0), (2, 2, 0)], # only positive edge padding
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
[(0, 0, 0), (-1, -1, 0)], # negative padding
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
[(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension
]:
for enable_xla in [True, False]:
define(
lax.pad_p,
f"inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_{pads=}_{enable_xla=}",
lax.pad,
[RandArg(arg_shape, dtype),
np.array(0, dtype),
StaticArg(pads)],
rng_factory=jtu.rand_small,
arg_shape=arg_shape,
dtype=dtype,
pads=pads,
enable_xla=enable_xla)
def _make_select_n_harness(name,
*,
shape_pred=(2, 3),
shape_args=(2, 3),
dtype=np.float32):
define(
lax.select_n_p,
f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.bool_)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}",
lax.select_n, [
RandArg(shape_pred, np.bool_),
RandArg(shape_args, dtype),
RandArg(shape_args, dtype)
],
shape_pred=shape_pred,
shape_args=shape_args,
dtype=dtype)
define(
lax.select_n_p,
f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.int32)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}",
lax.select_n, [
CustomArg(lambda rng: jtu.rand_int(rng, high=3)(shape_args, np.int32)),
RandArg(shape_args, dtype),
RandArg(shape_args, dtype),
RandArg(shape_args, dtype),
],
shape_pred=shape_pred,
shape_args=shape_args,
dtype=dtype)
for dtype in jtu.dtypes.all:
_make_select_n_harness("dtypes", dtype=dtype)
# Validate shapes
_make_select_n_harness("shapes", shape_pred=(), shape_args=(18,))
def _make_transpose_harness(name,
*,
shape=(2, 3),
permutation=(1, 0),
dtype=np.float32):
define(
lax.transpose_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{permutation=}"
.replace(" ", ""),
lambda x: lax.transpose_p.bind(x, permutation=permutation),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
permutation=permutation)
for dtype in jtu.dtypes.all:
_make_transpose_harness("dtypes", dtype=dtype)
# Validate permutations
for shape, permutation in [
((2, 3, 4), (0, 1, 2)), # identity
((2, 3, 4), (1, 2, 0)), # transposition
]:
_make_transpose_harness("permutations", shape=shape, permutation=permutation)
## CUMREDUCE
def _make_cumreduce_harness(name,
*,
f_jax=lax_control_flow.cummin,
shape=(8, 9),
dtype=np.float32,
axis=0,
reverse=False):
limitations = []
define(
f_jax.__name__,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{axis=}_{reverse=}",
f_jax, [RandArg(shape, dtype),
StaticArg(axis),
StaticArg(reverse)],
jax_unimplemented=limitations,
f_jax=f_jax,
shape=shape,
dtype=dtype,
axis=axis,
reverse=reverse,
associative_scan_reductions=False
)
define(
f_jax.__name__,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_associative_scan_reductions_{axis=}_{reverse=}",
f_jax, [RandArg(shape, dtype),
StaticArg(axis),
StaticArg(reverse)],
jax_unimplemented=limitations,
f_jax=f_jax,
shape=shape,
dtype=dtype,
axis=axis,
reverse=reverse,
associative_scan_reductions=True
)
# Validate dtypes for each function
for f_jax in [
lax_control_flow.cummin,
lax_control_flow.cummax,
lax_control_flow.cumlogsumexp,
lax_control_flow.cumsum,
lax_control_flow.cumprod,
]:
for dtype in jtu.dtypes.all:
# cumlogsumexp is only defined for floating point types.
if (f_jax == lax_control_flow.cumlogsumexp and
not np.issubdtype(dtype, np.floating)):
continue
if dtype == np.bool_:
continue
_make_cumreduce_harness("dtype_by_fun", dtype=dtype, f_jax=f_jax)
# Validate axis for each function
shape = (8, 9)
for axis in range(len(shape)):
_make_cumreduce_harness("axis_by_fun", axis=axis, f_jax=f_jax, shape=shape)
# Validate reverse for each function
_make_cumreduce_harness("reverse", reverse=True, f_jax=f_jax)
### TOP_K
def _make_top_k_harness(name,
*,
operand=None,
shape=(5, 3),
dtype=np.float32,
k=2):
if operand is None:
operand = RandArg(shape, dtype)
define(
lax.top_k_p,
f"{name}_inshape={jtu.format_shape_dtype_string(operand.shape, operand.dtype)}_{k=}",
lax.top_k, [operand, StaticArg(k)],
shape=operand.shape,
dtype=operand.dtype,
k=k)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex):
# Validate dtypes
_make_top_k_harness("dtypes", dtype=dtype)
# Validate implicit properties of the sort
for name, operand, k in [("stability",
np.array([5, 7, 5, 8, 8, 5], dtype=np.int32), 3),
("sort_inf_nan",
np.array([+np.inf, np.nan, -np.nan, -np.inf, 3],
dtype=np.float32), 5)]:
_make_top_k_harness(name, operand=operand, k=k)
### SORT
def _make_sort_harness(name,
*,
operands=None,
shape=(5, 7),
dtype=np.float32,
dimension=0,
is_stable=False,
num_keys=1):
if operands is None:
operands = [RandArg(shape, dtype)]
define(
lax.sort_p,
f"{name}_num_arrays={len(operands)}_shape={jtu.format_shape_dtype_string(operands[0].shape, operands[0].dtype)}_axis={dimension}_isstable={is_stable}_{num_keys=}",
lambda *args: lax.sort_p.bind(
*args[:-3], dimension=args[-3], is_stable=args[-2], num_keys=args[-1]
), [
*operands,
StaticArg(dimension),
StaticArg(is_stable),
StaticArg(num_keys)
],
shape=operands[0].shape,
dimension=dimension,
dtype=operands[0].dtype,
is_stable=is_stable,
num_keys=num_keys,
num_arrays=len(operands))
_lax_sort_multiple_array_shape = (100,)
# In order to test lexicographic ordering and sorting stability, the first
# array contains only integers 0 and 1
_lax_sort_multiple_array_first_arg = (
np.random.uniform(0, 2, _lax_sort_multiple_array_shape).astype(np.int32))
# Validate dtypes
for dtype in jtu.dtypes.all:
_make_sort_harness("dtypes", dtype=dtype)
# Validate dimensions
for dimension in [0, 1]:
_make_sort_harness("dimensions", dimension=dimension)
# Validate stable sort
_make_sort_harness("is_stable", is_stable=True)
# Potential edge cases
for operands, dimension in [
([np.array([+np.inf, np.nan, -np.nan, -np.inf, 2], dtype=np.float32)], 0)
]:
_make_sort_harness("edge_cases", operands=operands, dimension=dimension)
# Validate multiple arrays, num_keys, and is_stable
for is_stable in [False, True]:
for operands in (
[
_lax_sort_multiple_array_first_arg,
RandArg(_lax_sort_multiple_array_shape, np.int32)
],
[
_lax_sort_multiple_array_first_arg,
RandArg(_lax_sort_multiple_array_shape, np.int32),
RandArg(_lax_sort_multiple_array_shape, np.float32)
],
):
for num_keys in range(1, len(operands) + 1):
_make_sort_harness(
"multiple_arrays",
operands=operands,
num_keys=num_keys,
is_stable=is_stable,
shape=_lax_sort_multiple_array_first_arg.shape,
dtype=_lax_sort_multiple_array_first_arg.dtype)
def _make_cholesky_arg(shape, dtype, rng):
a = jtu.rand_default(rng)(shape, dtype)
return np.matmul(a, jnp.conj(np.swapaxes(a, -1, -2)))
for dtype in jtu.dtypes.all_inexact:
for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]:
define(
lax.linalg.cholesky_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
lambda *args: lax.linalg.cholesky_p.bind(*args),
[CustomArg(partial(_make_cholesky_arg, shape, dtype))],
jax_unimplemented=[
Limitation(
"unimplemented", dtypes=[np.float16], devices=("cpu", "gpu"))
],
shape=shape,
dtype=dtype)
for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]:
for full_matrices in [False, True]:
define(
lax.linalg.qr_p,
f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}",
partial(lax.linalg.qr, full_matrices=full_matrices),
[RandArg(shape, dtype)],
# See jax._src.lib.lapack.geqrf for the list of compatible types
jax_unimplemented=[
Limitation(
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16]),
],
shape=shape,
dtype=dtype,
full_matrices=full_matrices)
def _make_fft_harness(name,
*,
shape=(14, 15, 16, 17),
dtype=np.float32,
fft_type=xla_client.FftType.FFT,
fft_lengths=(17,)):
def _fft_rng_factory(dtype):
_all_integers = (
jtu.dtypes.all_integer + jtu.dtypes.all_unsigned + jtu.dtypes.boolean)
# For integer types, use small values to keep the errors small
if dtype in _all_integers:
return jtu.rand_small
else:
return jtu.rand_default
define(
lax.fft_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_ffttype={fft_type}_fftlengths={fft_lengths}",
lambda *args: lax.fft_p.bind(
args[0], fft_type=args[1], fft_lengths=args[2]),
[RandArg(shape, dtype),
StaticArg(fft_type),
StaticArg(fft_lengths)],
rng_factory=_fft_rng_factory(dtype),
shape=shape,
dtype=dtype,
fft_type=fft_type,
fft_lengths=fft_lengths)
# FFT, IFFT, RFFT, IRFFT
for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])):
# Validate dtypes per FFT type
for dtype in (jtu.dtypes.floating
if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex):
shape = (14, 15, 16, 17)
for fft_lengths in [
(shape[-1],) if fft_type != xla_client.FftType.IRFFT else
((shape[-1] - 1) * 2,)
]:
_make_fft_harness(
"dtypes",
shape=shape,
dtype=dtype,
fft_type=fft_type,
fft_lengths=fft_lengths)
# And with a 0 shape
_make_fft_harness(
"dtypes_zero",
shape=(14, 15, 0, 17),
dtype=dtype,
fft_type=fft_type,
fft_lengths=fft_lengths)
# Validate dimensions per FFT type
for dtype in [
np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64
]:
for dims in [1, 2, 3]:
for fft_lengths in [
shape[-dims:] if fft_type != xla_client.FftType.IRFFT else
shape[-dims:-1] + ((shape[-1] - 1) * 2,)
]:
_make_fft_harness(
"dims",
shape=shape,
fft_type=fft_type,
fft_lengths=fft_lengths,
dtype=dtype)
for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
for shape in [(2, 2), (2, 7), (29, 29), (2, 3, 53), (2, 3, 29, 7)]:
for full_matrices in [False, True]:
for compute_uv in [False, True]:
define(
lax.linalg.svd_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}_computeuv={compute_uv}",
lambda *args: lax.linalg.svd_p.bind(
args[0], full_matrices=args[1], compute_uv=args[2]), [
RandArg(shape, dtype),
StaticArg(full_matrices),
StaticArg(compute_uv)
],
jax_unimplemented=[
Limitation(
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16]),
],
shape=shape,
dtype=dtype,
full_matrices=full_matrices,
compute_uv=compute_uv)
for dtype in jtu.dtypes.all_inexact:
for shape in [(0, 0), (5, 5), (2, 6, 6)]:
for compute_left_eigenvectors in [False, True]:
for compute_right_eigenvectors in [False, True]:
define(
lax.linalg.eig_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_computelefteigenvectors={compute_left_eigenvectors}_computerighteigenvectors={compute_right_eigenvectors}",
partial(lax.linalg.eig, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors),
[
RandArg(shape, dtype),
],
jax_unimplemented=[
Limitation(
"only supported on CPU in JAX", devices=("tpu", "gpu")),
Limitation(
"unimplemented",
devices="cpu",
dtypes=[np.float16, dtypes.bfloat16])
],
shape=shape,
dtype=dtype,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
def _make_triangular_eigh_operand(shape, dtype, lower: bool, rng: Rng):
# For testing eigh we use triangular matrices
operand = jtu.rand_default(rng)(shape, dtype)
# Make operand self-adjoint
operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2
# Make operand lower/upper triangular
return operand # np.tril(operand) if lower else np.triu(operand)
for dtype in jtu.dtypes.all_inexact:
# The eigh implementation for TPU uses different lowering for n >= 256
# TODO: add a test case for n=300. First attempt resulted in significant
# numerical differences.
# And the implementation on GPU uses different lowering for n >= 32
for shape in [(0, 0), (50, 50), (2, 20, 20)]:
for lower in [False, True]:
define(
lax.linalg.eigh_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{lower=}",
# Make operand lower/upper triangular
lambda operand, lower, symmetrize_input: (lax.linalg.eigh(
jnp.tril(operand)
if lower else jnp.triu(operand), lower=lower, symmetrize_input=symmetrize_input)),
[
CustomArg(
partial(_make_triangular_eigh_operand, shape, dtype, lower)),
StaticArg(lower),
StaticArg(False)
],
jax_unimplemented=[
Limitation(
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16]),
],
shape=shape,
dtype=dtype,
lower=lower)
for dtype in jtu.dtypes.all_inexact:
for shape in [
(5, 5), # square
(3, 5, 5), # batched
(3, 5), # non-square
]:
define(
lax.linalg.lu_p,
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
lax.linalg.lu, [RandArg(shape, dtype)],
jax_unimplemented=[
Limitation("unimplemented", dtypes=[np.float16, dtypes.bfloat16])
],
shape=shape,
dtype=dtype)
def _make_triangular_solve_harness(name,
*,
left_side=True,
lower=False,
ab_shapes=((4, 4), (4, 1)),
dtype=np.float32,
transpose_a=False,
conjugate_a=False,
unit_diagonal=False):
a_shape, b_shape = ab_shapes
f_lax = lambda a, b: (
lax.linalg.triangular_solve_p.bind(
a,
b,
left_side=left_side,
lower=lower,
transpose_a=transpose_a,
conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal))
define(
lax.linalg.triangular_solve_p,
f"{name}_a={jtu.format_shape_dtype_string(a_shape, dtype)}_b={jtu.format_shape_dtype_string(b_shape, dtype)}_leftside={left_side}_{lower=}_transposea={transpose_a}_conjugatea={conjugate_a}_unitdiagonal={unit_diagonal}",
f_lax, [RandArg(a_shape, dtype),
RandArg(b_shape, dtype)],
jax_unimplemented=[
Limitation("unimplemented", devices="gpu", dtypes=[np.float16]),
],
dtype=dtype,
a_shape=a_shape,
b_shape=b_shape,
left_side=left_side,
lower=lower,
tranpose_a=transpose_a,
conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
# Validate dtypes
# This first harness runs the tests for all dtypes using default values for
# all the other parameters, except unit_diagonal (to ensure that
# tf.linalg.set_diag works reliably for all dtypes). Variations of other
# parameters can thus safely skip testing their corresponding default value.
# Note that this validates solving on the left.
for dtype in jtu.dtypes.all_inexact:
for unit_diagonal in [False, True]:
_make_triangular_solve_harness(
"dtypes", dtype=dtype, unit_diagonal=unit_diagonal)
# Validate shapes when solving on the right
for ab_shapes in [
((4, 4), (1, 4)), # standard
((2, 8, 8), (2, 10, 8)), # batched
]:
_make_triangular_solve_harness(
"shapes_right", ab_shapes=ab_shapes, left_side=False)
# Validate transformations of a complex matrix
for lower in [False, True]:
for transpose_a in [False, True]:
for conjugate_a in [False, True]:
_make_triangular_solve_harness(
"complex_transformations",
dtype=np.complex64,
lower=lower,
transpose_a=transpose_a,
conjugate_a=conjugate_a)
# Validate transformations of a real matrix
for lower in [False, True]:
for transpose_a in [False, True]:
# conjugate_a is irrelevant for real dtypes, and is thus omitted
_make_triangular_solve_harness(
"real_transformations",
dtype=np.float32,
lower=lower,
transpose_a=transpose_a)
def _make_linear_solve_harnesses():
def linear_solve(a, b, solve, transpose_solve=None, symmetric=False):
matvec = partial(lax.dot, a, precision=lax.Precision.HIGHEST)
return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric)
def explicit_jacobian_solve(matvec, b):
return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
def _make_harness(name,
*,
shape=(4, 4),
dtype=np.float32,
symmetric=False,
solvers=(explicit_jacobian_solve, explicit_jacobian_solve)):
solve, transpose_solve = solvers
transpose_solve_name = transpose_solve.__name__ if transpose_solve else None
def _make_first_argument(rng):
a = jtu.rand_default(rng)(shape, dtype)
if symmetric:
a = a + a.T
return a
define(
lax.linear_solve_p,
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_b={jtu.format_shape_dtype_string(shape[:-1], dtype)}_solve={solve.__name__}_transposesolve={transpose_solve_name}_{symmetric=}",
linear_solve, [
CustomArg(_make_first_argument),
RandArg(shape[:-1], dtype),
StaticArg(solve),
StaticArg(transpose_solve),
StaticArg(symmetric)
],
shape=shape,
dtype=dtype,
solve=solve,
transpose_solve=transpose_solve,
symmetric=symmetric)
for dtype in jtu.dtypes.all_floating:
if not dtype in [np.float16, dtypes.bfloat16]:
_make_harness("dtypes", dtype=dtype)
# Validate symmetricity
_make_harness("symmetric", symmetric=True)
# Validate removing transpose_solve
_make_harness("transpose_solve", solvers=(explicit_jacobian_solve, None))
_make_linear_solve_harnesses()
# tridiagonal_solve_p
for dtype in [np.float32, np.float64]:
define(
lax.linalg.tridiagonal_solve_p,
f"shape={jtu.format_shape_dtype_string((3,), dtype)}",
lax.linalg.tridiagonal_solve,
[ np.array([0.0, 2.0, 3.0], dtype=dtype),
np.ones(3, dtype=dtype),
np.array([1.0, 2.0, 0.0], dtype=dtype),
np.ones([3, 1], dtype=dtype)],
dtype=dtype)
def _make_slice_harness(name,
shape=(3,),
start_indices=(1,),
limit_indices=(2,),
strides=None,
dtype=np.float32):
define(
lax.slice_p,
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}",
# type: ignore
lax.slice,
[
RandArg(shape, dtype), # type: ignore
StaticArg(start_indices), # type: ignore
StaticArg(limit_indices), # type: ignore
StaticArg(strides)
], # type: ignore
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
limit_indices=limit_indices) # type: ignore
# Test first all dtypes
for dtype in jtu.dtypes.all:
_make_slice_harness("dtypes", dtype=dtype)
# Now test many shapes
for shape, start_indices, limit_indices, strides in [
((3,), (1,), (2,), None),
((7,), (4,), (7,), None),
((5,), (1,), (5,), (2,)),
((8,), (1,), (6,), (2,)),
((5, 3), (1, 1), (3, 2), None),
((5, 3), (1, 1), (3, 1), None),
((7, 5, 3), (4, 0, 1), (7, 1, 3), None),
((5, 3), (1, 1), (2, 1), (1, 1)),
((5, 3), (1, 1), (5, 3), (2, 1)),
]:
_make_slice_harness(
"shapes",
shape=shape,
start_indices=start_indices,
limit_indices=limit_indices,
strides=strides)
def _make_complex_harness(name, *, shapes=((3, 4), (3, 4)), dtype=np.float32):
define(
lax.complex_p,
f"{name}_lhs={jtu.format_shape_dtype_string(shapes[0], dtype)}_rhs={jtu.format_shape_dtype_string(shapes[1], dtype)}",
lax.complex_p.bind,
[RandArg(shapes[0], dtype),
RandArg(shapes[1], dtype)],
shapes=shapes,
dtype=dtype)
for dtype in jtu.dtypes.floating:
_make_complex_harness("dtypes", dtype=dtype)
# Validate broadcasting
for shapes in [
((3, 2), (3, 1)), # broadcast imaginary part
((3, 1), (3, 2)), # broadcast real part
]:
_make_complex_harness("broadcast", shapes=shapes)
def _make_conj_harness(name, *, shape=(3, 4), dtype=np.float32, **kwargs):
define(
lax.conj_p,
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}_{kwargs=}"
.replace(" ", ""),
lambda x: lax.conj_p.bind(x, **kwargs), [RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
**kwargs)
for dtype in jtu.dtypes.floating + jtu.dtypes.complex:
_make_conj_harness("dtypes", dtype=dtype)
# Validate kwargs
_make_conj_harness("kwargs", _input_dtype=np.float32)
def _make_real_imag_harness(prim, name, *, shape=(2, 3), dtype=np.float32):
define(
prim,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}",
prim.bind, [RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
prim=prim)
for dtype in jtu.dtypes.complex:
for prim in [lax.real_p, lax.imag_p]:
_make_real_imag_harness(prim, "dtypes", dtype=dtype)
def _make_dynamic_slice_harness(name,
shape=(3,),
start_indices=(1,),
limit_indices=(2,),
dtype=np.float32):
for enable_xla in [False, True]:
define(
lax.dynamic_slice_p,
f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}",
# type: ignore
lax.dynamic_slice,
[
RandArg(shape, dtype), # type: ignore
np.array(list(start_indices)),
StaticArg(tuple(map(operator.sub, limit_indices, start_indices)))
], # type: ignore
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
limit_indices=limit_indices, # type: ignore
enable_xla=enable_xla)
# Test first all dtypes
for dtype in jtu.dtypes.all:
_make_dynamic_slice_harness("dtypes", dtype=dtype)
# Now test many shapes
for shape, start_indices, limit_indices in [
((3,), (1,), (2,)),
((7,), (4,), (7,)),
((5,), (1,), (5,)),
((8,), (1,), (6,)),
((5, 3), (1, 1), (3, 2)),
((7, 5, 3), (4, 0, 1), (7, 1, 3)),
((5, 3), (1, 1), (2, 1)),
((3,), (np.array(1, np.uint32),), (np.array(2,
np.uint32),)), # uint32 indices
((3,), (np.array(1, np.uint8),), (np.array(2, np.uint8),)), # uint8 indices
# out-of-bounds cases, allowed for dynamic_slice
((5,), (-1,), (0,)),
((5,), (-1,), (1,)),
((5,), (-4,), (-2,)),
((5,), (-5,), (-2,)),
((5,), (-6,), (-5,)),
((5,), (-10,), (-9,)),
((5,), (-100,), (-99,)),
((5,), (5,), (6,)),
((5,), (10,), (11,)),
((5,), (3,), (6,))
]:
_make_dynamic_slice_harness(
"shapes",
shape=shape,
start_indices=start_indices,
limit_indices=limit_indices)
def _make_dynamic_update_slice_harness(name,
shape=(3,),
start_indices=(1,),
dtype=np.float32,
update_shape=(1,)):
for enable_xla in [False, True]:
define(
lax.dynamic_update_slice_p,
(
f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}"
f"_{start_indices=}_{enable_xla=}"),
lax.dynamic_update_slice,
[
RandArg(shape, dtype), # type: ignore
RandArg(update_shape, dtype), # type: ignore
np.array(start_indices)
], # type: ignore
dtype=dtype,
shape=shape, # type: ignore
start_indices=start_indices, # type: ignore
update_shape=update_shape, # type: ignore
enable_xla=enable_xla)
# Test first all dtypes
for dtype in jtu.dtypes.all:
_make_dynamic_update_slice_harness("dtypes", dtype=dtype)
# Now test many shapes
for shape, start_indices, update_shape in [
((3,), (1,), (1,)),
((5, 3), (1, 1), (3, 1)),
((7, 5, 3), (4, 1, 0), (2, 0, 1)),
((3,), (np.array(1, np.uint32),), (1,)), # uint32 indices
((3,), (np.array(1, np.uint8),), (1,)), # uint8 indices
((3,), (-1,), (1,)), # out-of-bounds
((3,), (10,), (1,)), # out-of-bounds
((3,), (10,), (2,)), # out-of-bounds
]:
_make_dynamic_update_slice_harness(
"shapes",
shape=shape,
start_indices=start_indices,
update_shape=update_shape)
def _make_squeeze_harness(name,
shape=(1, 2),
dimensions=(0,),
dtype=np.float32):
define(
lax.squeeze_p,
f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore
lax.squeeze,
[RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
dtype=dtype,
arg_shape=shape,
dimensions=dimensions) # type: ignore[has-type]
# Test first all dtypes
for dtype in set(jtu.dtypes.all):
_make_squeeze_harness("dtypes", dtype=dtype)
# Now test many shapes
for shape, dimensions in [
((1,), (0,)),
((1,), (-1,)),
((2, 1, 4), (1,)),
((2, 1, 4), (-2,)),
((2, 1, 3, 1), (1,)),
((2, 1, 3, 1), (1, 3)),
((2, 1, 3, 1), (3,)),
((2, 1, 3, 1), (1, -1)),
]:
_make_squeeze_harness("shapes", shape=shape, dimensions=dimensions)
def _make_select_and_scatter_add_harness(name,
*,
shape=(2, 4, 6),
dtype=np.float32,
select_prim=lax.ge_p,
window_dimensions=(2, 2, 2),
window_strides=(1, 1, 1),
padding=((0, 0), (0, 0), (0, 0)),
nb_inactive_dims=0):
ones = (1,) * len(shape)
cotangent_shape = jax.eval_shape(
lambda x: lax_windowed_reductions._select_and_gather_add(
x, x, lax.ge_p, window_dimensions, window_strides, padding,
ones, ones),
np.ones(shape, dtype)).shape
define(
lax.select_and_scatter_add_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_{padding=!s}",
lax_windowed_reductions._select_and_scatter_add, [
RandArg(cotangent_shape, dtype),
RandArg(shape, dtype),
StaticArg(select_prim),
StaticArg(window_dimensions),
StaticArg(window_strides),
StaticArg(padding)
],
jax_unimplemented=[
Limitation(
"works only for 2 or more inactive dimensions",
devices="tpu",
enabled=(nb_inactive_dims < 2))
],
shape=shape,
dtype=dtype,
select_prim=select_prim,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding)
for dtype in set(jtu.dtypes.all) - {np.complex64, np.complex128}:
_make_select_and_scatter_add_harness("dtypes", dtype=dtype)
# Validate different reduction primitives
_make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p)
# Validate padding
for padding in [
# TODO(bchetioui): commented out the test based on
# https://github.com/google/jax/issues/4690
# ((1, 2), (2, 3), (3, 4)) # non-zero padding
((1, 1), (1, 1), (1, 1)) # non-zero padding
]:
_make_select_and_scatter_add_harness("padding", padding=padding)
# Validate window_dimensions; uneven dimensions
_make_select_and_scatter_add_harness(
"window_dimensions", window_dimensions=(1, 2, 3))
# Validate window_strides
# smaller than/same as/bigger than corresponding window dimension
_make_select_and_scatter_add_harness("window_strides", window_strides=(1, 2, 3))
# Validate dtypes on TPU
for dtype in set(jtu.dtypes.all) - {
np.bool_, np.complex64, np.complex128, np.int8, np.uint8}:
for window_strides, window_dimensions, nb_inactive_dims in [((1, 2, 1),
(1, 3, 1), 2)]:
_make_select_and_scatter_add_harness(
"tpu_dtypes",
dtype=dtype,
nb_inactive_dims=nb_inactive_dims,
window_strides=window_strides,
window_dimensions=window_dimensions)
def _make_select_and_gather_add_harness(name,
*,
shape=(4, 6),
dtype=np.float32,
select_prim=lax.le_p,
padding="VALID",
window_dimensions=(2, 2),
window_strides=(1, 1),
base_dilation=(1, 1),
window_dilation=(1, 1)):
if isinstance(padding, str):
padding = tuple(
lax.padtype_to_pads(shape, window_dimensions, window_strides, padding))
define(
lax.select_and_gather_add_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_{padding=!s}_basedilation={base_dilation}_windowdilation={window_dilation}",
lax_windowed_reductions._select_and_gather_add, [
RandArg(shape, dtype),
RandArg(shape, dtype),
StaticArg(select_prim),
StaticArg(window_dimensions),
StaticArg(window_strides),
StaticArg(padding),
StaticArg(base_dilation),
StaticArg(window_dilation)
],
shape=shape,
dtype=dtype,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation)
for dtype in jtu.dtypes.all_floating:
for select_prim in [lax.ge_p, lax.le_p]:
_make_select_and_gather_add_harness(
"dtypes", dtype=dtype, select_prim=select_prim)
# Validate selection primitives
_make_select_and_gather_add_harness("select_prim", select_prim=lax.ge_p)
# Validate window dimensions
_make_select_and_gather_add_harness(
"window_dimensions", window_dimensions=(2, 3))
# Validate window strides
_make_select_and_gather_add_harness("window_strides", window_strides=(2, 3))
# Validate padding
_make_select_and_gather_add_harness("padding", padding="SAME")
# Validate dilations
for base_dilation, window_dilation in [
((2, 3), (1, 1)), # base dilation, no window dilation
((1, 1), (2, 3)), # no base dilation, window dilation
((2, 3), (3, 2)) # base dilation, window dilation
]:
_make_select_and_gather_add_harness(
"dilations", base_dilation=base_dilation, window_dilation=window_dilation)
def _make_reduce_harness(name, *,
shape=(4, 6), # The shape of all operands
nr_operands=1, # How many operands
computation=lax.add, # Takes Tuple(op1, [op2,]) and Tuple(init_val1, [init_val2]). Returns Tuple(out_val1, [out_val2]).
dimensions: Sequence[int] = (0,),
init_value=0, # The init value for first operand
dtype=np.float32): # The dtype of first operand
def reducer(*args):
init_val = np.array(init_value, dtype=dtype)
init_values = [init_val]
if nr_operands == 2:
init_values.append(np.int32(0.))
return lax.reduce(args[0:nr_operands], tuple(init_values),
computation, dimensions)
define(
lax.reduce_p,
f"gen_{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_{nr_operands=}_{dimensions=}".replace(" ", ""),
reducer,
[
RandArg(shape, dtype),
# Second operand (optional, always i32). We cannot mix multiple float
# types in XLA.
RandArg(shape, np.int32),
],
shape=shape,
dtype=dtype,
init_value=init_value,
computation=computation,
dimensions=dimensions)
for dtype in jtu.dtypes.all:
for name, nr_operands, computation, init_value in [
("add_scalar", 1,
lambda ops, inits: (lax.add(ops[0], inits[0]),), 3),
# Compute the max (starting with 3) and the min (from 0), in parallel
("max_min", 2,
lambda ops, inits: (lax.max(ops[0], inits[0]),
lax.min(ops[1], inits[1])), 3),
]:
if not (dtype == np.bool_ and name == "add_scalar"):
_make_reduce_harness(name, nr_operands=nr_operands,
computation=computation, init_value=init_value,
dtype=dtype)
# Test the dimensions, but only for int32 (to keep the # of tests small)
if dtype == np.int32:
_make_reduce_harness(name, nr_operands=nr_operands,
computation=computation, init_value=init_value,
dimensions=(1,),
dtype=dtype)
_make_reduce_harness(name, nr_operands=nr_operands,
computation=computation, init_value=init_value,
dimensions=(0, 1),
dtype=dtype)
def _make_reduce_window_harness(name,
*,
shape=(4, 6),
base_dilation=(1, 1),
computation=lax.add,
window_dimensions=(2, 2),
window_dilation=(1, 1),
init_value=0,
window_strides=(1, 1),
dtype=np.float32,
padding=((0, 0), (0, 0)),
requires_xla=False):
prim_name = f"reduce_window_{computation.__name__}"
limitations = []
xla_opts = [True] if requires_xla else [True, False]
for enable_xla in xla_opts:
define(
prim_name,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_initvalue={init_value}_windowdimensions={window_dimensions}_windowstrides={window_strides}_{padding=!s}_basedilation={base_dilation}_windowdilation={window_dilation}_enablexla={enable_xla}"
.replace(" ", ""),
lax.reduce_window,
[
RandArg(shape, dtype),
# Must be static to trigger the picking of the reducers
StaticArg(np.array(init_value, dtype=dtype)),
StaticArg(computation),
StaticArg(window_dimensions),
StaticArg(window_strides),
StaticArg(padding),
StaticArg(base_dilation),
StaticArg(window_dilation)
],
jax_unimplemented=limitations,
shape=shape,
dtype=dtype,
init_value=np.array(init_value, dtype=dtype),
computation=computation,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation,
enable_xla=enable_xla)
def requires_xla_for_reduce(name, dtype):
if name not in ["min", "max", "add"]:
return True
if name in ["min", "max"] and dtype in [
np.bool_, np.uint32, np.uint64, np.complex64, np.complex128
]:
return True
if name == "min" and dtype in [np.uint8, np.uint16]:
return True
if name == "add" and dtype not in [np.float16, np.float32, np.float64]:
return True
return False
# Validate dtypes across all execution paths
# This first harness runs the tests for all dtypes using default values for
# the other parameters (outside of computation and its init_value), through
# several execution paths. Variations of other parameters can thus safely
# skip testing their corresponding default value.
for dtype in jtu.dtypes.all:
for computation, init_value in [
(lax.min, _get_min_identity(dtype)), # path through reduce_window_min
(lax.max, _get_max_identity(dtype)), # path through TF reduce_window_max
(lax.max, 1), # path through reduce_window
(lax.add, 0), # path_through reduce_window_sum
(lax.add, 1), # path through reduce_window
(lax.mul, 0), # path through reduce_window
(lax.mul, 1), # path through reduce_window
(lax.mul, 2), # path through reduce_window
]:
if dtype == np.bool_ and computation in [lax.add, lax.mul]:
continue
_make_reduce_window_harness(
"dtypes",
dtype=dtype,
computation=computation,
init_value=init_value,
requires_xla=requires_xla_for_reduce(computation.__name__, dtype))
# Validate window_dimensions
_make_reduce_window_harness("window_dimensions", window_dimensions=(1, 1))
# Validate window_strides
_make_reduce_window_harness("window_strides", window_strides=(1, 2))
# Validate padding
_make_reduce_window_harness("padding", padding=((1, 2), (0, 3)),
requires_xla=True)
# Validate base_dilation
_make_reduce_window_harness("base_dilation", base_dilation=(1, 2),
requires_xla=True)
# Validate window_dilation
_make_reduce_window_harness("window_dilation", window_dilation=(1, 2))
# Validate batch and channel dimensions behavior. lax.reduce_window accepts
# inputs that either have or do not have batch and channel dimensions.
# N=batch, DHW=spatial, C=channel.
# Without XLA only supports 1D/2D reductions.
for shape, window_dimensions, requires_xla in [
((2,), (2,), False), # W
((2, 1), (2, 1), False), # WC
((1, 2), (1, 2), False), # NW
((1, 2, 1), (1, 2, 1), False), # NWC
((2, 4), (2, 2), False), # HW
((1, 2, 4, 1), (1, 2, 2, 1), False), # NHWC
((2, 4, 3), (2, 2, 2), True), # DHW
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1), True) # NDHWC
]:
_make_reduce_window_harness(
"batch_channel_dims",
computation=lax.max,
shape=shape,
dtype=np.float32,
init_value=-np.inf,
base_dilation=tuple([1] * len(shape)),
window_dilation=tuple([1] * len(shape)),
padding=tuple([(0, 0)] * len(shape)),
window_strides=tuple([1] * len(shape)),
window_dimensions=window_dimensions,
requires_xla=requires_xla)
for computation, id_value in [(lax.max, _get_max_identity(np.float32)),
(lax.min, _get_min_identity(np.float32)),
(lax.add, 0.)]:
_make_reduce_window_harness(
"same_padding",
shape=(112, 112),
init_value=id_value,
computation=computation,
window_dimensions=(3, 3),
window_strides=(2, 2),
padding="SAME")
# A few additional test cases for manual padding, which is applied when calling
# reduce_window with lax.add, SAME padding and window_dimensions != (1, 1, ...).
for window_dimensions, window_strides in [((2, 2), (1, 1)), ((3, 3), (2, 2)),
((13, 13), (5, 6))]:
_make_reduce_window_harness(
"manual_padding",
shape=(12, 12),
init_value=0.,
computation=lax.add,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding="SAME")
_make_reduce_window_harness(
"init_value_1d",
shape=(1, 1600),
init_value=1.0,
computation=lax.min,
window_dimensions=[1, 401],
window_strides=[1, 160],
padding="VALID",
requires_xla=False)
def _make_reducer_harness(prim,
name,
*,
shape=(2, 3),
axes=(0,),
dtype=np.int32):
define(
prim,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}",
lambda arg: prim.bind(arg, axes=axes), [RandArg(shape, dtype)],
prim=prim,
shape=shape,
dtype=dtype,
axes=axes)
for prim in [
lax.reduce_sum_p, lax.reduce_prod_p, lax.reduce_max_p, lax.reduce_min_p,
lax.reduce_or_p, lax.reduce_and_p
]:
for dtype in {
lax.reduce_sum_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean),
lax.reduce_prod_p: set(jtu.dtypes.all) - set(jtu.dtypes.boolean),
lax.reduce_max_p: jtu.dtypes.all,
lax.reduce_min_p: jtu.dtypes.all,
lax.reduce_or_p: jtu.dtypes.boolean,
lax.reduce_and_p: jtu.dtypes.boolean
}[prim]:
_make_reducer_harness(prim, "dtypes", dtype=dtype)
for dtype in (np.float32, np.float64):
for shape in ((), (3,)):
define(
"random_gamma",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
jax.jit(jax_random.gamma),
[np.array([42, 43], dtype=np.uint32),
RandArg(shape, dtype)],
dtype=dtype)
for key_i, key in enumerate([
np.array([0, 0], dtype=np.uint32),
np.array([42, 43], dtype=np.uint32),
np.array([0xFFFFFFFF, 0], dtype=np.uint32),
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
]):
if jax.config.jax_enable_custom_prng:
def wrap_and_split(key):
key = prng.random_wrap(key, impl=jax.random.default_prng_impl())
result = jax.random.split(key, 2)
return prng.random_unwrap(result)
else:
def wrap_and_split(key):
return jax.random.split(key, 2)
define(
"random_split",
f"i={key_i}",
jax.jit(wrap_and_split), [key],
dtype=key.dtype)
# A few library functions from jax.random
for dtype in jtu.dtypes.all_floating:
for shape in ((8,), (5, 4)):
for axis in range(len(shape)):
define(
"random_categorical",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{axis=}",
jax.random.categorical,
[np.array([42, 43], dtype=np.uint32), RandArg(shape, dtype),
StaticArg(axis)],
dtype=dtype,
axis=axis)
for dtype in jtu.dtypes.all_floating:
for shape in ((), (5, 4), (32,)):
define(
"random_uniform",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
jax.random.uniform,
[np.array([42, 43], dtype=np.uint32),
StaticArg(shape), StaticArg(dtype)],
dtype=dtype)
for dtype in jtu.dtypes.all_integer:
for shape in ((), (5, 4), (32,)):
maxval = {
np.uint8: 256, # Borderline
}.get(dtype, 5)
define(
"random_randint",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
jax.random.randint,
[np.array([42, 43], dtype=np.uint32),
StaticArg(shape),
StaticArg(-5), # minval
StaticArg(maxval),
StaticArg(dtype)],
dtype=dtype)
def _make_clamp_harness(name,
*,
min_shape=(),
operand_shape=(2, 3),
max_shape=(),
dtype=np.float32,
min_max=None):
min_arr, max_arr = (
min_max if min_max is not None else
[RandArg(min_shape, dtype),
RandArg(max_shape, dtype)])
define(
lax.clamp_p,
f"{name}_min={jtu.format_shape_dtype_string(min_arr.shape, min_arr.dtype)}_operand={jtu.format_shape_dtype_string(operand_shape, dtype)}_max={jtu.format_shape_dtype_string(max_arr.shape, max_arr.dtype)}",
lax.clamp, [min_arr, RandArg(operand_shape, dtype), max_arr],
min_shape=min_arr.shape,
operand_shape=operand_shape,
max_shape=max_arr.shape,
dtype=dtype,
jax_unimplemented=[
Limitation(
"unimplemented",
dtypes=[np.bool_, np.complex64, np.complex128])],
)
for dtype in set(jtu.dtypes.all):
_make_clamp_harness("dtypes", dtype=dtype)
# Validate broadcasting of min/max arrays
for min_shape, operand_shape, max_shape in [
((), (2, 3), (2, 3)), # no broadcasting for max
((2, 3), (2, 3), ()), # no broadcasting for min
((2, 3), (2, 3), (2, 3)), # no broadcasting
]:
_make_clamp_harness(
"broadcasting",
min_shape=min_shape,
max_shape=max_shape,
operand_shape=operand_shape)
# Validate clamping when minval > maxval, and when minval < maxval
for is_ordered, min_arr, max_arr in [
(False, np.array(4., dtype=np.float32), np.array(1., dtype=np.float32)),
(True, np.array(1., dtype=np.float32), np.array(4., dtype=np.float32))
]:
_make_clamp_harness(
f"order={is_ordered}", min_max=(min_arr, max_arr), dtype=np.float32)
def _make_dot_general_harness(name,
*,
lhs_shape=(3, 4),
rhs_shape=(4, 2),
dtype=np.float32,
precision=None,
dimension_numbers=(((1,), (0,)), ((), ())),
preferred_element_type=None):
suffix = ""
if precision is not None:
suffix += f"_{precision=}"
if preferred_element_type is not None:
suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}"
define(
lax.dot_general_p,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}{suffix}"
.replace(" ", ""),
lax.dot_general,
[
RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype),
StaticArg(dimension_numbers),
StaticArg(precision),
StaticArg(preferred_element_type)
],
dtype=dtype,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
jax_unimplemented=[
Limitation("preferred_element_type must match dtype for floating point",
devices="gpu",
dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64, np.complex64, np.complex128],
enabled=(preferred_element_type is not None and preferred_element_type != dtype))
]
)
# There are two execution paths in the conversion of dot_general. The main path
# uses tf.einsum, while special cases use tf.linalg.matmul. For that reason,
# the below tests are designed to perform the same checks on both execution
# paths.
# Validate dtypes and precision
# This first harness runs the tests for all dtypes and precisions using
# default values for all the other parameters. Variations of other parameters
# can thus safely skip testing their corresponding default value.
for dtype in jtu.dtypes.all:
for precision in [
None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST
]:
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 4), (4, 2), (((1,), (0,)), ((), ()))),
((1, 3, 4), (1, 4, 3), (((2, 1), (1, 2)), ((0,), (0,)))),
# Some batch dimensions
((7, 3, 4), (7, 4), (((2,), (1,)), ((0,), (0,)))),
]:
_make_dot_general_harness(
"dtypes_and_precision",
precision=precision,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
dtype=dtype)
# The other tests are only for float32.
# Validate batch dimensions
for lhs_shape, rhs_shape, dimension_numbers in [
# Unique pattern that can go through tf.linalg.matmul
((4, 4, 3, 3, 4), (4, 4, 3, 4, 2), (((4,), (3,)), ((0, 1, 2), (0, 1, 2)))),
# Main path with out of order batch dimensions
((8, 4, 3, 3, 4), (4, 8, 3, 4, 2), (((4, 3), (3, 2)), ((0, 1), (1, 0))))
]:
_make_dot_general_harness(
"batch_dimensions",
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
# Validate squeezing behavior for matmul path
for lhs_shape, rhs_shape, dimension_numbers in [
((4,), (4, 4), (((0,), (0,)), ((), ()))), # (1, 4) -> (4,)
((4, 4), (4,), (((1,), (0,)), ((), ()))), # (4, 1) -> (4,)
((4,), (4,), (((0,), (0,)), ((), ()))), # (1, 1) -> ()
]:
_make_dot_general_harness(
"squeeze",
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
# Validate preferred element type
# From lax_test.py
preferred_type_combinations = [(np.float16, np.float16), (np.float16,
np.float32),
(np.float16, np.float64),
(dtypes.bfloat16, np.float32),
(dtypes.bfloat16, np.float64),
(np.float32, np.float32),
(np.float32, np.float64), (np.int8, np.int16),
(np.int8, np.int32), (np.int8, np.int64),
(np.int16, np.int32), (np.int16, np.int64),
(np.int32, np.int32), (np.int32, np.int64),
(np.complex64, np.complex128)]
for lhs_shape in [(3,), (4, 3)]:
for rhs_shape in [(3,), (3, 6)]:
for dtype, preferred_element_type in preferred_type_combinations:
_make_dot_general_harness(
"preferred",
dtype=dtype,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=(((len(lhs_shape) - 1,), (0,)), ((), ())),
preferred_element_type=preferred_element_type)
def _make_concatenate_harness(name,
*,
shapes=[(2, 3), (2, 3)],
dimension=0,
dtype=np.float32):
shapes_str = "_".join(jtu.format_shape_dtype_string(s, dtype) for s in shapes)
define(
lax.concatenate_p,
f"{name}_shapes={shapes_str}_{dimension=}",
lambda *args: lax.concatenate_p.bind(*args, dimension=dimension),
[RandArg(shape, dtype) for shape in shapes],
shapes=shapes,
dtype=dtype,
dimension=dimension)
for dtype in jtu.dtypes.all:
_make_concatenate_harness("dtypes", dtype=dtype)
# Validate dimension; non-major axis
_make_concatenate_harness("dimension", dimension=1)
# Validate > 2 operands
for shapes in [
[(2, 3, 4), (3, 3, 4), (4, 3, 4)], # 3 operands
]:
_make_concatenate_harness("nb_operands", shapes=shapes)
def _make_conv_harness(name,
*,
lhs_shape=(2, 3, 9, 10),
rhs_shape=(3, 3, 4, 5),
dtype=np.float32,
window_strides=(1, 1),
precision=None,
padding=((0, 0), (0, 0)),
lhs_dilation=(1, 1),
rhs_dilation=(1, 1),
feature_group_count=1,
dimension_numbers=("NCHW", "OIHW", "NCHW"),
batch_group_count=1,
preferred_element_type=None,
works_without_xla=False):
enable_xla_cases = [True, False] if works_without_xla else [True]
for enable_xla in enable_xla_cases:
define(
lax.conv_general_dilated_p,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_{padding=!s}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_{precision=}_preferred={jtu.dtype_str(preferred_element_type)}_enablexla={enable_xla}"
.replace(" ", ""),
lax.conv_general_dilated,
[
RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype),
StaticArg(window_strides),
StaticArg(padding),
StaticArg(lhs_dilation),
StaticArg(rhs_dilation),
StaticArg(dimension_numbers),
StaticArg(feature_group_count),
StaticArg(batch_group_count),
StaticArg(precision),
StaticArg(preferred_element_type),
],
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=precision,
preferred_element_type=preferred_element_type,
enable_xla=enable_xla,
jax_unimplemented=[
# b/183565702 - no integer convolutions for GPU
Limitation(
"preferred_element_type not implemented for integers",
devices="gpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int16, np.int32,
np.int64])),
],
)
# Validate dtypes and precision
for dtype in jtu.dtypes.all_inexact:
for precision in [
None, lax.Precision.DEFAULT, lax.Precision.HIGH, lax.Precision.HIGHEST
]:
# This first harness runs the tests for all dtypes and precisions using
# default values for all the other parameters. Variations of other parameters
# can thus safely skip testing their corresponding default value.
_make_conv_harness(
"dtype_precision",
dtype=dtype,
precision=precision)
# Validate preferred_element_type
for dtype, preferred_element_type in preferred_type_combinations:
works_without_xla = dtype == np.float32 and preferred_element_type == np.float32
_make_conv_harness(
"preferred", dtype=dtype,
preferred_element_type=preferred_element_type,
works_without_xla=works_without_xla)
# Validate variations of feature_group_count and batch_group_count
for batch_group_count, feature_group_count in [
(1, 2), # feature_group_count != 1
(2, 1), # batch_group_count != 1
]:
for lhs_shape, rhs_shape in [
((2 * batch_group_count, 3 * feature_group_count, 9, 10),
(3 * feature_group_count * batch_group_count, 3, 4, 5))
]:
_make_conv_harness(
"group_counts",
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
# --- BEGIN Tests for conv_general_dilated with works_without_xla=True ---
# Validate Conv1D.
_make_conv_harness(
"conv1d",
lhs_shape=(2, 3, 10),
rhs_shape=(3, 3, 5),
window_strides=(1,),
padding=((0, 0),),
lhs_dilation=(1,),
rhs_dilation=(1,),
dimension_numbers=("NCH", "OIH", "NCH"),
works_without_xla=True)
# feature_group_count is supported for enable_xla=False only if we are doing a
# depthwise convolution, i.e.: in_channels == feature_group_count.
# See explanation of depthwise convolution at
# https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
_make_conv_harness(
"depthwise2d",
lhs_shape=(2, 3, 9, 9), # "NCHW": in_channels == 3
rhs_shape=(12, 1, 3, 3), # "OIHW": channel_multiplier = 12/3 = 4
feature_group_count=3,
works_without_xla=True)
_make_conv_harness(
"depthwise2d_dilated",
lhs_shape=(2, 3, 9, 9), # "NCHW": in_channels == 3
rhs_shape=(12, 1, 3, 3), # "OIHW": channel_multiplier = 12/3 = 4
feature_group_count=3,
lhs_dilation=(1, 1),
rhs_dilation=(2, 1),
works_without_xla=True)
_make_conv_harness(
"depthwise1d",
lhs_shape=(2, 3, 9), # "NCH": in_channels == 3
rhs_shape=(12, 1, 3), # "OIH": channel_multiplier = 12/3 = 4
feature_group_count=3,
lhs_dilation=(1,),
rhs_dilation=(1,),
window_strides=(1, ),
padding=((0, 0),),
dimension_numbers=("NCH", "OIH", "NCH"),
works_without_xla=True)
_make_conv_harness(
"depthwise1d_dilated",
lhs_shape=(2, 3, 9), # "NCH": in_channels == 3
rhs_shape=(12, 1, 3), # "OIH": channel_multiplier = 12/3 = 4
feature_group_count=3,
lhs_dilation=(1,),
rhs_dilation=(2,),
window_strides=(1,),
padding=((0, 0),),
dimension_numbers=("NCH", "OIH", "NCH"),
works_without_xla=True)
# Validate variations of window_strides
for window_strides in [(2, 3)]:
_make_conv_harness(
"window_strides",
window_strides=window_strides,
works_without_xla=True)
# Validate variations of padding
for padding in [
((1, 2), (0, 0)), # padding only one spatial axis
((1, 2), (2, 1)) # padding on both spatial axes
]:
_make_conv_harness("padding", padding=padding, works_without_xla=True)
# Validate variations of dilations
for lhs_dilation, rhs_dilation in [
((1, 1), (2, 2)), # dilation only on RHS (atrous)
]:
_make_conv_harness(
"dilations",
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
works_without_xla=True)
# Simulate a call from lax.conv_transpose.
_make_conv_harness(
"conv_tranpose2d_valid_padding",
lhs_shape=(1, 16, 16, 2),
rhs_shape=(2, 3, 2, 2),
window_strides=(1, 1),
lhs_dilation=(2, 2),
padding=((1, 1), (2, 2)),
dimension_numbers=("NHWC", "HWIO", "NHWC"),
works_without_xla=True)
# Simulate a call from lax.conv_transpose.
_make_conv_harness(
"conv_tranpose1d_valid_padding",
lhs_shape=(1, 16, 2),
rhs_shape=(3, 2, 2),
window_strides=(1,),
lhs_dilation=(2,),
rhs_dilation=(1,),
padding=((2, 2), ),
dimension_numbers=("NHC", "HIO", "NHC"),
works_without_xla=True)
_make_conv_harness(
"conv_tranpose1d_same_padding",
lhs_shape=(1, 16, 2),
rhs_shape=(3, 2, 2),
window_strides=(1,),
lhs_dilation=(2,),
rhs_dilation=(1,),
padding=((2, 1), ),
dimension_numbers=("NHC", "HIO", "NHC"),
works_without_xla=True)
_make_conv_harness(
"conv_tranpose2d_same_padding",
lhs_shape=(1, 16, 16, 2),
rhs_shape=(2, 3, 2, 2),
window_strides=(1, 1),
lhs_dilation=(2, 2),
padding=((1, 1), (2, 1)),
dimension_numbers=("NHWC", "HWIO", "NHWC"),
works_without_xla=True)
# Validate rhs > lhs.
# One dimension of rhs is bigger than lhs.
_make_conv_harness(
"rhs_oob",
lhs_shape=(2, 3, 9, 10),
rhs_shape=(3, 3, 10, 5),
works_without_xla=True)
# Effective rhs size is too big after applying rhs_dilation.
_make_conv_harness(
"rhs_oob_after_dilation",
lhs_shape=(2, 3, 9, 10),
rhs_shape=(3, 3, 4, 5),
rhs_dilation=(2, 3),
works_without_xla=True)
# Effective rhs is too big after applying input padding.
_make_conv_harness(
"rhs_oob_after_pading",
lhs_shape=(1, 3, 2, 2),
rhs_shape=(64, 3, 7, 7),
window_strides=(2, 2),
padding=((3, 3), (3, 3)),
works_without_xla=True)
# rhs out of bounds with "SAME" padding.
_make_conv_harness(
"rhs_oob_same_padding",
lhs_shape=(1, 1, 16, 1),
padding="SAME",
rhs_shape=(4, 1, 1, 2),
dimension_numbers=("NHWC", "HWIO", "NHWC"), # TF default
works_without_xla=True)
# Dimension numbers and corresponding permutation
for dimension_numbers, lhs_shape, rhs_shape in [
(("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default
(("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom
]:
_make_conv_harness(
"dimension_numbers",
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
works_without_xla=True)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding
("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding
("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding
("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding
([(1, 2), (0, 1)], (1, 1), (1, 2))
]:
for dimension_numbers, lhs_shape, rhs_shape in [
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
(("NCHW", "HWIO", "NCHW"), (1, 1, 28, 28), (3, 3, 1, 16)),
]:
_make_conv_harness(
"tf_conversion_path_2d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
works_without_xla=True)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1,), (1,)), # no dilation with "VALID" padding
("SAME", (1,), (1,)), # no dilation with "SAME" padding
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
]:
for dimension_numbers, lhs_shape, rhs_shape in [
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
]:
_make_conv_harness(
"tf_conversion_path_1d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1,),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
works_without_xla=True)
# --- END Tests for conv_general_dilated with works_without_xla=True ---
for lhs_dilation, rhs_dilation in [
# Note: LHS dilation does work for enable_xla=False, but only if
# padding=='VALID' (see test above for conv_transpose2d_valid_padding).
((2, 2), (1, 1)), # dilation only on LHS (transposed)
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
]:
_make_conv_harness(
"dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding
("SAME", (1, 1, 1), (1, 1, 1)), # no dilation with "SAME" padding
("VALID", (1, 1, 1), (1, 1,
2)), # dilation only on RHS with "VALID" padding
("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding
]:
for dimension_numbers, lhs_shape, rhs_shape in [
# TF default
(("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)),
]:
_make_conv_harness(
"tf_conversion_path_3d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1, 1),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
key_types = [((4,), np.uint32)]
if config.jax_enable_x64:
key_types.append(((2,), np.uint64))
for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY,
lax.RandomAlgorithm.RNG_PHILOX,
lax.RandomAlgorithm.RNG_DEFAULT]:
for dtype in [np.uint32, np.uint64]:
for shape in [(), (5, 7), (100, 100)]:
for key_shape, key_dtype in key_types:
define(
lax.rng_bit_generator_p,
f"{key_dtype=}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{algorithm=}",
lambda key, shape, dtype, algorithm: lax.rng_bit_generator(key, shape, dtype=dtype,
algorithm=algorithm),
[RandArg(key_shape, key_dtype),
StaticArg(shape), StaticArg(dtype), StaticArg(algorithm)],
shape=shape,
dtype=dtype,
algorithm=algorithm)
def _make_iota_2x32_shape_harness(shape):
shapestr = ','.join(str(dim) for dim in shape)
define(
prng.iota_2x32_shape_p,
f"shape=({shapestr})",
lambda shape: prng.iota_2x32_shape_p.bind(shape=shape),
[StaticArg(shape)],
dtype=jnp.uint32,
shape=shape)
for shape in [(3,), (5, 7, 4), (100, 100)]:
_make_iota_2x32_shape_harness(shape)
for in_dtype in jtu.dtypes.all_floating:
for out_dtype in jtu.dtypes.all_floating:
out_iinfo = dtypes.finfo(out_dtype)
for shape in [(), (5, 7)]:
define(
lax.reduce_precision_p,
f"in={jtu.format_shape_dtype_string(shape, in_dtype)}_out={jtu.format_shape_dtype_string(shape, out_dtype)}",
lambda x, exp_bits, mant_bits: lax.reduce_precision(x,
exponent_bits=exp_bits,
mantissa_bits=mant_bits),
[RandArg(shape, in_dtype),
StaticArg(out_iinfo.nexp), StaticArg(out_iinfo.nmant)],
shape=shape,
dtype=in_dtype,
out_dtype=out_dtype)
# approx_top_k
for is_max in [True, False]:
for dtype in jtu.dtypes.all_floating:
# There are different lowerings for sizes < 1024 for rank-1 and 128 for higher
# rank.
for shape in [(32,), (2048,), (16, 256)]:
define(
lax.approx_top_k_p,
f"large={np.prod(shape) >= 1024}_max={is_max}_op={jtu.format_shape_dtype_string(shape, dtype)}",
lambda operand, is_max: lax.approx_top_k_p.bind(
operand, k=4, reduction_dimension=-1, recall_target=0.95,
is_max_k=is_max,
reduction_input_size_override=-1,
aggregate_to_topk=True),
[RandArg(shape, dtype), StaticArg(is_max)],
dtype=dtype,
is_max=is_max,
shape=shape)