3RNN/Lib/site-packages/tensorflow/python/ops/random_ops_util.py

209 lines
7.2 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for random ops to share common usages."""
import enum
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export("random.Algorithm", "random.experimental.Algorithm")
class Algorithm(enum.Enum):
"""A random-number-generation (RNG) algorithm.
Many random-number generators (e.g. the `alg` argument of
`tf.random.Generator` and `tf.random.stateless_uniform`) in TF allow
you to choose the algorithm used to generate the (pseudo-)random
numbers. You can set the algorithm to be one of the options below.
* `PHILOX`: The Philox algorithm introduced in the paper ["Parallel
Random Numbers: As Easy as 1, 2,
3"](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf).
* `THREEFRY`: The ThreeFry algorithm introduced in the paper
["Parallel Random Numbers: As Easy as 1, 2,
3"](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf).
* `AUTO_SELECT`: Allow TF to automatically select the algorithm
depending on the accelerator device. Note that with this option,
running the same TF program on different devices may result in
different random numbers. Also note that TF may select an
algorithm that is different from `PHILOX` and `THREEFRY`.
"""
# The numbers here must match framework/rng_alg.h
PHILOX = 1
THREEFRY = 2
AUTO_SELECT = 3
def convert_alg_to_int(alg):
"""Converts algorithm to an integer.
Args:
alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
strings are "philox" and "threefry".
Returns:
An integer, unless the input is a Tensor in which case a Tensor is returned.
"""
if isinstance(alg, int):
return alg
if isinstance(alg, Algorithm):
return alg.value
if isinstance(alg, tensor.Tensor):
return alg
if isinstance(alg, str):
# canonicalized alg
canon_alg = alg.strip().lower().replace("-", "").replace("_", "")
if canon_alg == "philox":
return Algorithm.PHILOX.value
elif canon_alg == "threefry":
return Algorithm.THREEFRY.value
elif canon_alg == "autoselect":
return Algorithm.AUTO_SELECT.value
else:
raise ValueError(unsupported_alg_error_msg(alg))
else:
raise TypeError(
f"Can't convert argument `alg` (of value {alg} and type {type(alg)}) "
"to int."
)
def _get_key_counter(seed, alg):
"""Calculates the key and counter to pass to raw RNG ops.
This function calculates the key and counter that will be passed to
the raw RNG ops like `StatelessRandomUniformV2`. Depending on the
input `alg`, the key and counter may be scrambled or copied from
`seed`. If `alg` is `"auto_select"`, the key and counter will be
determined at runtime based on device type.
Args:
seed: An integer tensor of shape [2]. The seed to calculate the key and
counter from.
alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
explanation.
Returns:
A pair (key, counter) suitable for V2 stateless RNG ops like
`StatelessRandomUniformV2`.
"""
if alg == Algorithm.AUTO_SELECT.value:
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
seed
)
elif alg == Algorithm.PHILOX.value:
key, counter = _philox_scramble_seed(seed)
elif alg == Algorithm.THREEFRY.value:
key = array_ops.reshape(
_uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1]
)
counter = array_ops.zeros([1], dtypes.uint64)
else:
raise ValueError(unsupported_alg_error_msg(alg))
return key, counter
def get_key_counter_alg(seed, alg):
"""Calculates the key, counter and algorithm to pass to raw RNG ops.
This function calculates the key and counter, and determines the algorithm
that will be passed to the raw RNG ops like `StatelessRandomUniformV2`.
Depending on the input `alg`, the key and counter may be scrambled or copied
from `seed`. If `alg` is `"auto_select"`, the key and counter will be
determined at runtime based on device type.
Args:
seed: An integer tensor of shape [2]. The seed to calculate the key and
counter from.
alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
explanation.
Returns:
A pair (key, counter, algorithm) suitable for V2 stateless RNG ops like
`StatelessRandomUniformV2`.
"""
if alg is None:
alg = Algorithm.AUTO_SELECT.value
alg = convert_alg_to_int(alg)
key, counter = _get_key_counter(seed, alg)
return key, counter, alg
def _uint32s_to_uint64(x):
return bitwise_ops.bitwise_or(
math_ops.cast(x[0], dtypes.uint64),
bitwise_ops.left_shift(
math_ops.cast(x[1], dtypes.uint64),
constant_op.constant(32, dtypes.uint64),
),
)
def unsupported_alg_error_msg(alg):
"""Produces the unsupported-algorithm error message."""
if isinstance(alg, int):
philox = Algorithm.PHILOX.value
threefry = Algorithm.THREEFRY.value
auto_select = Algorithm.AUTO_SELECT.value
elif isinstance(alg, str):
philox = "philox"
threefry = "threefry"
auto_select = "auto_select"
else:
philox = Algorithm.PHILOX
threefry = Algorithm.THREEFRY
auto_select = Algorithm.AUTO_SELECT
return (
f"Argument `alg` got unsupported value {alg}. Supported values are "
f"{philox} for the Philox algorithm, "
f"{threefry} for the ThreeFry algorithm, and "
f"{auto_select} for auto-selection."
)
def _philox_scramble_seed(seed):
"""Determines the key and counter for Philox PRNG with the given seed.
Args:
seed: An integer tensor of shape [2]. The seed to calculate the key and
counter from.
Returns:
A pair (key, counter) suitable for V2 stateless RNG ops like
`StatelessRandomUniformV2`.
"""
# the same scrambling procedure as core/kernels/stateless_random_ops.cc
key = constant_op.constant([0x02461E293EC8F720], dtypes.uint64)
counter = math_ops.cast(seed, dtypes.uint64)
mix = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
[4],
key=key,
counter=counter,
dtype=dtypes.uint32,
alg=Algorithm.PHILOX.value,
)
key = array_ops.reshape(_uint32s_to_uint64(mix[:2]), [1])
counter = array_ops_stack.stack([0, _uint32s_to_uint64(mix[2:])], axis=0)
return key, counter