66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities related to distributed training."""
|
|
# pylint:disable=protected-access
|
|
|
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
|
from tensorflow.python.distribute import values as values_lib
|
|
from tensorflow.python.keras import backend
|
|
from tensorflow.python.ops import variables
|
|
|
|
|
|
# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
|
|
# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
|
|
# no longer needed.
|
|
def global_batch_size_supported(distribution_strategy):
|
|
return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access
|
|
|
|
|
|
def call_replica_local_fn(fn, *args, **kwargs):
|
|
"""Call a function that uses replica-local variables.
|
|
|
|
This function correctly handles calling `fn` in a cross-replica
|
|
context.
|
|
|
|
Args:
|
|
fn: The function to call.
|
|
*args: Positional arguments to the `fn`.
|
|
**kwargs: Keyword argument to `fn`.
|
|
|
|
Returns:
|
|
The result of calling `fn`.
|
|
"""
|
|
# TODO(b/132666209): Remove this function when we support assign_*
|
|
# for replica-local variables.
|
|
strategy = None
|
|
if 'strategy' in kwargs:
|
|
strategy = kwargs.pop('strategy')
|
|
else:
|
|
if ds_context.has_strategy():
|
|
strategy = ds_context.get_strategy()
|
|
|
|
# TODO(b/120571621): TPUStrategy does not implement replica-local variables.
|
|
is_tpu = backend.is_tpu_strategy(strategy)
|
|
if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
|
|
with strategy.scope():
|
|
return strategy.extended.call_for_each_replica(fn, args, kwargs)
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
def is_distributed_variable(v):
|
|
"""Returns whether `v` is a distributed variable."""
|
|
return (isinstance(v, values_lib.DistributedValues) and
|
|
isinstance(v, variables.Variable))
|