Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/mixed_precision/test_util.py
2023-06-19 00:49:18 +02:00

243 lines
8.3 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains testing utilities related to mixed precision."""
import tensorflow.compat.v2 as tf
from keras import regularizers
from keras.engine import base_layer
def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
"""Returns a function that asserts it's gradient has a certain value.
This serves as a hook to assert intermediate gradients have a certain value.
This returns an identity function. The identity's gradient function is also
the identity function, except it asserts that the gradient equals
`expected_gradient` and has dtype `expected_dtype`.
Args:
expected_gradient: The gradient function asserts that the gradient is this
value.
expected_dtype: The gradient function asserts the gradient has this dtype.
Returns:
An identity function whose gradient function asserts the gradient has a
certain value.
"""
@tf.custom_gradient
def _identity_with_grad_check(x):
"""Function that asserts it's gradient has a certain value."""
x = tf.identity(x)
def grad(dx):
"""Gradient function that asserts the gradient has a certain
value."""
if expected_dtype:
assert (
dx.dtype == expected_dtype
), f"dx.dtype should be {expected_dtype} but is: {dx.dtype}"
expected_tensor = tf.convert_to_tensor(
expected_gradient, dtype=dx.dtype, name="expected_gradient"
)
# Control dependency is to ensure input is available. It's possible
# the dataset will throw a StopIteration to indicate there is no
# more data, in which case we don't want to run the assertion.
with tf.control_dependencies([x]):
assert_op = tf.compat.v1.assert_equal(dx, expected_tensor)
with tf.control_dependencies([assert_op]):
dx = tf.identity(dx)
return dx
return x, grad
# Keras sometimes has trouble serializing Lambda layers with a decorated
# function. So we define and return a non-decorated function.
def identity_with_grad_check(x):
return _identity_with_grad_check(x)
return identity_with_grad_check
def create_identity_with_nan_gradients_fn(have_nan_gradients):
"""Returns a function that optionally has NaN gradients.
This serves as a hook to introduce NaN gradients to a model. This returns an
identity function. The identity's gradient function will check if the
boolean tensor `have_nan_gradients` is True. If so, the gradient will be
NaN. Otherwise, the gradient will also be the identity.
Args:
have_nan_gradients: A scalar boolean tensor. If True, gradients will be
NaN. Otherwise, the gradient function is the identity function.
Returns:
An identity function whose gradient function will return NaNs, if
`have_nan_gradients` is True.
"""
@tf.custom_gradient
def _identity_with_nan_gradients(x):
"""Function whose gradient is NaN iff `have_nan_gradients` is True."""
x = tf.identity(x)
def grad(dx):
return tf.cond(
have_nan_gradients, lambda: dx * float("NaN"), lambda: dx
)
return x, grad
# Keras sometimes has trouble serializing Lambda layers with a decorated
# function. So we define and return a non-decorated function.
def identity_with_nan_gradients(x):
return _identity_with_nan_gradients(x)
return identity_with_nan_gradients
class AssertTypeLayer(base_layer.Layer):
"""A layer which asserts it's inputs are a certain type."""
def __init__(self, assert_type=None, **kwargs):
self._assert_type = (
tf.as_dtype(assert_type).name if assert_type else None
)
super().__init__(**kwargs)
def assert_input_types(self, inputs):
"""Asserts `inputs` are of the correct type. Should be called in
call()."""
if self._assert_type:
inputs_flattened = tf.nest.flatten(inputs)
for inp in inputs_flattened:
assert inp.dtype.base_dtype == self._assert_type, (
"Input tensor has type %s which does "
"not match assert type %s"
% (inp.dtype.name, self._assert_type)
)
class MultiplyLayer(AssertTypeLayer):
"""A layer which multiplies its input by a scalar variable."""
def __init__(
self,
regularizer=None,
activity_regularizer=None,
use_operator=False,
var_name="v",
**kwargs,
):
"""Initializes the MultiplyLayer.
Args:
regularizer: The weight regularizer on the scalar variable.
activity_regularizer: The activity regularizer.
use_operator: If True, add using the * operator. If False, add using
tf.multiply.
var_name: The name of the variable. It can be useful to pass a name
other than 'v', to test having the attribute name (self.v) being
different from the variable name.
**kwargs: Passed to AssertTypeLayer constructor.
"""
self._regularizer = regularizer
if isinstance(regularizer, dict):
self._regularizer = regularizers.deserialize(
regularizer, custom_objects=globals()
)
self._activity_regularizer = activity_regularizer
if isinstance(activity_regularizer, dict):
self._activity_regularizer = regularizers.deserialize(
activity_regularizer, custom_objects=globals()
)
self._use_operator = use_operator
self._var_name = var_name
super().__init__(
activity_regularizer=self._activity_regularizer, **kwargs
)
def build(self, _):
self.v = self.add_weight(
self._var_name,
(),
initializer="ones",
regularizer=self._regularizer,
)
self.built = True
def call(self, inputs):
self.assert_input_types(inputs)
return self._multiply(inputs, self.v)
def _multiply(self, x, y):
if self._use_operator:
return x * y
else:
return tf.multiply(x, y)
def get_config(self):
config = super().get_config()
config["regularizer"] = regularizers.serialize(self._regularizer)
config["activity_regularizer"] = regularizers.serialize(
self._activity_regularizer
)
config["use_operator"] = self._use_operator
config["var_name"] = self._var_name
config["assert_type"] = self._assert_type
return config
class MultiplyLayerWithoutAutoCast(MultiplyLayer):
"""Same as MultiplyLayer, but does not use AutoCastVariables."""
def build(self, _):
dtype = self.dtype
if dtype in ("float16", "bfloat16"):
dtype = "float32"
self.v = self.add_weight(
"v",
(),
initializer="ones",
dtype=dtype,
experimental_autocast=False,
regularizer=self._regularizer,
)
self.built = True
def call(self, inputs):
self.assert_input_types(inputs)
assert self.v.dtype in (tf.float32, tf.float64)
return self._multiply(inputs, tf.cast(self.v, inputs.dtype))
class IdentityRegularizer(regularizers.Regularizer):
def __call__(self, x):
assert x.dtype == tf.float32
return tf.identity(x)
def get_config(self):
return {}
class ReduceSumRegularizer(regularizers.Regularizer):
def __call__(self, x):
return tf.reduce_sum(x)
def get_config(self):
return {}