# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Optimizer utilities.""" import tensorflow.compat.v2 as tf # isort: off from tensorflow.python.platform import tf_logging as logging def all_reduce_sum_gradients(grads_and_vars): """Returns all-reduced gradients aggregated via summation. Args: grads_and_vars: List of (gradient, variable) pairs. Returns: List of (gradient, variable) pairs where gradients have been all-reduced. """ grads_and_vars = list(grads_and_vars) filtered_grads_and_vars = filter_empty_gradients(grads_and_vars) if filtered_grads_and_vars: if tf.__internal__.distribute.strategy_supports_no_merge_call(): grads = [pair[0] for pair in filtered_grads_and_vars] reduced = tf.distribute.get_replica_context().all_reduce( tf.distribute.ReduceOp.SUM, grads ) else: # TODO(b/183257003): Remove this branch reduced = tf.distribute.get_replica_context().merge_call( _all_reduce_sum_fn, args=(filtered_grads_and_vars,) ) else: reduced = [] # Copy 'reduced' but add None gradients back in reduced_with_nones = [] reduced_pos = 0 for g, v in grads_and_vars: if g is None: reduced_with_nones.append((None, v)) else: reduced_with_nones.append((reduced[reduced_pos], v)) reduced_pos += 1 assert reduced_pos == len(reduced), "Failed to add all gradients" return reduced_with_nones def filter_empty_gradients(grads_and_vars): """Filter out `(grad, var)` pairs that have a gradient equal to `None`.""" grads_and_vars = tuple(grads_and_vars) if not grads_and_vars: return grads_and_vars filtered = [] vars_with_empty_grads = [] for grad, var in grads_and_vars: if grad is None: vars_with_empty_grads.append(var) else: filtered.append((grad, var)) filtered = tuple(filtered) if not filtered: variable = ([v.name for _, v in grads_and_vars],) raise ValueError( f"No gradients provided for any variable: {variable}. " f"Provided `grads_and_vars` is {grads_and_vars}." ) if vars_with_empty_grads: logging.warning( "Gradients do not exist for variables %s when minimizing the " "loss. If you're using `model.compile()`, did you forget to " "provide a `loss` argument?", ([v.name for v in vars_with_empty_grads]), ) return filtered def make_gradient_clipnorm_fn(clipnorm): """Creates a gradient transformation function for clipping by norm.""" if clipnorm is None: return lambda grads_and_vars: grads_and_vars def gradient_clipnorm_fn(grads_and_vars): if isinstance( tf.distribute.get_strategy(), ( tf.distribute.experimental.CentralStorageStrategy, tf.compat.v1.distribute.experimental.CentralStorageStrategy, ), ): raise ValueError( "`clipnorm` is not supported with `CenteralStorageStrategy`. " f"The strategy used is {tf.distribute.get_strategy()}." ) clipped_grads_and_vars = [ (tf.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars ] return clipped_grads_and_vars return gradient_clipnorm_fn def make_global_gradient_clipnorm_fn(clipnorm): """Creates a gradient transformation function for clipping by norm.""" if clipnorm is None: return lambda grads_and_vars: grads_and_vars def gradient_clipnorm_fn(grads_and_vars): if isinstance( tf.distribute.get_strategy(), ( tf.distribute.experimental.CentralStorageStrategy, tf.compat.v1.distribute.experimental.CentralStorageStrategy, ), ): raise ValueError( "`global_clipnorm` is not supported with " "`CenteralStorageStrategy`. " f"The strategy used is {tf.distribute.get_strategy()}." ) grads, variables = zip(*grads_and_vars) clipped_grads, _ = tf.clip_by_global_norm(grads, clipnorm) clipped_grads_and_vars = list(zip(clipped_grads, variables)) return clipped_grads_and_vars return gradient_clipnorm_fn def make_gradient_clipvalue_fn(clipvalue): """Creates a gradient transformation function for clipping by value.""" if clipvalue is None: return lambda grads_and_vars: grads_and_vars def gradient_clipvalue_fn(grads_and_vars): if isinstance( tf.distribute.get_strategy(), ( tf.distribute.experimental.CentralStorageStrategy, tf.compat.v1.distribute.experimental.CentralStorageStrategy, ), ): raise ValueError( "`clipvalue` is not supported with `CenteralStorageStrategy`. " f"The strategy used is {tf.distribute.get_strategy()}." ) clipped_grads_and_vars = [ (tf.clip_by_value(g, -clipvalue, clipvalue), v) for g, v in grads_and_vars ] return clipped_grads_and_vars return gradient_clipvalue_fn def _all_reduce_sum_fn(distribution, grads_and_vars): return distribution.extended.batch_reduce_to( tf.distribute.ReduceOp.SUM, grads_and_vars )