"""The DetermisticRandomTestTool. (from www.tensorflow.org/guide/migrate/validate_correctness) is a tool used to make random number generation semantics match between TF1.x graphs/sessions and eager execution. """ import sys import tensorflow.compat.v2 as tf # isort: off from tensorflow.python.util.tf_export import keras_export @keras_export(v1=["keras.utils.DeterministicRandomTestTool"]) class DeterministicRandomTestTool(object): """DeterministicRandomTestTool is a testing tool. This tool is used to validate random number generation semantics match between TF1.x graphs/sessions and eager execution. This is useful when you are migrating from TF 1.x to TF2 and need to make sure your computation is still happening correctly along the way. See the validating correctness migration guide for more info: https://www.tensorflow.org/guide/migrate/validate_correctness The following DeterministicRandomTestTool object provides a context manager scope() that can make stateful random operations use the same seed across both TF1 graphs/sessions and eager execution,The tool provides two testing modes: - constant which uses the same seed for every single operation no matter how many times it has been called and, - num_random_ops which uses the number of previously-observed stateful random operations as the operation seed. The num_random_ops mode serves as a more sensitive validation check than the constant mode. It ensures that the random numbers initialization does not get accidentaly reused.(for example if several weights take on the same initializations), you can use the num_random_ops mode to avoid this. In the num_random_ops mode, the generated random numbers will depend on the ordering of random ops in the program. This applies both to the stateful random operations used for creating and initializing variables, and to the stateful random operations used in computation (such as for dropout layers). """ def __init__(self, seed: int = 42, mode="constant"): """Set mode to 'constant' or 'num_random_ops'. Defaults to 'constant'.""" if mode not in {"constant", "num_random_ops"}: raise ValueError( "Mode arg must be 'constant' or 'num_random_ops'. " + f"Got: {mode}" ) self.seed_implementation = sys.modules[tf.compat.v1.get_seed.__module__] self._mode = mode self._seed = seed self.operation_seed = 0 self._observed_seeds = set() @property def operation_seed(self): return self._operation_seed @operation_seed.setter def operation_seed(self, value): self._operation_seed = value def scope(self): """set random seed.""" tf.random.set_seed(self._seed) def _get_seed(_): """Wraps TF get_seed to make deterministic random generation easier. This makes a variable's initialization (and calls that involve random number generation) depend only on how many random number generations were used in the scope so far, rather than on how many unrelated operations the graph contains. Returns: Random seed tuple. """ op_seed = self._operation_seed if self._mode == "constant": tf.random.set_seed(op_seed) else: if op_seed in self._observed_seeds: raise ValueError( "This `DeterministicRandomTestTool` " "object is trying to re-use the " + f"already-used operation seed {op_seed}. " + "It cannot guarantee random numbers will match " + "between eager and sessions when an operation seed " + "is reused. You most likely set " + "`operation_seed` explicitly but used a value that " + "caused the naturally-incrementing operation seed " + "sequences to overlap with an already-used seed." ) self._observed_seeds.add(op_seed) self._operation_seed += 1 return (self._seed, op_seed) # mock.patch internal symbols to modify the behavior of TF APIs relying # on them return tf.compat.v1.test.mock.patch.object( self.seed_implementation, "get_seed", wraps=_get_seed )