209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
|
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utilities for random ops to share common usages."""
|
||
|
|
||
|
import enum
|
||
|
|
||
|
from tensorflow.python.framework import constant_op
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import tensor
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import array_ops_stack
|
||
|
from tensorflow.python.ops import bitwise_ops
|
||
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
@tf_export("random.Algorithm", "random.experimental.Algorithm")
|
||
|
class Algorithm(enum.Enum):
|
||
|
"""A random-number-generation (RNG) algorithm.
|
||
|
|
||
|
Many random-number generators (e.g. the `alg` argument of
|
||
|
`tf.random.Generator` and `tf.random.stateless_uniform`) in TF allow
|
||
|
you to choose the algorithm used to generate the (pseudo-)random
|
||
|
numbers. You can set the algorithm to be one of the options below.
|
||
|
|
||
|
* `PHILOX`: The Philox algorithm introduced in the paper ["Parallel
|
||
|
Random Numbers: As Easy as 1, 2,
|
||
|
3"](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf).
|
||
|
* `THREEFRY`: The ThreeFry algorithm introduced in the paper
|
||
|
["Parallel Random Numbers: As Easy as 1, 2,
|
||
|
3"](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf).
|
||
|
* `AUTO_SELECT`: Allow TF to automatically select the algorithm
|
||
|
depending on the accelerator device. Note that with this option,
|
||
|
running the same TF program on different devices may result in
|
||
|
different random numbers. Also note that TF may select an
|
||
|
algorithm that is different from `PHILOX` and `THREEFRY`.
|
||
|
"""
|
||
|
|
||
|
# The numbers here must match framework/rng_alg.h
|
||
|
PHILOX = 1
|
||
|
THREEFRY = 2
|
||
|
AUTO_SELECT = 3
|
||
|
|
||
|
|
||
|
def convert_alg_to_int(alg):
|
||
|
"""Converts algorithm to an integer.
|
||
|
|
||
|
Args:
|
||
|
alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
|
||
|
strings are "philox" and "threefry".
|
||
|
|
||
|
Returns:
|
||
|
An integer, unless the input is a Tensor in which case a Tensor is returned.
|
||
|
"""
|
||
|
if isinstance(alg, int):
|
||
|
return alg
|
||
|
if isinstance(alg, Algorithm):
|
||
|
return alg.value
|
||
|
if isinstance(alg, tensor.Tensor):
|
||
|
return alg
|
||
|
if isinstance(alg, str):
|
||
|
# canonicalized alg
|
||
|
canon_alg = alg.strip().lower().replace("-", "").replace("_", "")
|
||
|
if canon_alg == "philox":
|
||
|
return Algorithm.PHILOX.value
|
||
|
elif canon_alg == "threefry":
|
||
|
return Algorithm.THREEFRY.value
|
||
|
elif canon_alg == "autoselect":
|
||
|
return Algorithm.AUTO_SELECT.value
|
||
|
else:
|
||
|
raise ValueError(unsupported_alg_error_msg(alg))
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
f"Can't convert argument `alg` (of value {alg} and type {type(alg)}) "
|
||
|
"to int."
|
||
|
)
|
||
|
|
||
|
|
||
|
def _get_key_counter(seed, alg):
|
||
|
"""Calculates the key and counter to pass to raw RNG ops.
|
||
|
|
||
|
This function calculates the key and counter that will be passed to
|
||
|
the raw RNG ops like `StatelessRandomUniformV2`. Depending on the
|
||
|
input `alg`, the key and counter may be scrambled or copied from
|
||
|
`seed`. If `alg` is `"auto_select"`, the key and counter will be
|
||
|
determined at runtime based on device type.
|
||
|
|
||
|
Args:
|
||
|
seed: An integer tensor of shape [2]. The seed to calculate the key and
|
||
|
counter from.
|
||
|
alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
|
||
|
explanation.
|
||
|
|
||
|
Returns:
|
||
|
A pair (key, counter) suitable for V2 stateless RNG ops like
|
||
|
`StatelessRandomUniformV2`.
|
||
|
"""
|
||
|
if alg == Algorithm.AUTO_SELECT.value:
|
||
|
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
|
||
|
seed
|
||
|
)
|
||
|
elif alg == Algorithm.PHILOX.value:
|
||
|
key, counter = _philox_scramble_seed(seed)
|
||
|
elif alg == Algorithm.THREEFRY.value:
|
||
|
key = array_ops.reshape(
|
||
|
_uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1]
|
||
|
)
|
||
|
counter = array_ops.zeros([1], dtypes.uint64)
|
||
|
else:
|
||
|
raise ValueError(unsupported_alg_error_msg(alg))
|
||
|
return key, counter
|
||
|
|
||
|
|
||
|
def get_key_counter_alg(seed, alg):
|
||
|
"""Calculates the key, counter and algorithm to pass to raw RNG ops.
|
||
|
|
||
|
This function calculates the key and counter, and determines the algorithm
|
||
|
that will be passed to the raw RNG ops like `StatelessRandomUniformV2`.
|
||
|
Depending on the input `alg`, the key and counter may be scrambled or copied
|
||
|
from `seed`. If `alg` is `"auto_select"`, the key and counter will be
|
||
|
determined at runtime based on device type.
|
||
|
|
||
|
Args:
|
||
|
seed: An integer tensor of shape [2]. The seed to calculate the key and
|
||
|
counter from.
|
||
|
alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
|
||
|
explanation.
|
||
|
|
||
|
Returns:
|
||
|
A pair (key, counter, algorithm) suitable for V2 stateless RNG ops like
|
||
|
`StatelessRandomUniformV2`.
|
||
|
"""
|
||
|
if alg is None:
|
||
|
alg = Algorithm.AUTO_SELECT.value
|
||
|
alg = convert_alg_to_int(alg)
|
||
|
key, counter = _get_key_counter(seed, alg)
|
||
|
return key, counter, alg
|
||
|
|
||
|
|
||
|
def _uint32s_to_uint64(x):
|
||
|
return bitwise_ops.bitwise_or(
|
||
|
math_ops.cast(x[0], dtypes.uint64),
|
||
|
bitwise_ops.left_shift(
|
||
|
math_ops.cast(x[1], dtypes.uint64),
|
||
|
constant_op.constant(32, dtypes.uint64),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
def unsupported_alg_error_msg(alg):
|
||
|
"""Produces the unsupported-algorithm error message."""
|
||
|
if isinstance(alg, int):
|
||
|
philox = Algorithm.PHILOX.value
|
||
|
threefry = Algorithm.THREEFRY.value
|
||
|
auto_select = Algorithm.AUTO_SELECT.value
|
||
|
elif isinstance(alg, str):
|
||
|
philox = "philox"
|
||
|
threefry = "threefry"
|
||
|
auto_select = "auto_select"
|
||
|
else:
|
||
|
philox = Algorithm.PHILOX
|
||
|
threefry = Algorithm.THREEFRY
|
||
|
auto_select = Algorithm.AUTO_SELECT
|
||
|
return (
|
||
|
f"Argument `alg` got unsupported value {alg}. Supported values are "
|
||
|
f"{philox} for the Philox algorithm, "
|
||
|
f"{threefry} for the ThreeFry algorithm, and "
|
||
|
f"{auto_select} for auto-selection."
|
||
|
)
|
||
|
|
||
|
|
||
|
def _philox_scramble_seed(seed):
|
||
|
"""Determines the key and counter for Philox PRNG with the given seed.
|
||
|
|
||
|
Args:
|
||
|
seed: An integer tensor of shape [2]. The seed to calculate the key and
|
||
|
counter from.
|
||
|
|
||
|
Returns:
|
||
|
A pair (key, counter) suitable for V2 stateless RNG ops like
|
||
|
`StatelessRandomUniformV2`.
|
||
|
"""
|
||
|
# the same scrambling procedure as core/kernels/stateless_random_ops.cc
|
||
|
key = constant_op.constant([0x02461E293EC8F720], dtypes.uint64)
|
||
|
counter = math_ops.cast(seed, dtypes.uint64)
|
||
|
mix = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
|
||
|
[4],
|
||
|
key=key,
|
||
|
counter=counter,
|
||
|
dtype=dtypes.uint32,
|
||
|
alg=Algorithm.PHILOX.value,
|
||
|
)
|
||
|
key = array_ops.reshape(_uint32s_to_uint64(mix[:2]), [1])
|
||
|
counter = array_ops_stack.stack([0, _uint32s_to_uint64(mix[2:])], axis=0)
|
||
|
return key, counter
|