# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import abc from functools import partial, reduce import math import operator as op from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple, Set, Sequence, Tuple, Union) import numpy as np from jax import lax from jax import numpy as jnp from jax import tree_util from jax._src import ad_util from jax._src import api from jax._src import basearray from jax._src import config as config_lib from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import sharding_specs from jax._src import typing from jax._src.api import jit, vmap from jax._src.config import config from jax._src.dtypes import float0 from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils from jax._src.lib.mlir import ir from jax._src.lib import gpu_prng from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip Device = xc.Device Shard = Any # TODO(jakevdp): fix circular imports and import Shard UINT_DTYPES = { 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type] # -- PRNG implementation interface class PRNGImpl(NamedTuple): """Specifies PRNG key shape and operations. A PRNG implementation is determined by a key type ``K`` and a collection of functions that operate on such keys. The key type ``K`` is an array type with element type uint32 and shape specified by ``key_shape``. The type signature of each operations is:: seed :: int[] -> K fold_in :: K -> int[] -> K split[n] :: K -> K[n] random_bits[shape, bit_width] :: K -> uint[shape] A PRNG implementation is adapted to an array-like object of keys ``K`` by the ``PRNGKeyArray`` class, which should be created via the ``seed_with_impl`` function. """ key_shape: core.Shape seed: Callable split: Callable random_bits: Callable fold_in: Callable tag: str = '?' def __hash__(self) -> int: return hash(self.tag) def __str__(self) -> str: return self.tag def pprint(self): return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") + pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [ pp.text(f"{k} = {v}") for k, v in self._asdict().items() ])))) # -- PRNG key arrays def _check_prng_key_data(impl, key_data: typing.Array): ndim = len(impl.key_shape) if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']): raise TypeError("JAX encountered invalid PRNG key data: expected key_data " f"to have ndim, shape, and dtype attributes. Got {key_data}") if key_data.ndim < 1: raise TypeError("JAX encountered invalid PRNG key data: expected " f"key_data.ndim >= 1; got ndim={key_data.ndim}") if key_data.shape[-ndim:] != impl.key_shape: raise TypeError("JAX encountered invalid PRNG key data: expected key_data.shape to " f"end with {impl.key_shape}; got shape={key_data.shape} for {impl=}") if key_data.dtype not in [np.uint32, float0]: raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; " f"got dtype={key_data.dtype}") class PRNGKeyArrayMeta(abc.ABCMeta): """Metaclass for overriding PRNGKeyArray isinstance checks.""" def __instancecheck__(cls, instance): try: return (isinstance(instance.aval, core.ShapedArray) and type(instance.aval.dtype) is KeyTy) except AttributeError: return super().__instancecheck__(instance) class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta): """An array whose elements are PRNG keys""" @abc.abstractmethod # TODO(frostig): rename def unsafe_raw_array(self) -> PRNGKeyArray: ... @property @abc.abstractmethod def unsafe_buffer_pointer(self) -> int: ... @abc.abstractmethod def block_until_ready(self) -> PRNGKeyArray: ... @abc.abstractmethod def copy_to_host_async(self) -> None: ... @property @abc.abstractmethod def shape(self) -> Tuple[int, ...]: ... @property @abc.abstractmethod def ndim(self) -> int: ... @property @abc.abstractmethod def size(self) -> int: ... @property @abc.abstractmethod def dtype(self): ... @property @abc.abstractmethod def sharding(self): ... @property @abc.abstractmethod def at(self) -> _IndexUpdateHelper: ... @abc.abstractmethod def __len__(self) -> int: ... @abc.abstractmethod def __iter__(self) -> Iterator[PRNGKeyArray]: ... @abc.abstractmethod def reshape(self, newshape, order=None) -> PRNGKeyArray: ... @property @abc.abstractmethod def T(self) -> PRNGKeyArray: ... @abc.abstractmethod def __getitem__(self, _) -> PRNGKeyArray: ... @abc.abstractmethod def ravel(self, *_, **__) -> PRNGKeyArray: ... @abc.abstractmethod def squeeze(self, *_, **__) -> PRNGKeyArray: ... @abc.abstractmethod def swapaxes(self, *_, **__) -> PRNGKeyArray: ... @abc.abstractmethod def take(self, *_, **__) -> PRNGKeyArray: ... @abc.abstractmethod def transpose(self, *_, **__) -> PRNGKeyArray: ... @abc.abstractmethod def flatten(self, *_, **__) -> PRNGKeyArray: ... @property @abc.abstractmethod def is_fully_addressable(self) -> bool: ... @property @abc.abstractmethod def is_fully_replicated(self) -> bool: ... @abc.abstractmethod def device(self) -> Device: ... @abc.abstractmethod def devices(self) -> Set[Device]: ... @abc.abstractmethod def delete(self) -> None: ... @abc.abstractmethod def is_deleted(self) -> bool: ... @abc.abstractmethod def on_device_size_in_bytes(self) -> int: ... @property @abc.abstractmethod def addressable_shards(self) -> List[Shard]: ... @property @abc.abstractmethod def global_shards(self) -> List[Shard]: ... @abc.abstractmethod def addressable_data(self, index: int) -> PRNGKeyArray: ... # TODO(jakevdp): potentially add tolist(), tobytes(), # device_buffer, device_buffers, __cuda_interface__() class PRNGKeyArrayImpl(PRNGKeyArray): """An array of PRNG keys backed by an RNG implementation. This class lifts the definition of a PRNG, provided in the form of a ``PRNGImpl``, into an array-like pytree class. Instances of this class behave like an array whose base elements are keys, hiding the fact that keys are typically arrays (of ``uint32`` dtype) themselves. PRNGKeyArrays are also restricted relative to JAX arrays in that they do not expose arithmetic operations. They instead expose wrapper methods around the PRNG implementation functions (``split``, ``random_bits``, ``fold_in``). """ impl: PRNGImpl _base_array: typing.Array def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) _check_prng_key_data(impl, key_data) self.impl = impl self._base_array = key_data # TODO(frostig): rename to unsafe_base_array, or just offer base_array attr? def unsafe_raw_array(self): """Access the raw numerical array that carries underlying key data. Returns: A uint32 JAX array whose leading dimensions are ``self.shape``. """ return self._base_array def block_until_ready(self): _ = self._base_array.block_until_ready() return self def copy_to_host_async(self): _ = self._base_array.copy_to_host_async() @property def aval(self): return keys_shaped_array(self.impl, self.shape) @property def shape(self): return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape) @property def size(self): return math.prod(self.shape) @property def ndim(self): return len(self.shape) @property def dtype(self): return KeyTy(self.impl) _device = property(op.attrgetter('_base_array._device')) _committed = property(op.attrgetter('_base_array._committed')) device = property(op.attrgetter('_base_array.device')) # type: ignore[assignment] devices = property(op.attrgetter('_base_array.devices')) # type: ignore[assignment] is_fully_addressable = property(op.attrgetter('_base_array.is_fully_addressable')) # type: ignore[assignment] is_fully_replicated = property(op.attrgetter('_base_array.is_fully_replicated')) # type: ignore[assignment] delete = property(op.attrgetter('_base_array.delete')) # type: ignore[assignment] is_deleted = property(op.attrgetter('_base_array.is_deleted')) # type: ignore[assignment] on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment] unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment] def addressable_data(self, index: int) -> PRNGKeyArrayImpl: return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index)) @property def addressable_shards(self) -> List[Shard]: return [ type(s)( device=s._device, sharding=s._sharding, global_shape=s._global_shape, data=PRNGKeyArrayImpl(self.impl, s._data), ) for s in self._base_array.addressable_shards ] @property def global_shards(self) -> List[Shard]: return [ type(s)( device=s._device, sharding=s._sharding, global_shape=s._global_shape, data=PRNGKeyArrayImpl(self.impl, s._data), ) for s in self._base_array.global_shards ] @property def sharding(self): phys_sharding = self._base_array.sharding return KeyTyRules.logical_op_sharding(self.aval, phys_sharding) def _is_scalar(self): base_ndim = len(self.impl.key_shape) return self._base_array.ndim == base_ndim def __len__(self): if self._is_scalar(): raise TypeError('len() of unsized object') return len(self._base_array) def __iter__(self) -> Iterator[PRNGKeyArrayImpl]: if self._is_scalar(): raise TypeError('iteration over a 0-d key array') # TODO(frostig): we may want to avoid iteration by slicing because # a very common use of iteration is `k1, k2 = split(key)`, and # slicing/indexing may be trickier to track for linearity checking # purposes. Maybe we can: # * introduce an unpack primitive+traceable (also allow direct use) # * unpack upfront into shape[0] many keyarray slices # * return iter over these unpacked slices # Whatever we do, we'll want to do it by overriding # ShapedArray._iter when the element type is KeyTy... return (PRNGKeyArrayImpl(self.impl, k) for k in iter(self._base_array)) # TODO(frostig): are all of the stackable methods below (reshape, # concat, broadcast_to, expand_dims), and the stackable registration, # still needed? If, with some work, none are needed, then do we want # to remove stackables altogether? This may be the only application. def __repr__(self): return (f'{self.__class__.__name__}[{self.impl.tag}]' f' {{ {self._base_array} }}') def pprint(self): pp_keys = pp.text('shape = ') + pp.text(str(self.shape)) pp_impl = pp.text('impl = ') + self.impl.pprint() return str(pp.group( pp.text('PRNGKeyArray:') + pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl))) def copy(self): return self.__class__(self.impl, self._base_array.copy()) __hash__ = None # type: ignore[assignment] __array_priority__ = 100 # Overwritten immediately below @property def at(self) -> _IndexUpdateHelper: assert False @property def T(self) -> PRNGKeyArray: assert False def __getitem__(self, _) -> PRNGKeyArray: assert False def flatten(self, *_, **__) -> PRNGKeyArray: assert False def ravel(self, *_, **__) -> PRNGKeyArray: assert False def reshape(self, *_, **__) -> PRNGKeyArray: assert False def squeeze(self, *_, **__) -> PRNGKeyArray: assert False def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False def take(self, *_, **__) -> PRNGKeyArray: assert False def transpose(self, *_, **__) -> PRNGKeyArray: assert False _set_array_base_attributes(PRNGKeyArrayImpl, include=[ *(f"__{op}__" for op in _array_operators), 'at', 'flatten', 'ravel', 'reshape', 'squeeze', 'swapaxes', 'take', 'transpose', 'T']) basearray.Array.register(PRNGKeyArrayImpl) ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type] # TODO(frostig): remove, rerouting callers directly to random_seed def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl: return random_seed(seed, impl=impl) def keys_shaped_array(impl, shape): return core.ShapedArray(shape, KeyTy(impl)) # TODO(frostig): remove in favor of physical_aval call def keys_aval_to_base_arr_aval(keys_aval): return core.physical_aval(keys_aval) def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): if dispatch.is_single_device_sharding(sharding): return sharding elif isinstance(sharding, PmapSharding): key_shape = aval.dtype.impl.key_shape trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape) phys_sharding_spec = sharding_specs.ShardingSpec( sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), mesh_mapping=sharding.sharding_spec.mesh_mapping) return PmapSharding(devices=sharding.devices, sharding_spec=phys_sharding_spec) elif isinstance(sharding, NamedSharding): key_shape = aval.dtype.impl.key_shape trailing_spec = [None] * len(key_shape) return NamedSharding( sharding.mesh, PartitionSpec(*sharding.spec, *trailing_spec)) elif is_sharding_from_xla: return sharding else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( sharding._device_assignment, KeyTyRules.physical_hlo_sharding(aval, hlos)) class KeyTyRules: @staticmethod def full(shape, fill_value, dtype): physical_shape = (*shape, *dtype.impl.key_shape) if isinstance(fill_value, PRNGKeyArray): key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape) else: key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32')) # TODO(frostig,mattjj,vanderplas,lenamartens): consider this consumed from # the outset. return random_wrap(key_data, impl=dtype.impl) @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray(dtype.impl.key_shape, jnp.dtype('uint32')) @staticmethod def physical_const(val) -> Array: return val.unsafe_raw_array() @staticmethod def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: key_shape = aval.dtype.impl.key_shape op_sharding_proto = hlo_sharding.to_proto() # type: ignore new_op_sharding = op_sharding_proto.clone() tad = list(new_op_sharding.tile_assignment_dimensions) suffix = [tad.pop()] if op_sharding_proto.replicate_on_last_tile_dim else [] tad.extend([1] * len(key_shape) + suffix) new_op_sharding.tile_assignment_dimensions = tad return xc.HloSharding.from_proto(new_op_sharding) @staticmethod def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding: if dispatch.is_single_device_sharding(phys_sharding): return phys_sharding elif isinstance(phys_sharding, PmapSharding): key_shape = aval.dtype.impl.key_shape logical_sharding_spec = sharding_specs.ShardingSpec( sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)], mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) return PmapSharding(devices=phys_sharding.devices, sharding_spec=logical_sharding_spec) elif isinstance(phys_sharding, NamedSharding): key_shape = aval.dtype.impl.key_shape return pxla.create_mesh_pspec_sharding( phys_sharding.mesh, PartitionSpec(*phys_sharding.spec[:-len(key_shape)])) else: key_shape = aval.dtype.impl.key_shape phys_op_sharding = phys_sharding._to_xla_hlo_sharding( aval.ndim + len(key_shape)).to_proto() logical_op_sharding = phys_op_sharding.clone() tad = list(logical_op_sharding.tile_assignment_dimensions) tad = tad[:-len(key_shape)] logical_op_sharding.tile_assignment_dimensions = tad return GSPMDSharding(phys_sharding._device_assignment, xc.HloSharding.from_proto(logical_op_sharding)) @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): buf.aval = core.ShapedArray(buf.shape, buf.dtype) return PRNGKeyArrayImpl(aval.dtype.impl, buf) return handler @staticmethod def local_sharded_result_handler(aval, sharding, indices): phys_aval = core.physical_aval(aval) key_shape = aval.dtype.impl.key_shape phys_handler_maker = pxla.local_result_handlers[core.ShapedArray] # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): phys_sharding = make_key_array_phys_sharding( aval, sharding, is_sharding_from_xla=False) else: assert False, f'impossible sharding {sharding} in local sharded result handler' # set up grounded indices trailing_inds = [slice(None)] * len(key_shape) phys_indices = [(*inds, *trailing_inds) for inds in indices] # make a physical handler phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices) # set up a handler that calls the physical one and wraps back up def handler(bufs): return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs)) return handler @staticmethod def global_sharded_result_handler(aval, out_sharding, committed, is_out_sharding_from_xla): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_sharding = make_key_array_phys_sharding( aval, out_sharding, is_out_sharding_from_xla) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, is_out_sharding_from_xla) def handler(bufs): return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs)) return handler @staticmethod def make_sharded_array(aval, sharding, arrays, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_arrays = [random_unwrap(arr) for arr in arrays] phys_sharding = make_key_array_phys_sharding(aval, sharding, False) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False) phys_result = phys_handler(phys_arrays) return PRNGKeyArrayImpl(aval.dtype.impl, phys_result) @staticmethod def device_put_sharded(vals, aval, sharding, devices): physical_aval = keys_aval_to_base_arr_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) physical_sharding = make_key_array_phys_sharding(aval, sharding, False) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) return random_wrap(physical_result, impl=aval.dtype.impl) @staticmethod def device_put_replicated(val, aval, sharding, devices): physical_aval = keys_aval_to_base_arr_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) physical_sharding = make_key_array_phys_sharding(aval, sharding, False) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype.impl) class KeyTy: impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really _rules = KeyTyRules def __init__(self, impl): self.impl = impl @property def name(self) -> str: return f'key<{self.impl.tag}>' @property def itemsize(self) -> int: return math.prod(self.impl.key_shape) * np.dtype('uint32').itemsize def __repr__(self) -> str: return self.name def __eq__(self, other): return type(other) is KeyTy and self.impl == other.impl def __hash__(self) -> int: return hash((self.__class__, self.impl)) dtypes.opaque_dtypes.add(KeyTy) core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding): # TODO(frostig): Remove the need for `core.get_aval`. aval = core.get_aval(x) key_shape = aval.dtype.impl.key_shape arr = x.unsafe_raw_array() # TODO(yashkatariya,frostig): This assumes that the last dimensions are not # sharded. This is only true when enable_custom_prng is True. trailing_inds = [slice(None)] * len(key_shape) phys_indices = [(*inds, *trailing_inds) for inds in indices] phys_sharding = make_key_array_phys_sharding( aval, sharding, is_sharding_from_xla=False) return pxla.shard_arg_handlers[type(arr)]( arr, devices, phys_indices, phys_sharding ) pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler def key_array_constant_handler(x, canonicalize_dtypes): arr = x.unsafe_raw_array() return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes) mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler) # -- primitives def iterated_vmap_unary(n, f): for _ in range(n): f = api.vmap(f) return f # TODO(frostig): Revise the following two functions? These basically # undo the singleton dimensions added by `batching.defbroadcasting`. # It works, but introduces some possibly-redundant squeezes. Can we # borrow from other broadcasting primitives instead? def squeeze_vmap(f, left): def squeeze_vmap_f(x, y): if left: x = jnp.squeeze(x, axis=0) axes = (None, 0) else: y = jnp.squeeze(y, axis=0) axes = (0, None) return api.vmap(f, in_axes=axes, out_axes=0)(x, y) return squeeze_vmap_f def iterated_vmap_binary_bcast(shape1, shape2, f): ndim1, ndim2 = len(shape1), len(shape2) if ndim1 == ndim2 == 0: return f if 0 in [ndim1, ndim2]: if ndim1 == 0: return lambda x, y: iterated_vmap_unary(ndim2, lambda y: f(x, y))(y) else: return lambda x, y: iterated_vmap_unary(ndim1, lambda x: f(x, y))(x) assert len(shape1) == len(shape2) for sz1, sz2 in reversed(zip(shape1, shape2)): if sz1 == sz2: f = api.vmap(f, out_axes=0) else: assert sz1 == 1 or sz2 == 1, (sz1, sz2) f = squeeze_vmap(f, sz1 == 1) return f def random_seed(seeds, impl): # Avoid overflow error in X32 mode by first converting ints to int64. # This breaks JIT invariance for large ints, but supports the common # use-case of instantiating with Python hashes in X32 mode. if isinstance(seeds, int): seeds_arr = jnp.asarray(np.int64(seeds)) else: seeds_arr = jnp.asarray(seeds) return random_seed_p.bind(seeds_arr, impl=impl) random_seed_p = core.Primitive('random_seed') ad.defjvp_zero(random_seed_p) batching.defvectorized(random_seed_p) @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): return keys_shaped_array(impl, seeds_aval.shape) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): base_arr = random_seed_impl_base(seeds, impl=impl) return PRNGKeyArrayImpl(impl, base_arr) def random_seed_impl_base(seeds, *, impl): seed = iterated_vmap_unary(seeds.ndim, impl.seed) return seed(seeds) def random_seed_lowering(ctx, seeds, *, impl): aval, = ctx.avals_in seed = iterated_vmap_unary(aval.ndim, impl.seed) seed_lowering = mlir.lower_fun(seed, multiple_results=False) return mlir.delegate_lowering( ctx, seed_lowering, seeds, avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out)) mlir.register_lowering(random_seed_p, random_seed_lowering) def random_split(keys, count): return random_split_p.bind(keys, count=count) random_split_p = core.Primitive('random_split') ad.defjvp_zero(random_split_p) batching.defvectorized(random_split_p) @random_split_p.def_abstract_eval def random_split_abstract_eval(keys_aval, *, count): return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count)) @random_split_p.def_impl def random_split_impl(keys, *, count): base_arr = random_split_impl_base( keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count) return PRNGKeyArrayImpl(keys.impl, base_arr) def random_split_impl_base(impl, base_arr, keys_ndim, *, count): split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count)) return split(base_arr) def random_split_lowering(ctx, keys, *, count): aval, = ctx.avals_in impl = aval.dtype.impl split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, count)) split_lowering = mlir.lower_fun(split, multiple_results=False) return mlir.delegate_lowering( ctx, split_lowering, keys, avals_in=[keys_aval_to_base_arr_aval(aval)], avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out)) mlir.register_lowering(random_split_p, random_split_lowering) def random_fold_in(keys, msgs): return random_fold_in_p.bind(keys, jnp.asarray(msgs)) random_fold_in_p = core.Primitive('random_fold_in') ad.defjvp_zero(random_fold_in_p) batching.defbroadcasting(random_fold_in_p) @random_fold_in_p.def_abstract_eval def random_fold_in_abstract_eval(keys_aval, msgs_aval): shape = lax_internal.broadcasting_shape_rule( 'random_fold_in', keys_aval, msgs_aval) named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval) return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): base_arr = random_fold_in_impl_base( keys.impl, keys.unsafe_raw_array(), msgs, keys.shape) return PRNGKeyArrayImpl(keys.impl, base_arr) def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape): fold_in = iterated_vmap_binary_bcast( keys_shape, np.shape(msgs), impl.fold_in) return fold_in(base_arr, msgs) def random_fold_in_lowering(ctx, keys, msgs): keys_aval, msgs_aval = ctx.avals_in impl = keys_aval.dtype.impl fold_in = iterated_vmap_binary_bcast( keys_aval.shape, msgs_aval.shape, impl.fold_in) fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False) return mlir.delegate_lowering( ctx, fold_in_lowering, keys, msgs, avals_in=[keys_aval_to_base_arr_aval(keys_aval), msgs_aval], avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out)) mlir.register_lowering(random_fold_in_p, random_fold_in_lowering) def random_bits(keys, bit_width, shape): shape = core.as_named_shape(shape) for name, size in shape.named_items: # TODO(frostig,mattjj,apaszke): Is this real_size check necessary, # and is it meant to raise a user-facing ValueError? Should it be # an `assert` (or RuntimeError) instead? Why do we check it in # calls to `random_bits` instead of a more common paralleism path? real_size = lax.psum(1, name) if real_size != size: raise ValueError(f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) keys = random_fold_in(keys, axis_index) return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional) random_bits_p = core.Primitive('random_bits') ad.defjvp_zero(random_bits_p) batching.defvectorized(random_bits_p) @random_bits_p.def_abstract_eval def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') return core.ShapedArray(out_shape, out_dtype) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): return random_bits_impl_base(keys.impl, keys.unsafe_raw_array(), keys.ndim, bit_width=bit_width, shape=shape) def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape): bits = iterated_vmap_unary( keys_ndim, lambda k: impl.random_bits(k, bit_width, shape)) return bits(base_arr) def random_bits_lowering(ctx, keys, *, bit_width, shape): aval, = ctx.avals_in impl = aval.dtype.impl bits = iterated_vmap_unary( aval.ndim, lambda k: impl.random_bits(k, bit_width, shape)) bits_lowering = mlir.lower_fun(bits, multiple_results=False) ctx_new = ctx.replace(avals_in=[keys_aval_to_base_arr_aval(aval)]) out = bits_lowering(ctx_new, keys) ctx.set_tokens_out(ctx_new.tokens_out) return out mlir.register_lowering(random_bits_p, random_bits_lowering) # The following wrap/unwrap primitives are at least a stopgap for # backwards compatibility, namely when `config.jax_enable_custom_prng` # is False. We need to convert key arrays to and from underlying # uint32 base array, and we may need to do so under a jit. For # example, we want to support: # # keys = jax.jit(random.split)(key) # # where `key` and `keys` are both acceptably old-style uint32 arrays # so long as enable_custom_prng is False. The way we handle this is # that `random.split` adapts the input/output by converting to/from # key arrays across its call to `random_split`. So we rely on these # wrap/unwrap casting primitives to allow that conversion under jit. # # We may want to keep both around for testing and debugging escape # hatches. We can rename them `unsafe` for emphasis, and/or issue a # warning on entry to the traceable. # # TODO(frostig): Consider removal once we always enable_custom_prng. def random_wrap(base_arr, *, impl): _check_prng_key_data(impl, base_arr) return random_wrap_p.bind(base_arr, impl=impl) random_wrap_p = core.Primitive('random_wrap') ad.defjvp_zero(random_wrap_p) @random_wrap_p.def_abstract_eval def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) return keys_shaped_array(impl, shape) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): return PRNGKeyArrayImpl(impl, base_arr) def random_wrap_lowering(ctx, base_arr, *, impl): return [base_arr] def random_wrap_batch_rule(batched_args, batch_dims, *, impl): x, = batched_args d, = batch_dims x = batching.bdim_at_front(x, d, 1) return random_wrap(x, impl=impl), 0 mlir.register_lowering(random_wrap_p, random_wrap_lowering) batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule def random_unwrap(keys): if not isinstance(keys, PRNGKeyArrayImpl): raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}') return random_unwrap_p.bind(keys) random_unwrap_p = core.Primitive('random_unwrap') ad.defjvp_zero(random_unwrap_p) batching.defvectorized(random_unwrap_p) @random_unwrap_p.def_abstract_eval def random_unwrap_abstract_eval(keys_aval): return keys_aval_to_base_arr_aval(keys_aval) @random_unwrap_p.def_impl def random_unwrap_impl(keys): return keys.unsafe_raw_array() def random_unwrap_lowering(ctx, keys): return [keys] mlir.register_lowering(random_unwrap_p, random_unwrap_lowering) # -- threefry2x32 PRNG implementation def _is_threefry_prng_key(key: typing.Array) -> bool: try: return key.shape == (2,) and key.dtype == np.uint32 except AttributeError: return False def threefry_seed(seed: typing.Array) -> typing.Array: """Create a single raw threefry PRNG key from an integer seed. Args: seed: a 64- or 32-bit integer used as the value of the key. Returns: The PRNG key contents, modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros). """ return _threefry_seed(seed) @partial(jit, inline=True) def _threefry_seed(seed: typing.Array) -> typing.Array: if seed.shape: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed.dtype, np.integer): raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert( lax.shift_right_logical(seed, lax_internal._const(seed, 32))) with config_lib.numpy_dtype_promotion('standard'): # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit # inputs. We should avoid this. k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF))) return lax.concatenate([k1, k2], 0) def _make_rotate_left(dtype): if not jnp.issubdtype(dtype, np.integer): raise TypeError("_rotate_left only accepts integer dtypes.") nbits = np.array(jnp.iinfo(dtype).bits, dtype) def _rotate_left(x, d): if lax.dtype(d) != dtype: d = lax.convert_element_type(d, dtype) if lax.dtype(x) != dtype: x = lax.convert_element_type(x, dtype) return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d) return _rotate_left ### hash function and split def _threefry2x32_abstract_eval(*args): if any(a.dtype != jnp.uint32 for a in args): raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}" .format(args)) if all(isinstance(arg, core.ShapedArray) for arg in args): shape = lax_internal.broadcasting_shape_rule(*args) named_shape = core.join_named_shapes(*(a.named_shape for a in args)) aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape) else: aval = core.UnshapedArray(jnp.dtype(jnp.uint32)) return (aval,) * 2 rotate_left = _make_rotate_left(np.uint32) def apply_round(v, rot): v = v[:] v[0] = v[0] + v[1] v[1] = rotate_left(v[1], rot) v[1] = v[0] ^ v[1] return v def rotate_list(xs): return xs[1:] + xs[:1] def rolled_loop_step(i, state): x, ks, rotations = state for r in rotations[0]: x = apply_round(x, r) new_x = [x[0] + ks[0], x[1] + ks[1] + jnp.asarray(i + 1, dtype=np.uint32)] return new_x, rotate_list(ks), rotate_list(rotations) def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): """Apply the Threefry 2x32 hash. Args: keypair: a pair of 32bit unsigned integers used for the key. count: an array of dtype uint32 used for the counts. Returns: An array of dtype uint32 with the same shape as `count`. """ x = [x1, x2] rotations = [np.array([13, 15, 26, 6], dtype=np.uint32), np.array([17, 29, 16, 24], dtype=np.uint32)] ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)] x[0] = x[0] + ks[0] x[1] = x[1] + ks[1] if use_rolled_loops: x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations)) else: for r in rotations[0]: x = apply_round(x, r) x[0] = x[0] + ks[1] x[1] = x[1] + ks[2] + np.uint32(1) for r in rotations[1]: x = apply_round(x, r) x[0] = x[0] + ks[2] x[1] = x[1] + ks[0] + np.uint32(2) for r in rotations[0]: x = apply_round(x, r) x[0] = x[0] + ks[0] x[1] = x[1] + ks[1] + np.uint32(3) for r in rotations[1]: x = apply_round(x, r) x[0] = x[0] + ks[1] x[1] = x[1] + ks[2] + np.uint32(4) for r in rotations[0]: x = apply_round(x, r) x[0] = x[0] + ks[2] x[1] = x[1] + ks[0] + np.uint32(5) return tuple(x) def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2): aval_out, aval_out_2 = ctx.avals_out assert aval_out == aval_out_2 k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in rank = len(aval_out.shape) if 0 in aval_out.shape: zeros = mlir.full_like_aval(ctx, 0, aval_out) return [zeros, zeros] def _broadcast(x, aval): return mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=range(rank - len(aval.shape), rank)) out_len = reduce(op.mul, aval_out.shape, 1) if not core.is_constant_dim(out_len): length = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, [out_len])) length = mlir.hlo.ConvertOp( ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), length).result output_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) else: length = int(out_len) # will be passed statically output_shape = None if (jaxlib_version >= (0, 4, 9)): return lowering_func( (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, output_shape) else: if output_shape is not None: raise ValueError("native lowering with shape polymorphism " "for threefry on GPU requires jaxlib version 0.4.9") return lowering_func( (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length) threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) mlir.register_lowering(threefry2x32_p, mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), multiple_results=True)) mlir.register_lowering(threefry2x32_p, mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), multiple_results=True), platform='cpu') mlir.register_lowering( threefry2x32_p, partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32), platform='cuda') mlir.register_lowering( threefry2x32_p, partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32), platform='rocm') def iota_2x32_shape(shape): """Reshaped ``uint64`` iota, as two parallel ``uint32`` arrays. Setting aside representation, this function essentially computes the equivalent of:: jax.lax.iota(dtype=np.uint64, size=math.prod(shape)).reshape(shape) However: * It returns two parallel ``uint32`` arrays instead of one ``uint64`` array. This renders it invariant under either setting of the system-wide ``jax_enable_x64`` configuration flag. * It lowers in a way such that the compiler's automatic SPMD partitioner recognizes its partitionability. For example:: >>> import numpy as np >>> from jax import lax >>> from jax._src import prng >>> prng.iota_2x32_shape((3, 4)) [Array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=uint32), Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=uint32)] >>> def reshaped_iota(shape): ... return lax.iota(size=math.prod(shape), dtype=np.uint32).reshape(shape) ... >>> reshaped_iota((3, 4)) Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=uint32) Args: shape: the output shape Returns: A pair of ``uint32`` arrays ``(counts_hi, counts_lo)``, both of shape ``shape``, representing the higher-order and lower-order 32 bits of the 64 bit unsigned iota. """ if len(shape) == 0: return (jnp.zeros((), np.dtype('uint32')),) * 2 return iota_2x32_shape_p.bind(shape=shape) iota_2x32_shape_p = core.Primitive('iota_2x32_shape') iota_2x32_shape_p.multiple_results = True iota_2x32_shape_p.def_impl(partial(dispatch.apply_primitive, iota_2x32_shape_p)) @iota_2x32_shape_p.def_abstract_eval def iota_2x32_shape_abstract_eval(*, shape): return (core.ShapedArray(shape, np.dtype('uint32')),) * 2 def bcast_iotas_to_reshaped_iota( add: Callable[[ir.Value, ir.Value], ir.Value], mul: Callable[[core.DimSize, ir.Value], ir.Value], shape: core.Shape, iotas: Sequence[ir.Value]) -> ir.Value: strides: core.Shape = (*(np.cumprod(shape[1:][::-1])[::-1]), 1) # type: ignore return reduce(add, [mul(s, i) for i, s in zip(iotas, strides)]) # type: ignore def iota_2x32_shape_lowering(ctx, *, shape): aval_out, _ = ctx.avals_out aval_u64 = core.ShapedArray(shape, np.dtype('uint64')) def _add(x: ir.Value, y: ir.Value) -> ir.Value: return mlir.hlo.AddOp(x, y).result def _mul(x: core.DimSize, y: ir.Value) -> ir.Value: if core.is_constant_dim(x): x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')), canonicalize_types=False) else: x_const, = mlir.eval_dynamic_shape(ctx, (x,)) x_const = hlo.ConvertOp( ir.RankedTensorType.get( (), mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const).result x_bcast = mlir.broadcast_in_dim(ctx, x_const, aval_u64, broadcast_dimensions=[]) return mlir.hlo.MulOp(x_bcast, y).result assert len(shape) > 0 iotas = [mlir.iota(ctx, aval_u64, dimension=dimension) for dimension in range(len(shape))] counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas) shift = mlir.ir_constant(np.array(32, np.dtype('uint64')), canonicalize_types=False) shift = mlir.broadcast_in_dim(ctx, shift, aval_u64, broadcast_dimensions=[]) counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts_shifted).result return counts_hi, counts_lo mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering) @partial(jit, inline=True) def threefry_2x32(keypair, count): """Apply the Threefry 2x32 hash. Args: keypair: a pair of 32bit unsigned integers used for the key. count: an array of dtype uint32 used for the counts. Returns: An array of dtype uint32 with the same shape as `count`. """ key1, key2 = keypair if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32: msg = "threefry_2x32 requires uint32 arguments, got {}" raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]])) odd_size = count.size % 2 if not isinstance(odd_size, int): msg = ("jax.random functions have limited support for shape polymorphism. " "In particular, the product of the known dimensions must be even.") raise core.InconclusiveDimensionOperation(msg) if odd_size: x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2)) else: x = list(jnp.split(count.ravel(), 2)) x = threefry2x32_p.bind(key1, key2, x[0], x[1]) out = jnp.concatenate(x) assert out.dtype == np.uint32 return lax.reshape(out[:-1] if odd_size else out, count.shape) def threefry_split(key: typing.Array, num: core.DimSize) -> typing.Array: if not core.is_special_dim_size(num): num = core.concrete_or_error(op.index, num) if config.jax_threefry_partitionable: return _threefry_split_foldlike(key, num) # type: ignore else: return _threefry_split_original(key, num) # type: ignore @partial(jit, static_argnums=(1,), inline=True) def _threefry_split_original(key, num) -> typing.Array: counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (num, 2)) @partial(jit, static_argnums=(1,), inline=True) def _threefry_split_foldlike(key, num) -> typing.Array: k1, k2 = key counts1, counts2 = iota_2x32_shape((num,)) bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2) return jnp.stack([bits1, bits2], axis=1) def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape return _threefry_fold_in(key, jnp.uint32(data)) @partial(jit, inline=True) def _threefry_fold_in(key, data): return threefry_2x32(key, threefry_seed(data)) def threefry_random_bits(key: typing.Array, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("threefry_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") if config.jax_threefry_partitionable: return _threefry_random_bits_partitionable(key, bit_width, shape) else: return _threefry_random_bits_original(key, bit_width, shape) def _threefry_random_bits_partitionable(key: typing.Array, bit_width, shape): if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64: raise NotImplementedError('random bits array of size exceeding 2 ** 64') k1, k2 = key counts1, counts2 = iota_2x32_shape(shape) bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits_hi = lax.convert_element_type(bits1, dtype) bits_lo = lax.convert_element_type(bits2, dtype) return lax.shift_left(bits_hi, dtype(32)) | bits_lo elif bit_width == 32: return bits1 ^ bits2 else: return lax.convert_element_type(bits1 ^ bits2, dtype) @partial(jit, static_argnums=(1, 2), inline=True) def _threefry_random_bits_original(key: typing.Array, bit_width, shape): size = math.prod(shape) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1,)), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0) ) ) ) bits = lax.reshape(bits, ((max_count * 32 // bit_width),), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape) threefry_prng_impl = PRNGImpl( key_shape=(2,), seed=threefry_seed, split=threefry_split, random_bits=threefry_random_bits, fold_in=threefry_fold_in, tag='fry') # -- RngBitGenerator PRNG implementation # This code is experimental! # https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator # Notice that the RngBitGenerator operations are not guaranteed to be # stable/deterministic across backends or compiler versions. Correspondingly, we # reserve the right to change any of these implementations at any time! def _rbg_seed(seed: typing.Array) -> typing.Array: assert not seed.shape halfkey = threefry_seed(seed) return jnp.concatenate([halfkey, halfkey]) def _rbg_split(key: typing.Array, num: int) -> typing.Array: if config.jax_threefry_partitionable: _threefry_split = _threefry_split_foldlike else: _threefry_split = _threefry_split_original return vmap( _threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4) def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4) def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int] ) -> typing.Array: if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'): raise TypeError("_rbg_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") _, bits = lax.rng_bit_generator(key, shape, dtype=UINT_DTYPES[bit_width]) return bits rbg_prng_impl = PRNGImpl( key_shape=(4,), seed=_rbg_seed, split=_rbg_split, random_bits=_rbg_random_bits, fold_in=_rbg_fold_in, tag='rbg') def _unsafe_rbg_split(key: typing.Array, num: int) -> typing.Array: # treat 10 iterations of random bits as a 'hash function' _, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32') return lax.slice_in_dim(keys, start_index=None, limit_index=None, stride=10) def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: assert not data.shape _, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32') return key ^ random_bits[-1] unsafe_rbg_prng_impl = PRNGImpl( key_shape=(4,), seed=_rbg_seed, split=_unsafe_rbg_split, random_bits=_rbg_random_bits, fold_in=_unsafe_rbg_fold_in, tag='urbg')