178 lines
6.0 KiB
Python
178 lines
6.0 KiB
Python
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Optimizer utilities."""
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
|
||
|
|
||
|
def all_reduce_sum_gradients(grads_and_vars):
|
||
|
"""Returns all-reduced gradients aggregated via summation.
|
||
|
|
||
|
Args:
|
||
|
grads_and_vars: List of (gradient, variable) pairs.
|
||
|
|
||
|
Returns:
|
||
|
List of (gradient, variable) pairs where gradients have been all-reduced.
|
||
|
"""
|
||
|
grads_and_vars = list(grads_and_vars)
|
||
|
filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
|
||
|
if filtered_grads_and_vars:
|
||
|
if tf.__internal__.distribute.strategy_supports_no_merge_call():
|
||
|
grads = [pair[0] for pair in filtered_grads_and_vars]
|
||
|
reduced = tf.distribute.get_replica_context().all_reduce(
|
||
|
tf.distribute.ReduceOp.SUM, grads
|
||
|
)
|
||
|
else:
|
||
|
# TODO(b/183257003): Remove this branch
|
||
|
reduced = tf.distribute.get_replica_context().merge_call(
|
||
|
_all_reduce_sum_fn, args=(filtered_grads_and_vars,)
|
||
|
)
|
||
|
else:
|
||
|
reduced = []
|
||
|
# Copy 'reduced' but add None gradients back in
|
||
|
reduced_with_nones = []
|
||
|
reduced_pos = 0
|
||
|
for g, v in grads_and_vars:
|
||
|
if g is None:
|
||
|
reduced_with_nones.append((None, v))
|
||
|
else:
|
||
|
reduced_with_nones.append((reduced[reduced_pos], v))
|
||
|
reduced_pos += 1
|
||
|
assert reduced_pos == len(reduced), "Failed to add all gradients"
|
||
|
return reduced_with_nones
|
||
|
|
||
|
|
||
|
def filter_empty_gradients(grads_and_vars):
|
||
|
"""Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
|
||
|
grads_and_vars = tuple(grads_and_vars)
|
||
|
if not grads_and_vars:
|
||
|
return grads_and_vars
|
||
|
|
||
|
filtered = []
|
||
|
vars_with_empty_grads = []
|
||
|
for grad, var in grads_and_vars:
|
||
|
if grad is None:
|
||
|
vars_with_empty_grads.append(var)
|
||
|
else:
|
||
|
filtered.append((grad, var))
|
||
|
filtered = tuple(filtered)
|
||
|
|
||
|
if not filtered:
|
||
|
variable = ([v.name for _, v in grads_and_vars],)
|
||
|
raise ValueError(
|
||
|
f"No gradients provided for any variable: {variable}. "
|
||
|
f"Provided `grads_and_vars` is {grads_and_vars}."
|
||
|
)
|
||
|
if vars_with_empty_grads:
|
||
|
logging.warning(
|
||
|
"Gradients do not exist for variables %s when minimizing the "
|
||
|
"loss. If you're using `model.compile()`, did you forget to "
|
||
|
"provide a `loss` argument?",
|
||
|
([v.name for v in vars_with_empty_grads]),
|
||
|
)
|
||
|
return filtered
|
||
|
|
||
|
|
||
|
def make_gradient_clipnorm_fn(clipnorm):
|
||
|
"""Creates a gradient transformation function for clipping by norm."""
|
||
|
if clipnorm is None:
|
||
|
return lambda grads_and_vars: grads_and_vars
|
||
|
|
||
|
def gradient_clipnorm_fn(grads_and_vars):
|
||
|
|
||
|
if isinstance(
|
||
|
tf.distribute.get_strategy(),
|
||
|
(
|
||
|
tf.distribute.experimental.CentralStorageStrategy,
|
||
|
tf.compat.v1.distribute.experimental.CentralStorageStrategy,
|
||
|
),
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"`clipnorm` is not supported with `CenteralStorageStrategy`. "
|
||
|
f"The strategy used is {tf.distribute.get_strategy()}."
|
||
|
)
|
||
|
|
||
|
clipped_grads_and_vars = [
|
||
|
(tf.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars
|
||
|
]
|
||
|
return clipped_grads_and_vars
|
||
|
|
||
|
return gradient_clipnorm_fn
|
||
|
|
||
|
|
||
|
def make_global_gradient_clipnorm_fn(clipnorm):
|
||
|
"""Creates a gradient transformation function for clipping by norm."""
|
||
|
if clipnorm is None:
|
||
|
return lambda grads_and_vars: grads_and_vars
|
||
|
|
||
|
def gradient_clipnorm_fn(grads_and_vars):
|
||
|
|
||
|
if isinstance(
|
||
|
tf.distribute.get_strategy(),
|
||
|
(
|
||
|
tf.distribute.experimental.CentralStorageStrategy,
|
||
|
tf.compat.v1.distribute.experimental.CentralStorageStrategy,
|
||
|
),
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"`global_clipnorm` is not supported with "
|
||
|
"`CenteralStorageStrategy`. "
|
||
|
f"The strategy used is {tf.distribute.get_strategy()}."
|
||
|
)
|
||
|
|
||
|
grads, variables = zip(*grads_and_vars)
|
||
|
clipped_grads, _ = tf.clip_by_global_norm(grads, clipnorm)
|
||
|
clipped_grads_and_vars = list(zip(clipped_grads, variables))
|
||
|
return clipped_grads_and_vars
|
||
|
|
||
|
return gradient_clipnorm_fn
|
||
|
|
||
|
|
||
|
def make_gradient_clipvalue_fn(clipvalue):
|
||
|
"""Creates a gradient transformation function for clipping by value."""
|
||
|
if clipvalue is None:
|
||
|
return lambda grads_and_vars: grads_and_vars
|
||
|
|
||
|
def gradient_clipvalue_fn(grads_and_vars):
|
||
|
|
||
|
if isinstance(
|
||
|
tf.distribute.get_strategy(),
|
||
|
(
|
||
|
tf.distribute.experimental.CentralStorageStrategy,
|
||
|
tf.compat.v1.distribute.experimental.CentralStorageStrategy,
|
||
|
),
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"`clipvalue` is not supported with `CenteralStorageStrategy`. "
|
||
|
f"The strategy used is {tf.distribute.get_strategy()}."
|
||
|
)
|
||
|
|
||
|
clipped_grads_and_vars = [
|
||
|
(tf.clip_by_value(g, -clipvalue, clipvalue), v)
|
||
|
for g, v in grads_and_vars
|
||
|
]
|
||
|
return clipped_grads_and_vars
|
||
|
|
||
|
return gradient_clipvalue_fn
|
||
|
|
||
|
|
||
|
def _all_reduce_sum_fn(distribution, grads_and_vars):
|
||
|
return distribution.extended.batch_reduce_to(
|
||
|
tf.distribute.ReduceOp.SUM, grads_and_vars
|
||
|
)
|