243 lines
8.3 KiB
Python
243 lines
8.3 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains testing utilities related to mixed precision."""
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import regularizers
|
|
from keras.engine import base_layer
|
|
|
|
|
|
def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
|
"""Returns a function that asserts it's gradient has a certain value.
|
|
|
|
This serves as a hook to assert intermediate gradients have a certain value.
|
|
This returns an identity function. The identity's gradient function is also
|
|
the identity function, except it asserts that the gradient equals
|
|
`expected_gradient` and has dtype `expected_dtype`.
|
|
|
|
Args:
|
|
expected_gradient: The gradient function asserts that the gradient is this
|
|
value.
|
|
expected_dtype: The gradient function asserts the gradient has this dtype.
|
|
|
|
Returns:
|
|
An identity function whose gradient function asserts the gradient has a
|
|
certain value.
|
|
"""
|
|
|
|
@tf.custom_gradient
|
|
def _identity_with_grad_check(x):
|
|
"""Function that asserts it's gradient has a certain value."""
|
|
x = tf.identity(x)
|
|
|
|
def grad(dx):
|
|
"""Gradient function that asserts the gradient has a certain
|
|
value."""
|
|
if expected_dtype:
|
|
assert (
|
|
dx.dtype == expected_dtype
|
|
), f"dx.dtype should be {expected_dtype} but is: {dx.dtype}"
|
|
expected_tensor = tf.convert_to_tensor(
|
|
expected_gradient, dtype=dx.dtype, name="expected_gradient"
|
|
)
|
|
# Control dependency is to ensure input is available. It's possible
|
|
# the dataset will throw a StopIteration to indicate there is no
|
|
# more data, in which case we don't want to run the assertion.
|
|
with tf.control_dependencies([x]):
|
|
assert_op = tf.compat.v1.assert_equal(dx, expected_tensor)
|
|
with tf.control_dependencies([assert_op]):
|
|
dx = tf.identity(dx)
|
|
return dx
|
|
|
|
return x, grad
|
|
|
|
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
|
# function. So we define and return a non-decorated function.
|
|
def identity_with_grad_check(x):
|
|
return _identity_with_grad_check(x)
|
|
|
|
return identity_with_grad_check
|
|
|
|
|
|
def create_identity_with_nan_gradients_fn(have_nan_gradients):
|
|
"""Returns a function that optionally has NaN gradients.
|
|
|
|
This serves as a hook to introduce NaN gradients to a model. This returns an
|
|
identity function. The identity's gradient function will check if the
|
|
boolean tensor `have_nan_gradients` is True. If so, the gradient will be
|
|
NaN. Otherwise, the gradient will also be the identity.
|
|
|
|
Args:
|
|
have_nan_gradients: A scalar boolean tensor. If True, gradients will be
|
|
NaN. Otherwise, the gradient function is the identity function.
|
|
|
|
Returns:
|
|
An identity function whose gradient function will return NaNs, if
|
|
`have_nan_gradients` is True.
|
|
"""
|
|
|
|
@tf.custom_gradient
|
|
def _identity_with_nan_gradients(x):
|
|
"""Function whose gradient is NaN iff `have_nan_gradients` is True."""
|
|
x = tf.identity(x)
|
|
|
|
def grad(dx):
|
|
return tf.cond(
|
|
have_nan_gradients, lambda: dx * float("NaN"), lambda: dx
|
|
)
|
|
|
|
return x, grad
|
|
|
|
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
|
# function. So we define and return a non-decorated function.
|
|
def identity_with_nan_gradients(x):
|
|
return _identity_with_nan_gradients(x)
|
|
|
|
return identity_with_nan_gradients
|
|
|
|
|
|
class AssertTypeLayer(base_layer.Layer):
|
|
"""A layer which asserts it's inputs are a certain type."""
|
|
|
|
def __init__(self, assert_type=None, **kwargs):
|
|
self._assert_type = (
|
|
tf.as_dtype(assert_type).name if assert_type else None
|
|
)
|
|
super().__init__(**kwargs)
|
|
|
|
def assert_input_types(self, inputs):
|
|
"""Asserts `inputs` are of the correct type. Should be called in
|
|
call()."""
|
|
if self._assert_type:
|
|
inputs_flattened = tf.nest.flatten(inputs)
|
|
for inp in inputs_flattened:
|
|
assert inp.dtype.base_dtype == self._assert_type, (
|
|
"Input tensor has type %s which does "
|
|
"not match assert type %s"
|
|
% (inp.dtype.name, self._assert_type)
|
|
)
|
|
|
|
|
|
class MultiplyLayer(AssertTypeLayer):
|
|
"""A layer which multiplies its input by a scalar variable."""
|
|
|
|
def __init__(
|
|
self,
|
|
regularizer=None,
|
|
activity_regularizer=None,
|
|
use_operator=False,
|
|
var_name="v",
|
|
**kwargs,
|
|
):
|
|
"""Initializes the MultiplyLayer.
|
|
|
|
Args:
|
|
regularizer: The weight regularizer on the scalar variable.
|
|
activity_regularizer: The activity regularizer.
|
|
use_operator: If True, add using the * operator. If False, add using
|
|
tf.multiply.
|
|
var_name: The name of the variable. It can be useful to pass a name
|
|
other than 'v', to test having the attribute name (self.v) being
|
|
different from the variable name.
|
|
**kwargs: Passed to AssertTypeLayer constructor.
|
|
"""
|
|
self._regularizer = regularizer
|
|
if isinstance(regularizer, dict):
|
|
self._regularizer = regularizers.deserialize(
|
|
regularizer, custom_objects=globals()
|
|
)
|
|
self._activity_regularizer = activity_regularizer
|
|
if isinstance(activity_regularizer, dict):
|
|
self._activity_regularizer = regularizers.deserialize(
|
|
activity_regularizer, custom_objects=globals()
|
|
)
|
|
|
|
self._use_operator = use_operator
|
|
self._var_name = var_name
|
|
super().__init__(
|
|
activity_regularizer=self._activity_regularizer, **kwargs
|
|
)
|
|
|
|
def build(self, _):
|
|
self.v = self.add_weight(
|
|
self._var_name,
|
|
(),
|
|
initializer="ones",
|
|
regularizer=self._regularizer,
|
|
)
|
|
self.built = True
|
|
|
|
def call(self, inputs):
|
|
self.assert_input_types(inputs)
|
|
return self._multiply(inputs, self.v)
|
|
|
|
def _multiply(self, x, y):
|
|
if self._use_operator:
|
|
return x * y
|
|
else:
|
|
return tf.multiply(x, y)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
config["regularizer"] = regularizers.serialize(self._regularizer)
|
|
config["activity_regularizer"] = regularizers.serialize(
|
|
self._activity_regularizer
|
|
)
|
|
config["use_operator"] = self._use_operator
|
|
config["var_name"] = self._var_name
|
|
config["assert_type"] = self._assert_type
|
|
return config
|
|
|
|
|
|
class MultiplyLayerWithoutAutoCast(MultiplyLayer):
|
|
"""Same as MultiplyLayer, but does not use AutoCastVariables."""
|
|
|
|
def build(self, _):
|
|
dtype = self.dtype
|
|
if dtype in ("float16", "bfloat16"):
|
|
dtype = "float32"
|
|
self.v = self.add_weight(
|
|
"v",
|
|
(),
|
|
initializer="ones",
|
|
dtype=dtype,
|
|
experimental_autocast=False,
|
|
regularizer=self._regularizer,
|
|
)
|
|
self.built = True
|
|
|
|
def call(self, inputs):
|
|
self.assert_input_types(inputs)
|
|
assert self.v.dtype in (tf.float32, tf.float64)
|
|
return self._multiply(inputs, tf.cast(self.v, inputs.dtype))
|
|
|
|
|
|
class IdentityRegularizer(regularizers.Regularizer):
|
|
def __call__(self, x):
|
|
assert x.dtype == tf.float32
|
|
return tf.identity(x)
|
|
|
|
def get_config(self):
|
|
return {}
|
|
|
|
|
|
class ReduceSumRegularizer(regularizers.Regularizer):
|
|
def __call__(self, x):
|
|
return tf.reduce_sum(x)
|
|
|
|
def get_config(self):
|
|
return {}
|