4234 lines
174 KiB
Python
4234 lines
174 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# pylint: disable=line-too-long
|
|
"""Library for running a computation across multiple devices.
|
|
|
|
The intent of this library is that you can write an algorithm in a stylized way
|
|
and it will be usable with a variety of different `tf.distribute.Strategy`
|
|
implementations. Each descendant will implement a different strategy for
|
|
distributing the algorithm across multiple devices/machines. Furthermore, these
|
|
changes can be hidden inside the specific layers and other library classes that
|
|
need special treatment to run in a distributed setting, so that most users'
|
|
model definition code can run unchanged. The `tf.distribute.Strategy` API works
|
|
the same way with eager and graph execution.
|
|
|
|
*Guides*
|
|
|
|
* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
|
|
* [TensorFlow
|
|
v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
|
|
|
|
*Tutorials*
|
|
|
|
* [Distributed Training
|
|
Tutorials](https://www.tensorflow.org/tutorials/distribute/)
|
|
|
|
The tutorials cover how to use `tf.distribute.Strategy` to do distributed
|
|
training with native Keras APIs, and custom training loops.
|
|
They also cover how to save/load model when using `tf.distribute.Strategy`.
|
|
|
|
*Glossary*
|
|
|
|
* _Data parallelism_ is where we run multiple copies of the model
|
|
on different slices of the input data. This is in contrast to
|
|
_model parallelism_ where we divide up a single copy of a model
|
|
across multiple devices.
|
|
Note: we only support data parallelism for now, but
|
|
hope to add support for model parallelism in the future.
|
|
* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
|
|
TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
|
|
devices on a single machine, or be connected to devices on multiple
|
|
machines. Devices used to run computations are called _worker devices_.
|
|
Devices used to store variables are _parameter devices_. For some strategies,
|
|
such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
|
|
will be the same (see mirrored variables below). For others they will be
|
|
different. For example, `tf.distribute.experimental.CentralStorageStrategy`
|
|
puts the variables on a single device (which may be a worker device or may be
|
|
the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
|
|
variables on separate machines called _parameter servers_ (see below).
|
|
* A _replica_ is one copy of the model, running on one slice of the
|
|
input data. Right now each replica is executed on its own
|
|
worker device, but once we add support for model parallelism
|
|
a replica may span multiple worker devices.
|
|
* A _host_ is the CPU device on a machine with worker devices, typically
|
|
used for running input pipelines.
|
|
* A _worker_ is defined to be the physical machine(s) containing the physical
|
|
devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
|
|
worker may contain one or more replicas, but contains at least one
|
|
replica. Typically one worker will correspond to one machine, but in the case
|
|
of very large models with model parallelism, one worker may span multiple
|
|
machines. We typically run one input pipeline per worker, feeding all the
|
|
replicas on that worker.
|
|
* _Synchronous_, or more commonly _sync_, training is where the updates from
|
|
each replica are aggregated together before updating the model variables. This
|
|
is in contrast to _asynchronous_, or _async_ training, where each replica
|
|
updates the model variables independently. You may also have replicas
|
|
partitioned into groups which are in sync within each group but async between
|
|
groups.
|
|
* _Parameter servers_: These are machines that hold a single copy of
|
|
parameters/variables, used by some strategies (right now just
|
|
`tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
|
|
to operate on a variable retrieve it at the beginning of a step and send an
|
|
update to be applied at the end of the step. These can in principle support
|
|
either sync or async training, but right now we only have support for async
|
|
training with parameter servers. Compare to
|
|
`tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
|
|
on a single device on the same machine (and does sync training), and
|
|
`tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
|
|
(see below).
|
|
|
|
* _Replica context_ vs. _Cross-replica context_ vs _Update context_
|
|
|
|
A _replica context_ applies
|
|
when you execute the computation function that was called with `strategy.run`.
|
|
Conceptually, you're in replica context when executing the computation
|
|
function that is being replicated.
|
|
|
|
An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
|
|
call.
|
|
|
|
An _cross-replica context_ is entered when you enter a `strategy.scope`. This
|
|
is useful for calling `tf.distribute.Strategy` methods which operate across
|
|
the replicas (like `reduce_to()`). By default you start in a _replica context_
|
|
(the "default single _replica context_") and then some methods can switch you
|
|
back and forth.
|
|
|
|
* _Distributed value_: Distributed value is represented by the base class
|
|
`tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
|
|
to represent values on multiple devices, and it contains a map from replica id
|
|
to values. Two representative types of `tf.distribute.DistributedValues`
|
|
are `tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored`
|
|
values.
|
|
|
|
`PerReplica` values exist on the worker devices, with a different value for
|
|
each replica. They are produced by iterating through a distributed dataset
|
|
returned by `tf.distribute.Strategy.experimental_distribute_dataset` and
|
|
`tf.distribute.Strategy.distribute_datasets_from_function`. They are also the
|
|
typical result returned by `tf.distribute.Strategy.run`.
|
|
|
|
`Mirrored` values are like `PerReplica` values, except we know that the value
|
|
on all replicas are the same. `Mirrored` values are kept synchronized by the
|
|
distribution strategy in use, while `PerReplica` values are left
|
|
unsynchronized. `Mirrored` values typically represent model weights. We can
|
|
safely read a `Mirrored` value in a cross-replica context by using the value
|
|
on any replica, while PerReplica values can only be read within a replica
|
|
context.
|
|
|
|
* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
|
|
replicas, like `strategy.run(fn, args=[w])` with an
|
|
argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
|
|
have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
|
|
`strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
|
|
device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return
|
|
values from `fn()`, which leads to one common object if the returned values
|
|
are the same object from every replica, or a `DistributedValues` object
|
|
otherwise.
|
|
|
|
* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
|
|
multiple values into one value, like "sum" or "mean". If a strategy is doing
|
|
sync training, we will perform a reduction on the gradients to a parameter
|
|
from all replicas before applying the update. _All-reduce_ is an algorithm for
|
|
performing a reduction on values from multiple devices and making the result
|
|
available on all of those devices.
|
|
|
|
* _Mirrored variables_: These are variables that are created on multiple
|
|
devices, where we keep the variables in sync by applying the same
|
|
updates to every copy. Mirrored variables are created with
|
|
`tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
|
|
Normally they are only used in synchronous training.
|
|
|
|
* _SyncOnRead variables_
|
|
|
|
_SyncOnRead variables_ are created by
|
|
`tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
|
|
they are created on multiple devices. In replica context, each
|
|
component variable on the local replica can perform reads and writes without
|
|
synchronization with each other. When the
|
|
_SyncOnRead variable_ is read in cross-replica context, the values from
|
|
component variables are aggregated and returned.
|
|
|
|
_SyncOnRead variables_ bring a lot of custom configuration difficulty to the
|
|
underlying logic, so we do not encourage users to instantiate and use
|
|
_SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
|
|
variables_ for use cases such as batch norm and metrics. For performance
|
|
reasons, we often don't need to keep these statistics in sync every step and
|
|
they can be accumulated on each replica independently. The only time we want
|
|
to sync them is reporting or checkpointing, which typically happens in
|
|
cross-replica context. _SyncOnRead variables_ are also often used by advanced
|
|
users who want to control when variable values are aggregated. For example,
|
|
users sometimes want to maintain gradients independently on each replica for a
|
|
couple of steps without aggregation.
|
|
|
|
* _Distribute-aware layers_
|
|
|
|
Layers are generally called in a replica context, except when defining a
|
|
Keras functional model. `tf.distribute.in_cross_replica_context` will let you
|
|
determine which case you are in. If in a replica context,
|
|
the `tf.distribute.get_replica_context` function will return the default
|
|
replica context outside a strategy scope, `None` within a strategy scope, and
|
|
a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
|
|
`tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
|
|
`all_reduce` method for aggregating across all replicas.
|
|
|
|
|
|
Note that we provide a default version of `tf.distribute.Strategy` that is
|
|
used when no other strategy is in scope, that provides the same API with
|
|
reasonable default behavior.
|
|
"""
|
|
# pylint: enable=line-too-long
|
|
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import enum # pylint: disable=g-bad-import-order
|
|
import functools
|
|
import threading
|
|
import weakref
|
|
|
|
import six
|
|
|
|
from tensorflow.python import tf2
|
|
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
|
from tensorflow.python.autograph.impl import api as autograph
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.distribute import collective_util
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.distribute import numpy_dataset
|
|
from tensorflow.python.distribute import reduce_util
|
|
from tensorflow.python.eager import context as eager_context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import monitoring
|
|
from tensorflow.python.eager import tape
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import indexed_slices
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor as tensor_lib
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import custom_gradient
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import ref_variable
|
|
from tensorflow.python.ops import summary_ops_v2
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variable_v1
|
|
from tensorflow.python.platform import tf_logging
|
|
from tensorflow.python.trackable import base as trackable
|
|
from tensorflow.python.types import distribute as ds_types
|
|
from tensorflow.python.util import deprecation
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import tf_contextlib
|
|
from tensorflow.python.util.deprecation import deprecated
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
from tensorflow.tools.docs import doc_controls
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Context tracking whether in a strategy.update() or .update_non_slot() call.
|
|
|
|
|
|
_update_replica_id = threading.local()
|
|
|
|
|
|
def get_update_replica_id():
|
|
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
|
|
try:
|
|
return _update_replica_id.current
|
|
except AttributeError:
|
|
return None
|
|
|
|
|
|
class UpdateContext(object):
|
|
"""Context manager when you are in `update()` or `update_non_slot()`."""
|
|
|
|
__slots__ = ["_replica_id", "_old_replica_id"]
|
|
|
|
def __init__(self, replica_id):
|
|
self._replica_id = replica_id
|
|
self._old_replica_id = None
|
|
|
|
def __enter__(self):
|
|
self._old_replica_id = get_update_replica_id()
|
|
_update_replica_id.current = self._replica_id
|
|
|
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
del exception_type, exception_value, traceback
|
|
_update_replica_id.current = self._old_replica_id
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Internal API for validating the current thread mode
|
|
|
|
|
|
def _require_cross_replica_or_default_context_extended(extended,
|
|
error_message=None):
|
|
"""Verify in cross-replica context."""
|
|
context = _get_per_thread_mode()
|
|
cross_replica = context.cross_replica_context
|
|
if cross_replica is not None and cross_replica.extended is extended:
|
|
return
|
|
if context is _get_default_replica_mode():
|
|
return
|
|
strategy = extended._container_strategy() # pylint: disable=protected-access
|
|
# We have an error to report, figure out the right message.
|
|
if context.strategy is not strategy:
|
|
_wrong_strategy_scope(strategy, context)
|
|
assert cross_replica is None
|
|
if not error_message:
|
|
error_message = ("Method requires being in cross-replica context, use "
|
|
"get_replica_context().merge_call()")
|
|
raise RuntimeError(error_message)
|
|
|
|
|
|
def _wrong_strategy_scope(strategy, context):
|
|
# Figure out the right error message.
|
|
if not has_strategy():
|
|
raise RuntimeError(
|
|
'Need to be inside "with strategy.scope()" for %s' %
|
|
(strategy,))
|
|
else:
|
|
raise RuntimeError(
|
|
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
|
(context.strategy, strategy))
|
|
|
|
|
|
def require_replica_context(replica_ctx):
|
|
"""Verify in `replica_ctx` replica context."""
|
|
context = _get_per_thread_mode()
|
|
if context.replica_context is replica_ctx: return
|
|
# We have an error to report, figure out the right message.
|
|
if context.replica_context is None:
|
|
raise RuntimeError("Need to be inside `call_for_each_replica()`")
|
|
if context.strategy is replica_ctx.strategy:
|
|
# Two different ReplicaContexts with the same tf.distribute.Strategy.
|
|
raise RuntimeError("Mismatching ReplicaContext.")
|
|
raise RuntimeError(
|
|
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
|
|
(context.strategy, replica_ctx.strategy))
|
|
|
|
|
|
def _require_strategy_scope_strategy(strategy):
|
|
"""Verify in a `strategy.scope()` in this thread."""
|
|
context = _get_per_thread_mode()
|
|
if context.strategy is strategy: return
|
|
_wrong_strategy_scope(strategy, context)
|
|
|
|
|
|
def _require_strategy_scope_extended(extended):
|
|
"""Verify in a `distribution_strategy.scope()` in this thread."""
|
|
context = _get_per_thread_mode()
|
|
if context.strategy.extended is extended: return
|
|
# Report error.
|
|
strategy = extended._container_strategy() # pylint: disable=protected-access
|
|
_wrong_strategy_scope(strategy, context)
|
|
|
|
|
|
_creating_default_strategy_singleton = False
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Internal API for setting the current thread mode as being either in a
|
|
# replica or cross-replica context for a particular tf.distribute.Strategy.
|
|
|
|
|
|
class _ThreadMode(object):
|
|
|
|
def __init__(self, dist, cross, replica):
|
|
self.strategy = dist
|
|
self.cross_replica_context = cross
|
|
self.replica_context = replica
|
|
|
|
|
|
class _CrossReplicaThreadMode(_ThreadMode):
|
|
|
|
def __init__(self, strategy):
|
|
_ThreadMode.__init__(self, strategy, strategy, None)
|
|
|
|
|
|
class _InReplicaThreadMode(_ThreadMode):
|
|
|
|
def __init__(self, replica_ctx):
|
|
_ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx)
|
|
|
|
|
|
def _push_per_thread_mode(context):
|
|
ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
|
|
|
|
|
|
def _pop_per_thread_mode():
|
|
ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
|
|
|
|
|
|
class _DefaultReplicaThreadMode(_ThreadMode):
|
|
"""Type of default value returned by `_get_per_thread_mode()`.
|
|
|
|
Used when the thread-local stack is empty.
|
|
"""
|
|
|
|
def __init__(self):
|
|
_ThreadMode.__init__(self, _get_default_strategy(), None,
|
|
_get_default_replica_context())
|
|
|
|
|
|
def _get_per_thread_mode():
|
|
try:
|
|
return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
|
|
except (AttributeError, IndexError):
|
|
return _get_default_replica_mode()
|
|
|
|
|
|
_variable_sync_on_read_context = threading.local()
|
|
|
|
|
|
@tf_export("__internal__.distribute.variable_sync_on_read_context", v1=[])
|
|
@contextlib.contextmanager
|
|
def variable_sync_on_read_context():
|
|
"""A context that forces SyncOnReadVariable to aggregate upon reading.
|
|
|
|
This context is useful if one wants to read the aggregated value out of a
|
|
SyncOnReadVariable in replica context. By default the aggregation is turned
|
|
off per the definition of SyncOnReadVariable.
|
|
|
|
When reading a SyncOnReadVariable in cross-replica context, aggregation is
|
|
always turned on so there is no need for such context.
|
|
|
|
By reading a SyncOnReadVariable, we mean:
|
|
1. Convert the variable to a tensor using `convert_to_tensor`.
|
|
2. Calling `variable.value()` or `variable.read_value()`.
|
|
|
|
Example usage:
|
|
|
|
```
|
|
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
|
|
with strategy.scope():
|
|
v = tf.Variable(1.0, synchronization=tf.VariableSynchronization.ON_READ,
|
|
aggregation=tf.VariableAggregation.SUM)
|
|
|
|
def replica_fn():
|
|
return v + 10.0
|
|
|
|
non_aggregated = strategy.run(replica_fn)
|
|
print(non_aggregated) # PerReplica: {0: 11.0, 1: 11.0}
|
|
|
|
def replica_fn():
|
|
with variable_sync_on_read_context():
|
|
return v + 10.0
|
|
|
|
aggregated = strategy.run(replica_fn)
|
|
print(aggregated) # PerReplica: {0: 12.0, 1: 12.0}
|
|
```
|
|
|
|
Yields:
|
|
Context manager for aggregating SyncOnReadVariable upon reading.
|
|
"""
|
|
try:
|
|
_variable_sync_on_read_context.entered = True
|
|
yield
|
|
finally:
|
|
_variable_sync_on_read_context.entered = False
|
|
|
|
|
|
def in_variable_sync_on_read_context():
|
|
try:
|
|
return _variable_sync_on_read_context.entered
|
|
except AttributeError:
|
|
return False
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Public API for accessing the current thread mode
|
|
|
|
|
|
@tf_export("distribute.get_replica_context")
|
|
def get_replica_context():
|
|
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
|
|
|
|
Returns `None` if in a cross-replica context.
|
|
|
|
Note that execution:
|
|
|
|
1. starts in the default (single-replica) replica context (this function
|
|
will return the default `ReplicaContext` object);
|
|
2. switches to cross-replica context (in which case this will return
|
|
`None`) when entering a `with tf.distribute.Strategy.scope():` block;
|
|
3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`;
|
|
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
|
|
inside `merge_fn` you are back in the cross-replica context (and again
|
|
this function will return `None`).
|
|
|
|
Most `tf.distribute.Strategy` methods may only be executed in
|
|
a cross-replica context, in a replica context you should use the
|
|
API of the `tf.distribute.ReplicaContext` object returned by this
|
|
method instead.
|
|
|
|
```
|
|
assert tf.distribute.get_replica_context() is not None # default
|
|
with strategy.scope():
|
|
assert tf.distribute.get_replica_context() is None
|
|
|
|
def f():
|
|
replica_context = tf.distribute.get_replica_context() # for strategy
|
|
assert replica_context is not None
|
|
tf.print("Replica id: ", replica_context.replica_id_in_sync_group,
|
|
" of ", replica_context.num_replicas_in_sync)
|
|
|
|
strategy.run(f)
|
|
```
|
|
|
|
Returns:
|
|
The current `tf.distribute.ReplicaContext` object when in a replica context
|
|
scope, else `None`.
|
|
|
|
Within a particular block, exactly one of these two things will be true:
|
|
|
|
* `get_replica_context()` returns non-`None`, or
|
|
* `tf.distribute.is_cross_replica_context()` returns True.
|
|
"""
|
|
return _get_per_thread_mode().replica_context
|
|
|
|
|
|
def get_cross_replica_context():
|
|
"""Returns the current tf.distribute.Strategy if in a cross-replica context.
|
|
|
|
DEPRECATED: Please use `in_cross_replica_context()` and
|
|
`get_strategy()` instead.
|
|
|
|
Returns:
|
|
Returns the current `tf.distribute.Strategy` object in a cross-replica
|
|
context, or `None`.
|
|
|
|
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
|
|
will return `None` in a particular block.
|
|
"""
|
|
return _get_per_thread_mode().cross_replica_context
|
|
|
|
|
|
@tf_export("distribute.in_cross_replica_context")
|
|
def in_cross_replica_context():
|
|
"""Returns `True` if in a cross-replica context.
|
|
|
|
See `tf.distribute.get_replica_context` for details.
|
|
|
|
```
|
|
assert not tf.distribute.in_cross_replica_context()
|
|
with strategy.scope():
|
|
assert tf.distribute.in_cross_replica_context()
|
|
|
|
def f():
|
|
assert not tf.distribute.in_cross_replica_context()
|
|
|
|
strategy.run(f)
|
|
```
|
|
|
|
Returns:
|
|
`True` if in a cross-replica context (`get_replica_context()` returns
|
|
`None`), or `False` if in a replica context (`get_replica_context()` returns
|
|
non-`None`).
|
|
"""
|
|
return _get_per_thread_mode().cross_replica_context is not None
|
|
|
|
|
|
@tf_export("distribute.get_strategy")
|
|
def get_strategy() -> "StrategyBase":
|
|
"""Returns the current `tf.distribute.Strategy` object.
|
|
|
|
Typically only used in a cross-replica context:
|
|
|
|
```
|
|
if tf.distribute.in_cross_replica_context():
|
|
strategy = tf.distribute.get_strategy()
|
|
...
|
|
```
|
|
|
|
Returns:
|
|
A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block,
|
|
it returns `strategy`, otherwise it returns the default (single-replica)
|
|
`tf.distribute.Strategy` object.
|
|
"""
|
|
return _get_per_thread_mode().strategy
|
|
|
|
|
|
@tf_export("distribute.has_strategy")
|
|
def has_strategy():
|
|
"""Return if there is a current non-default `tf.distribute.Strategy`.
|
|
|
|
```
|
|
assert not tf.distribute.has_strategy()
|
|
with strategy.scope():
|
|
assert tf.distribute.has_strategy()
|
|
```
|
|
|
|
Returns:
|
|
True if inside a `with strategy.scope():`.
|
|
"""
|
|
return get_strategy() is not _get_default_strategy()
|
|
|
|
|
|
def get_strategy_and_replica_context():
|
|
per_thread_mode = _get_per_thread_mode()
|
|
return (per_thread_mode.strategy, per_thread_mode.replica_context)
|
|
|
|
|
|
@tf_export("distribute.experimental_set_strategy")
|
|
def experimental_set_strategy(strategy):
|
|
"""Set a `tf.distribute.Strategy` as current without `with strategy.scope()`.
|
|
|
|
```
|
|
tf.distribute.experimental_set_strategy(strategy1)
|
|
f()
|
|
tf.distribute.experimental_set_strategy(strategy2)
|
|
g()
|
|
tf.distribute.experimental_set_strategy(None)
|
|
h()
|
|
```
|
|
|
|
is equivalent to:
|
|
|
|
```
|
|
with strategy1.scope():
|
|
f()
|
|
with strategy2.scope():
|
|
g()
|
|
h()
|
|
```
|
|
|
|
In general, you should use the `with strategy.scope():` API, but this
|
|
alternative may be convenient in notebooks where you would have to put
|
|
each cell in a `with strategy.scope():` block.
|
|
|
|
Note: This should only be called outside of any TensorFlow scope to
|
|
avoid improper nesting.
|
|
|
|
Args:
|
|
strategy: A `tf.distribute.Strategy` object or None.
|
|
|
|
Raises:
|
|
RuntimeError: If called inside a `with strategy.scope():`.
|
|
"""
|
|
old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access
|
|
if old_scope is not None:
|
|
old_scope.__exit__(None, None, None)
|
|
ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access
|
|
if has_strategy():
|
|
raise RuntimeError(
|
|
"Must not be called inside a `tf.distribute.Strategy` scope.")
|
|
if strategy is not None:
|
|
new_scope = strategy.scope()
|
|
new_scope.__enter__()
|
|
ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Internal helpers.
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enter_or_assert_strategy(strategy):
|
|
if has_strategy():
|
|
_assert_strategy(strategy)
|
|
yield
|
|
else:
|
|
with strategy.scope():
|
|
yield
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Defaults that are used when no tf.distribute.Strategy is explicitly created.
|
|
# We create them lazily in a function so that we can workaround the circular
|
|
# dependency on distribute_lib. See lazy loader at the top of this file.
|
|
|
|
_defaults = {
|
|
"strategy": None,
|
|
"replica_context": None,
|
|
"replica_mode": None
|
|
}
|
|
# Note: These need to be different locks since _get_default_replica_context
|
|
# calls _get_default_strategy inside its lock, and them using the same lock
|
|
# can lead to deadlock.
|
|
_default_strategy_lock = threading.Lock()
|
|
_default_replica_context_lock = threading.Lock()
|
|
_default_replica_mode_lock = threading.Lock()
|
|
|
|
|
|
def _assert_strategy(strategy):
|
|
if not has_strategy():
|
|
raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
|
|
(strategy,))
|
|
current_strategy = get_strategy()
|
|
if current_strategy is not strategy:
|
|
raise RuntimeError(
|
|
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
|
(current_strategy, strategy))
|
|
|
|
|
|
def _get_default_strategy():
|
|
if _defaults["strategy"] is None:
|
|
# Avoid race condition causing two defaults to be created
|
|
with _default_strategy_lock:
|
|
if _defaults["strategy"] is None:
|
|
# pylint: disable=protected-access
|
|
# Make sure distribute_lib module is loaded by accessing some member.
|
|
global _creating_default_strategy_singleton
|
|
_creating_default_strategy_singleton = True
|
|
if tf2.enabled():
|
|
_defaults["strategy"] = _DefaultDistributionStrategy()
|
|
else:
|
|
_defaults["strategy"] = (
|
|
_DefaultDistributionStrategyV1())
|
|
_creating_default_strategy_singleton = False
|
|
# pylint: enable=protected-access
|
|
return _defaults["strategy"]
|
|
|
|
|
|
def _get_default_replica_context():
|
|
if _defaults["replica_context"] is None:
|
|
# Avoid race condition causing two defaults to be created
|
|
with _default_replica_context_lock:
|
|
if _defaults["replica_context"] is None:
|
|
# pylint: disable=protected-access
|
|
_defaults["replica_context"] = _DefaultReplicaContext(
|
|
_get_default_strategy(), replica_id_in_sync_group=0)
|
|
# pylint: enable=protected-access
|
|
return _defaults["replica_context"]
|
|
|
|
|
|
def _get_default_replica_mode():
|
|
if _defaults["replica_mode"] is None:
|
|
# Avoid race condition causing two defaults to be created
|
|
with _default_replica_mode_lock:
|
|
if _defaults["replica_mode"] is None:
|
|
_defaults["replica_mode"] = _DefaultReplicaThreadMode()
|
|
return _defaults["replica_mode"]
|
|
|
|
|
|
# Aliases for compatibility with old names.
|
|
get_distribution_strategy = get_strategy
|
|
has_distribution_strategy = has_strategy
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Internal context managers used to implement the DistributionStrategy
|
|
# base class
|
|
|
|
|
|
class _CurrentDistributionContext(object):
|
|
"""Context manager setting the current `tf.distribute.Strategy`.
|
|
|
|
Also: overrides the variable creator and optionally the current device.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
strategy,
|
|
var_creator_scope,
|
|
var_scope=None,
|
|
resource_creator_scope=None,
|
|
default_device_scope=None,
|
|
):
|
|
self._context = _CrossReplicaThreadMode( # pylint: disable=protected-access
|
|
strategy)
|
|
self._var_creator_scope = var_creator_scope
|
|
self._var_scope = var_scope
|
|
self._resource_creator_scope = resource_creator_scope
|
|
if default_device_scope:
|
|
self._device_scope = default_device_scope
|
|
else:
|
|
self._device_scope = None
|
|
self._same_scope_again_count = 0
|
|
|
|
def __enter__(self):
|
|
# Allow this scope to be entered if this strategy is already in scope.
|
|
if has_strategy():
|
|
_require_cross_replica_or_default_context_extended(
|
|
self._context.strategy.extended)
|
|
self._same_scope_again_count += 1
|
|
else:
|
|
_push_per_thread_mode(self._context)
|
|
if self._var_scope:
|
|
self._var_scope.__enter__()
|
|
self._var_creator_scope.__enter__()
|
|
if self._resource_creator_scope:
|
|
nest.map_structure(lambda scope: scope.__enter__(),
|
|
self._resource_creator_scope)
|
|
if self._device_scope:
|
|
self._device_scope.__enter__()
|
|
return self._context.strategy
|
|
|
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
if hasattr(self._context.strategy.extended, "_lazy_variable_tracker"):
|
|
self._context.strategy.extended._lazy_variable_tracker.initialize_all()
|
|
|
|
if self._same_scope_again_count > 0:
|
|
self._same_scope_again_count -= 1
|
|
return
|
|
if self._device_scope:
|
|
try:
|
|
self._device_scope.__exit__(exception_type, exception_value, traceback)
|
|
except RuntimeError as e:
|
|
six.raise_from(
|
|
RuntimeError("Device scope nesting error: move call to "
|
|
"tf.distribute.set_strategy() out of `with` scope."),
|
|
e)
|
|
|
|
try:
|
|
self._var_creator_scope.__exit__(
|
|
exception_type, exception_value, traceback)
|
|
except RuntimeError as e:
|
|
six.raise_from(
|
|
RuntimeError("Variable creator scope nesting error: move call to "
|
|
"tf.distribute.set_strategy() out of `with` scope."),
|
|
e)
|
|
|
|
if self._resource_creator_scope:
|
|
try:
|
|
if isinstance(self._resource_creator_scope, list):
|
|
reversed_resource_creator_scope = self._resource_creator_scope[::-1]
|
|
nest.map_structure(
|
|
lambda scope: scope.__exit__(exception_type, exception_value, # pylint:disable=g-long-lambda
|
|
traceback),
|
|
reversed_resource_creator_scope)
|
|
|
|
else:
|
|
self._resource_creator_scope.__exit__(exception_type, exception_value,
|
|
traceback)
|
|
except RuntimeError as e:
|
|
six.raise_from(
|
|
RuntimeError("Resource creator scope nesting error: move call "
|
|
"to tf.distribute.set_strategy() out of `with` "
|
|
"scope."), e)
|
|
|
|
if self._var_scope:
|
|
try:
|
|
self._var_scope.__exit__(exception_type, exception_value, traceback)
|
|
except RuntimeError as e:
|
|
six.raise_from(
|
|
RuntimeError("Variable scope nesting error: move call to "
|
|
"tf.distribute.set_strategy() out of `with` scope."),
|
|
e)
|
|
_pop_per_thread_mode()
|
|
|
|
|
|
# TODO(yuefengz): add more replication modes.
|
|
@tf_export("distribute.InputReplicationMode")
|
|
class InputReplicationMode(enum.Enum):
|
|
"""Replication mode for input function.
|
|
|
|
* `PER_WORKER`: The input function will be called on each worker
|
|
independently, creating as many input pipelines as number of workers.
|
|
Replicas will dequeue from the local Dataset on their worker.
|
|
`tf.distribute.Strategy` doesn't manage any state sharing between such
|
|
separate input pipelines.
|
|
* `PER_REPLICA`: The input function will be called on each replica separately.
|
|
`tf.distribute.Strategy` doesn't manage any state sharing between such
|
|
separate input pipelines.
|
|
"""
|
|
PER_WORKER = "PER_WORKER"
|
|
PER_REPLICA = "PER_REPLICA"
|
|
|
|
|
|
@tf_export("distribute.InputContext")
|
|
class InputContext(object):
|
|
"""A class wrapping information needed by an input function.
|
|
|
|
This is a context class that is passed to the user's input function and
|
|
contains information about the compute replicas and input pipelines. The
|
|
number of compute replicas (in sync training) helps compute the local batch
|
|
size from the desired global batch size for each replica. The input pipeline
|
|
information can be used to return a different subset of the input in each
|
|
replica (for e.g. shard the input pipeline, use a different input
|
|
source etc).
|
|
"""
|
|
|
|
__slots__ = [
|
|
"_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
|
|
]
|
|
|
|
def __init__(self,
|
|
num_input_pipelines=1,
|
|
input_pipeline_id=0,
|
|
num_replicas_in_sync=1):
|
|
"""Initializes an InputContext object.
|
|
|
|
Args:
|
|
num_input_pipelines: the number of input pipelines in a cluster.
|
|
input_pipeline_id: the current input pipeline id, should be an int in
|
|
[0,`num_input_pipelines`).
|
|
num_replicas_in_sync: the number of replicas that are in sync.
|
|
"""
|
|
self._num_input_pipelines = num_input_pipelines
|
|
self._input_pipeline_id = input_pipeline_id
|
|
self._num_replicas_in_sync = num_replicas_in_sync
|
|
|
|
@property
|
|
def num_replicas_in_sync(self):
|
|
"""Returns the number of compute replicas in sync."""
|
|
return self._num_replicas_in_sync
|
|
|
|
@property
|
|
def input_pipeline_id(self):
|
|
"""Returns the input pipeline ID."""
|
|
return self._input_pipeline_id
|
|
|
|
@property
|
|
def num_input_pipelines(self):
|
|
"""Returns the number of input pipelines."""
|
|
return self._num_input_pipelines
|
|
|
|
def get_per_replica_batch_size(self, global_batch_size):
|
|
"""Returns the per-replica batch size.
|
|
|
|
Args:
|
|
global_batch_size: the global batch size which should be divisible by
|
|
`num_replicas_in_sync`.
|
|
|
|
Returns:
|
|
the per-replica batch size.
|
|
|
|
Raises:
|
|
ValueError: if `global_batch_size` not divisible by
|
|
`num_replicas_in_sync`.
|
|
"""
|
|
if global_batch_size % self._num_replicas_in_sync != 0:
|
|
raise ValueError("The `global_batch_size` %r is not divisible by "
|
|
"`num_replicas_in_sync` %r " %
|
|
(global_batch_size, self._num_replicas_in_sync))
|
|
return global_batch_size // self._num_replicas_in_sync
|
|
|
|
def __str__(self):
|
|
return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
|
|
self.input_pipeline_id, self.num_input_pipelines)
|
|
|
|
|
|
@tf_export("distribute.experimental.ValueContext", v1=[])
|
|
class ValueContext(object):
|
|
"""A class wrapping information needed by a distribute function.
|
|
|
|
This is a context class that is passed to the `value_fn` in
|
|
`strategy.experimental_distribute_values_from_function` and contains
|
|
information about the compute replicas. The `num_replicas_in_sync` and
|
|
`replica_id` can be used to customize the value on each replica.
|
|
|
|
Example usage:
|
|
|
|
1. Directly constructed.
|
|
|
|
>>> def value_fn(context):
|
|
... return context.replica_id_in_sync_group/context.num_replicas_in_sync
|
|
>>> context = tf.distribute.experimental.ValueContext(
|
|
... replica_id_in_sync_group=2, num_replicas_in_sync=4)
|
|
>>> per_replica_value = value_fn(context)
|
|
>>> per_replica_value
|
|
0.5
|
|
|
|
2. Passed in by `experimental_distribute_values_from_function`. {: value=2}
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> def value_fn(value_context):
|
|
... return value_context.num_replicas_in_sync
|
|
>>> distributed_values = (
|
|
... strategy.experimental_distribute_values_from_function(
|
|
... value_fn))
|
|
>>> local_result = strategy.experimental_local_results(distributed_values)
|
|
>>> local_result
|
|
(2, 2)
|
|
|
|
"""
|
|
|
|
__slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
|
|
|
|
def __init__(self,
|
|
replica_id_in_sync_group=0,
|
|
num_replicas_in_sync=1):
|
|
"""Initializes a ValueContext object.
|
|
|
|
Args:
|
|
replica_id_in_sync_group: the current replica_id, should be an int in
|
|
[0,`num_replicas_in_sync`).
|
|
num_replicas_in_sync: the number of replicas that are in sync.
|
|
"""
|
|
self._replica_id_in_sync_group = replica_id_in_sync_group
|
|
self._num_replicas_in_sync = num_replicas_in_sync
|
|
|
|
@property
|
|
def num_replicas_in_sync(self):
|
|
"""Returns the number of compute replicas in sync."""
|
|
return self._num_replicas_in_sync
|
|
|
|
@property
|
|
def replica_id_in_sync_group(self):
|
|
"""Returns the replica ID."""
|
|
return self._replica_id_in_sync_group
|
|
|
|
def __str__(self):
|
|
return (("tf.distribute.ValueContext(replica id {}, "
|
|
" total replicas in sync: ""{})")
|
|
.format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
|
|
|
|
|
|
@tf_export("distribute.RunOptions")
|
|
class RunOptions(
|
|
collections.namedtuple("RunOptions", [
|
|
"experimental_enable_dynamic_batch_size",
|
|
"experimental_bucketizing_dynamic_shape",
|
|
"experimental_xla_options",
|
|
])):
|
|
"""Run options for `strategy.run`.
|
|
|
|
This can be used to hold some strategy specific configs.
|
|
|
|
Attributes:
|
|
experimental_enable_dynamic_batch_size: Boolean. Only applies to
|
|
TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
|
|
padder to support dynamic batch size for the inputs. Otherwise only static
|
|
shape inputs are allowed.
|
|
experimental_bucketizing_dynamic_shape: Boolean. Only applies to
|
|
TPUStrategy. Default to False. If True, TPUStrategy will automatic
|
|
bucketize inputs passed into `run` if the input shape is
|
|
dynamic. This is a performance optimization to reduce XLA recompilation,
|
|
which should not have impact on correctness.
|
|
experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to
|
|
TPUStrategy. Controls the XLA compiling options on TPUs. Default to None.
|
|
"""
|
|
|
|
def __new__(cls,
|
|
experimental_enable_dynamic_batch_size=True,
|
|
experimental_bucketizing_dynamic_shape=False,
|
|
experimental_xla_options=None):
|
|
return super(RunOptions,
|
|
cls).__new__(cls, experimental_enable_dynamic_batch_size,
|
|
experimental_bucketizing_dynamic_shape,
|
|
experimental_xla_options)
|
|
|
|
|
|
@tf_export("distribute.InputOptions", v1=[])
|
|
class InputOptions(
|
|
collections.namedtuple("InputOptions", [
|
|
"experimental_fetch_to_device",
|
|
"experimental_replication_mode",
|
|
"experimental_place_dataset_on_device",
|
|
"experimental_per_replica_buffer_size",
|
|
])):
|
|
"""Run options for `experimental_distribute_dataset(s_from_function)`.
|
|
|
|
This can be used to hold some strategy specific configs.
|
|
|
|
```python
|
|
# Setup TPUStrategy
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
|
tf.config.experimental_connect_to_cluster(resolver)
|
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
strategy = tf.distribute.TPUStrategy(resolver)
|
|
|
|
dataset = tf.data.Dataset.range(16)
|
|
distributed_dataset_on_host = (
|
|
strategy.experimental_distribute_dataset(
|
|
dataset,
|
|
tf.distribute.InputOptions(
|
|
experimental_replication_mode=
|
|
experimental_replication_mode.PER_WORKER,
|
|
experimental_place_dataset_on_device=False,
|
|
experimental_per_replica_buffer_size=1)))
|
|
```
|
|
|
|
Attributes:
|
|
experimental_fetch_to_device: Boolean. If True, dataset
|
|
elements will be prefetched to accelerator device memory. When False,
|
|
dataset elements are prefetched to host device memory. Must be False when
|
|
using TPUEmbedding API. experimental_fetch_to_device can only be used
|
|
with experimental_replication_mode=PER_WORKER. Default behavior is same as
|
|
setting it to True.
|
|
experimental_replication_mode: Replication mode for the input function.
|
|
Currently, the InputReplicationMode.PER_REPLICA is only supported with
|
|
tf.distribute.MirroredStrategy.
|
|
experimental_distribute_datasets_from_function.
|
|
The default value is InputReplicationMode.PER_WORKER.
|
|
experimental_place_dataset_on_device: Boolean. Default to False. When True,
|
|
dataset will be placed on the device, otherwise it will remain on the
|
|
host. experimental_place_dataset_on_device=True can only be used with
|
|
experimental_replication_mode=PER_REPLICA
|
|
experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the
|
|
prefetch buffer size in the replica device memory. Users can set it
|
|
to 0 to completely disable prefetching behavior, or a number greater than
|
|
1 to enable larger buffer size. Note that this option is still
|
|
valid with `experimental_fetch_to_device=False`.
|
|
"""
|
|
|
|
def __new__(cls,
|
|
experimental_fetch_to_device=None,
|
|
experimental_replication_mode=InputReplicationMode.PER_WORKER,
|
|
experimental_place_dataset_on_device=False,
|
|
experimental_per_replica_buffer_size=1):
|
|
if experimental_fetch_to_device is None:
|
|
experimental_fetch_to_device = True
|
|
|
|
return super(InputOptions,
|
|
cls).__new__(cls, experimental_fetch_to_device,
|
|
experimental_replication_mode,
|
|
experimental_place_dataset_on_device,
|
|
experimental_per_replica_buffer_size)
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Base classes for all distribution strategies.
|
|
|
|
|
|
# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
|
|
# v1/v2 Strategy, add to implementing classes of StrategyBase.
|
|
# pylint: disable=line-too-long
|
|
class StrategyBase(object):
|
|
"""A state & compute distribution policy on a list of devices.
|
|
|
|
See [the guide](https://www.tensorflow.org/guide/distributed_training)
|
|
for overview and examples. See `tf.distribute.StrategyExtended` and
|
|
[`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
|
|
for a glossary of concepts mentioned on this page such as "per-replica",
|
|
_replica_, and _reduce_.
|
|
|
|
In short:
|
|
|
|
* To use it with Keras `compile`/`fit`,
|
|
[please
|
|
read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
|
|
* Otherwise, use `tf.distribute.Strategy.scope` to specify that a
|
|
strategy should be used when building an executing your model.
|
|
(This puts you in the "cross-replica context" for this strategy, which
|
|
means the strategy is put in control of things like variable placement.)
|
|
* If you are writing a custom training loop, you will need to call a few more
|
|
methods,
|
|
[see the
|
|
guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
|
|
|
|
* Start by creating a `tf.data.Dataset` normally.
|
|
* Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
|
|
a `tf.data.Dataset` to something that produces "per-replica" values.
|
|
If you want to manually specify how the dataset should be partitioned
|
|
across replicas, use
|
|
`tf.distribute.Strategy.distribute_datasets_from_function`
|
|
instead.
|
|
* Use `tf.distribute.Strategy.run` to run a function
|
|
once per replica, taking values that may be "per-replica" (e.g.
|
|
from a `tf.distribute.DistributedDataset` object) and returning
|
|
"per-replica" values.
|
|
This function is executed in "replica context", which means each
|
|
operation is performed separately on each replica.
|
|
* Finally use a method (such as `tf.distribute.Strategy.reduce`) to
|
|
convert the resulting "per-replica" values into ordinary `Tensor`s.
|
|
|
|
A custom training loop can be as simple as:
|
|
|
|
```
|
|
with my_strategy.scope():
|
|
@tf.function
|
|
def distribute_train_epoch(dataset):
|
|
def replica_fn(input):
|
|
# process input and return result
|
|
return result
|
|
|
|
total_result = 0
|
|
for x in dataset:
|
|
per_replica_result = my_strategy.run(replica_fn, args=(x,))
|
|
total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
|
|
per_replica_result, axis=None)
|
|
return total_result
|
|
|
|
dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
|
|
for _ in range(EPOCHS):
|
|
train_result = distribute_train_epoch(dist_dataset)
|
|
```
|
|
|
|
This takes an ordinary `dataset` and `replica_fn` and runs it
|
|
distributed using a particular `tf.distribute.Strategy` named
|
|
`my_strategy` above. Any variables created in `replica_fn` are created
|
|
using `my_strategy`'s policy, and library functions called by
|
|
`replica_fn` can use the `get_replica_context()` API to implement
|
|
distributed-specific behavior.
|
|
|
|
You can use the `reduce` API to aggregate results across replicas and use
|
|
this as a return value from one iteration over a
|
|
`tf.distribute.DistributedDataset`. Or
|
|
you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
|
|
accumulate metrics across steps in a given epoch.
|
|
|
|
See the
|
|
[custom training loop
|
|
tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
|
|
for a more detailed example.
|
|
|
|
Note: `tf.distribute.Strategy` currently does not support TensorFlow's
|
|
partitioned variables (where a single variable is split across multiple
|
|
devices) at this time.
|
|
"""
|
|
# pylint: enable=line-too-long
|
|
|
|
# TODO(josh11b): Partitioned computations, state; sharding
|
|
# TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
|
|
|
|
def __init__(self, extended):
|
|
self._extended = extended
|
|
|
|
self._scale_loss_for_estimator = False
|
|
|
|
if not hasattr(extended, "_retrace_functions_for_each_device"):
|
|
# pylint: disable=protected-access
|
|
# `extended._retrace_functions_for_each_device` dictates
|
|
# whether the same function will be retraced when it is called on
|
|
# different devices.
|
|
try:
|
|
extended._retrace_functions_for_each_device = (
|
|
len(extended.worker_devices) > 1)
|
|
distribution_strategy_replica_gauge.get_cell("num_replicas").set(
|
|
self.num_replicas_in_sync)
|
|
except: # pylint: disable=bare-except
|
|
# Default for the case where extended.worker_devices can't return
|
|
# a sensible value.
|
|
extended._retrace_functions_for_each_device = True
|
|
|
|
# Below are the dicts of axis(int) -> `tf.function`.
|
|
self._mean_reduce_helper_fns = {}
|
|
self._reduce_sum_fns = {}
|
|
|
|
# Whether this strategy is designed to work with `ClusterCoordinator`.
|
|
self._should_use_with_coordinator = False
|
|
|
|
@property
|
|
def extended(self):
|
|
"""`tf.distribute.StrategyExtended` with additional methods."""
|
|
return self._extended
|
|
|
|
@tf_contextlib.contextmanager
|
|
def _scale_loss_for_estimator_enabled(self):
|
|
"""Scope which sets a flag used for scaling losses in optimizer.
|
|
|
|
Yields:
|
|
`_scale_loss_for_estimator_enabled` is a context manager with a
|
|
side effect, but doesn't return a value.
|
|
"""
|
|
self._scale_loss_for_estimator = True
|
|
try:
|
|
yield
|
|
finally:
|
|
self._scale_loss_for_estimator = False
|
|
|
|
# pylint: disable=line-too-long
|
|
def scope(self):
|
|
"""Context manager to make the strategy current and distribute variables.
|
|
|
|
This method returns a context manager, and is used as follows:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> # Variable created inside scope:
|
|
>>> with strategy.scope():
|
|
... mirrored_variable = tf.Variable(1.)
|
|
>>> mirrored_variable
|
|
MirroredVariable:{
|
|
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
|
|
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
|
|
}
|
|
>>> # Variable created outside scope:
|
|
>>> regular_variable = tf.Variable(1.)
|
|
>>> regular_variable
|
|
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
|
|
|
|
_What happens when Strategy.scope is entered?_
|
|
|
|
* `strategy` is installed in the global context as the "current" strategy.
|
|
Inside this scope, `tf.distribute.get_strategy()` will now return this
|
|
strategy. Outside this scope, it returns the default no-op strategy.
|
|
* Entering the scope also enters the "cross-replica context". See
|
|
`tf.distribute.StrategyExtended` for an explanation on cross-replica and
|
|
replica contexts.
|
|
* Variable creation inside `scope` is intercepted by the strategy. Each
|
|
strategy defines how it wants to affect the variable creation. Sync
|
|
strategies like `MirroredStrategy`, `TPUStrategy` and
|
|
`MultiWorkerMiroredStrategy` create variables replicated on each replica,
|
|
whereas `ParameterServerStrategy` creates variables on the parameter
|
|
servers. This is done using a custom `tf.variable_creator_scope`.
|
|
* In some strategies, a default device scope may also be entered: in
|
|
`MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
|
|
entered on each worker.
|
|
|
|
Note: Entering a scope does not automatically distribute a computation, except
|
|
in the case of high level training framework like keras `model.fit`. If
|
|
you're not using `model.fit`, you
|
|
need to use `strategy.run` API to explicitly distribute that computation.
|
|
See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
|
|
|
|
|
|
_What should be in scope and what should be outside?_
|
|
|
|
There are a number of requirements on what needs to happen inside the scope.
|
|
However, in places where we have information about which strategy is in use,
|
|
we often enter the scope for the user, so they don't have to do it
|
|
explicitly (i.e. calling those either inside or outside the scope is OK).
|
|
|
|
* Anything that creates variables that should be distributed variables
|
|
must be called in a `strategy.scope`. This can be accomplished either by
|
|
directly calling the variable creating function within the scope context,
|
|
or by relying on another API like `strategy.run` or `keras.Model.fit` to
|
|
automatically enter it for you. Any variable that is created outside scope
|
|
will not be distributed and may have performance implications. Some common
|
|
objects that create variables in TF are Models, Optimizers, Metrics. Such
|
|
objects should always be initialized in the scope, and any functions
|
|
that may lazily create variables (e.g., `Model.__call__()`, tracing a
|
|
`tf.function`, etc.) should similarly be called within scope. Another
|
|
source of variable creation can be a checkpoint restore - when variables
|
|
are created lazily. Note that any variable created inside a strategy
|
|
captures the strategy information. So reading and writing to these
|
|
variables outside the `strategy.scope` can also work seamlessly, without
|
|
the user having to enter the scope.
|
|
* Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
|
|
require to be in a strategy's scope, enter the scope automatically, which
|
|
means when using those APIs you don't need to explicitly enter the scope
|
|
yourself.
|
|
* When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
|
|
object captures the scope information. When high level training framework
|
|
methods such as `model.compile`, `model.fit`, etc. are then called, the
|
|
captured scope will be automatically entered, and the associated strategy
|
|
will be used to distribute the training etc. See a detailed example in
|
|
[distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
|
|
WARNING: Simply calling `model(..)` does not automatically enter the
|
|
captured scope -- only high level training framework APIs support this
|
|
behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
|
|
and `model.save` can all be called inside or outside the scope.
|
|
* The following can be either inside or outside the scope:
|
|
* Creating the input datasets
|
|
* Defining `tf.function`s that represent your training step
|
|
* Saving APIs such as `tf.saved_model.save`. Loading creates variables,
|
|
so that should go inside the scope if you want to train the model in a
|
|
distributed way.
|
|
* Checkpoint saving. As mentioned above - `checkpoint.restore` may
|
|
sometimes need to be inside scope if it creates variables.
|
|
|
|
Returns:
|
|
A context manager.
|
|
"""
|
|
return self._extended._scope(self) # pylint: disable=protected-access
|
|
# pylint: enable=line-too-long
|
|
|
|
@doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended`
|
|
@deprecated(None, "use extended.colocate_vars_with() instead.")
|
|
def colocate_vars_with(self, colocate_with_variable):
|
|
"""DEPRECATED: use extended.colocate_vars_with() instead."""
|
|
return self._extended.colocate_vars_with(colocate_with_variable)
|
|
|
|
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
|
|
def make_dataset_iterator(self, dataset):
|
|
"""DEPRECATED TF 1.x ONLY."""
|
|
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
|
|
|
|
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
|
|
def make_input_fn_iterator(self,
|
|
input_fn,
|
|
replication_mode=InputReplicationMode.PER_WORKER):
|
|
"""DEPRECATED TF 1.x ONLY."""
|
|
if replication_mode != InputReplicationMode.PER_WORKER:
|
|
raise ValueError(
|
|
"Input replication mode not supported: %r" % replication_mode)
|
|
with self.scope():
|
|
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
|
|
input_fn, replication_mode=replication_mode)
|
|
|
|
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
|
|
@deprecated(None, "use run() instead")
|
|
def experimental_run(self, fn, input_iterator=None):
|
|
"""DEPRECATED TF 1.x ONLY."""
|
|
with self.scope():
|
|
args = (input_iterator.get_next(),) if input_iterator is not None else ()
|
|
return self.run(fn, args=args)
|
|
|
|
def experimental_distribute_dataset(self, dataset, options=None):
|
|
# pylint: disable=line-too-long
|
|
"""Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
|
|
|
|
The returned `tf.distribute.DistributedDataset` can be iterated over
|
|
similar to regular datasets.
|
|
NOTE: The user cannot add any more transformations to a
|
|
`tf.distribute.DistributedDataset`. You can only create an iterator or
|
|
examine the `tf.TypeSpec` of the data generated by it. See API docs of
|
|
`tf.distribute.DistributedDataset` to learn more.
|
|
|
|
The following is an example:
|
|
|
|
>>> global_batch_size = 2
|
|
>>> # Passing the devices is optional.
|
|
... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
|
|
>>> # Create a dataset
|
|
... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
|
|
>>> # Distribute that dataset
|
|
... dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
|
>>> @tf.function
|
|
... def replica_fn(input):
|
|
... return input*2
|
|
>>> result = []
|
|
>>> # Iterate over the `tf.distribute.DistributedDataset`
|
|
... for x in dist_dataset:
|
|
... # process dataset elements
|
|
... result.append(strategy.run(replica_fn, args=(x,)))
|
|
>>> print(result)
|
|
[PerReplica:{
|
|
0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
|
|
1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
|
|
}, PerReplica:{
|
|
0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
|
|
1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
|
|
}]
|
|
|
|
|
|
Three key actions happening under the hood of this method are batching,
|
|
sharding, and prefetching.
|
|
|
|
In the code snippet above, `dataset` is batched by `global_batch_size`, and
|
|
calling `experimental_distribute_dataset` on it rebatches `dataset` to a
|
|
new batch size that is equal to the global batch size divided by the number
|
|
of replicas in sync. We iterate through it using a Pythonic for loop.
|
|
`x` is a `tf.distribute.DistributedValues` containing data for all replicas,
|
|
and each replica gets data of the new batch size.
|
|
`tf.distribute.Strategy.run` will take care of feeding the right per-replica
|
|
data in `x` to the right `replica_fn` executed on each replica.
|
|
|
|
Sharding contains autosharding across multiple workers and within every
|
|
worker. First, in multi-worker distributed training (i.e. when you use
|
|
`tf.distribute.experimental.MultiWorkerMirroredStrategy`
|
|
or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of
|
|
workers means that each worker is assigned a subset of the entire dataset
|
|
(if the right `tf.data.experimental.AutoShardPolicy` is set). This is to
|
|
ensure that at each step, a global batch size of non-overlapping dataset
|
|
elements will be processed by each worker. Autosharding has a couple of
|
|
different options that can be specified using
|
|
`tf.data.experimental.DistributeOptions`. Then, sharding within each worker
|
|
means the method will split the data among all the worker devices (if more
|
|
than one a present). This will happen regardless of multi-worker
|
|
autosharding.
|
|
|
|
Note: for autosharding across multiple workers, the default mode is
|
|
`tf.data.experimental.AutoShardPolicy.AUTO`. This mode
|
|
will attempt to shard the input dataset by files if the dataset is
|
|
being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
|
|
`tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data,
|
|
where each of the workers will read the entire dataset and only process the
|
|
shard assigned to it. However, if you have less than one input file per
|
|
worker, we suggest that you disable dataset autosharding across workers by
|
|
setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be
|
|
`tf.data.experimental.AutoShardPolicy.OFF`.
|
|
|
|
By default, this method adds a prefetch transformation at the end of the
|
|
user provided `tf.data.Dataset` instance. The argument to the prefetch
|
|
transformation which is `buffer_size` is equal to the number of replicas in
|
|
sync.
|
|
|
|
If the above batch splitting and dataset sharding logic is undesirable,
|
|
please use
|
|
`tf.distribute.Strategy.distribute_datasets_from_function`
|
|
instead, which does not do any automatic batching or sharding for you.
|
|
|
|
Note: If you are using TPUStrategy, the order in which the data is processed
|
|
by the workers when using
|
|
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
|
`tf.distribute.Strategy.distribute_datasets_from_function` is
|
|
not guaranteed. This is typically required if you are using
|
|
`tf.distribute` to scale prediction. You can however insert an index for
|
|
each element in the batch and order outputs accordingly. Refer to [this
|
|
snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
|
|
for an example of how to order outputs.
|
|
|
|
Note: Stateful dataset transformations are currently not supported with
|
|
`tf.distribute.experimental_distribute_dataset` or
|
|
`tf.distribute.distribute_datasets_from_function`. Any stateful
|
|
ops that the dataset may have are currently ignored. For example, if your
|
|
dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
|
|
then you have a dataset graph that depends on state (i.e the random seed) on
|
|
the local machine where the python process is being executed.
|
|
|
|
For a tutorial on more usage and properties of this method, refer to the
|
|
[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
|
|
If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
|
|
|
|
Args:
|
|
dataset: `tf.data.Dataset` that will be sharded across all replicas using
|
|
the rules stated above.
|
|
options: `tf.distribute.InputOptions` used to control options on how this
|
|
dataset is distributed.
|
|
|
|
Returns:
|
|
A `tf.distribute.DistributedDataset`.
|
|
"""
|
|
distribution_strategy_input_api_counter.get_cell(
|
|
self.__class__.__name__, "distribute_dataset").increase_by(1)
|
|
# pylint: enable=line-too-long
|
|
return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access
|
|
|
|
def distribute_datasets_from_function(self, dataset_fn, options=None):
|
|
# pylint: disable=line-too-long
|
|
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
|
|
|
|
The argument `dataset_fn` that users pass in is an input function that has a
|
|
`tf.distribute.InputContext` argument and returns a `tf.data.Dataset`
|
|
instance. It is expected that the returned dataset from `dataset_fn` is
|
|
already batched by per-replica batch size (i.e. global batch size divided by
|
|
the number of replicas in sync) and sharded.
|
|
`tf.distribute.Strategy.distribute_datasets_from_function` does
|
|
not batch or shard the `tf.data.Dataset` instance
|
|
returned from the input function. `dataset_fn` will be called on the CPU
|
|
device of each of the workers and each generates a dataset where every
|
|
replica on that worker will dequeue one batch of inputs (i.e. if a worker
|
|
has two replicas, two batches will be dequeued from the `Dataset` every
|
|
step).
|
|
|
|
This method can be used for several purposes. First, it allows you to
|
|
specify your own batching and sharding logic. (In contrast,
|
|
`tf.distribute.experimental_distribute_dataset` does batching and sharding
|
|
for you.) For example, where
|
|
`experimental_distribute_dataset` is unable to shard the input files, this
|
|
method might be used to manually shard the dataset (avoiding the slow
|
|
fallback behavior in `experimental_distribute_dataset`). In cases where the
|
|
dataset is infinite, this sharding can be done by creating dataset replicas
|
|
that differ only in their random seed.
|
|
|
|
The `dataset_fn` should take an `tf.distribute.InputContext` instance where
|
|
information about batching and input replication can be accessed.
|
|
|
|
You can use `element_spec` property of the
|
|
`tf.distribute.DistributedDataset` returned by this API to query the
|
|
`tf.TypeSpec` of the elements returned by the iterator. This can be used to
|
|
set the `input_signature` property of a `tf.function`. Follow
|
|
`tf.distribute.DistributedDataset.element_spec` to see an example.
|
|
|
|
IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
|
|
per-replica batch size, unlike `experimental_distribute_dataset`, which uses
|
|
the global batch size. This may be computed using
|
|
`input_context.get_per_replica_batch_size`.
|
|
|
|
Note: If you are using TPUStrategy, the order in which the data is processed
|
|
by the workers when using
|
|
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
|
`tf.distribute.Strategy.distribute_datasets_from_function` is
|
|
not guaranteed. This is typically required if you are using
|
|
`tf.distribute` to scale prediction. You can however insert an index for
|
|
each element in the batch and order outputs accordingly. Refer to [this
|
|
snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
|
|
for an example of how to order outputs.
|
|
|
|
Note: Stateful dataset transformations are currently not supported with
|
|
`tf.distribute.experimental_distribute_dataset` or
|
|
`tf.distribute.distribute_datasets_from_function`. Any stateful
|
|
ops that the dataset may have are currently ignored. For example, if your
|
|
dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
|
|
then you have a dataset graph that depends on state (i.e the random seed) on
|
|
the local machine where the python process is being executed.
|
|
|
|
For a tutorial on more usage and properties of this method, refer to the
|
|
[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
|
|
If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
|
|
|
|
Args:
|
|
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
|
|
returning a `tf.data.Dataset`.
|
|
options: `tf.distribute.InputOptions` used to control options on how this
|
|
dataset is distributed.
|
|
|
|
Returns:
|
|
A `tf.distribute.DistributedDataset`.
|
|
"""
|
|
distribution_strategy_input_api_counter.get_cell(
|
|
self.__class__.__name__,
|
|
"distribute_datasets_from_function").increase_by(1)
|
|
# pylint: enable=line-too-long
|
|
return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access
|
|
dataset_fn, options)
|
|
|
|
# TODO(b/162776748): Remove deprecated symbol.
|
|
@doc_controls.do_not_doc_inheritable
|
|
@deprecation.deprecated(None, "rename to distribute_datasets_from_function")
|
|
def experimental_distribute_datasets_from_function(self,
|
|
dataset_fn,
|
|
options=None):
|
|
return self.distribute_datasets_from_function(dataset_fn, options)
|
|
|
|
def run(self, fn, args=(), kwargs=None, options=None):
|
|
"""Invokes `fn` on each replica, with the given arguments.
|
|
|
|
This method is the primary way to distribute your computation with a
|
|
tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
|
|
have `tf.distribute.DistributedValues`, such as those produced by a
|
|
`tf.distribute.DistributedDataset` from
|
|
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
|
`tf.distribute.Strategy.distribute_datasets_from_function`,
|
|
when `fn` is executed on a particular replica, it will be executed with the
|
|
component of `tf.distribute.DistributedValues` that correspond to that
|
|
replica.
|
|
|
|
`fn` is invoked under a replica context. `fn` may call
|
|
`tf.distribute.get_replica_context()` to access members such as
|
|
`all_reduce`. Please see the module-level docstring of tf.distribute for the
|
|
concept of replica context.
|
|
|
|
All arguments in `args` or `kwargs` can be a nested structure of tensors,
|
|
e.g. a list of tensors, in which case `args` and `kwargs` will be passed to
|
|
the `fn` invoked on each replica. Or `args` or `kwargs` can be
|
|
`tf.distribute.DistributedValues` containing tensors or composite tensors,
|
|
i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call
|
|
will get the component of a `tf.distribute.DistributedValues` corresponding
|
|
to its replica. Note that arbitrary Python values that are not of the types
|
|
above are not supported.
|
|
|
|
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
|
|
whether eager execution is enabled, `fn` may be called one or more times. If
|
|
`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
|
|
called inside a `tf.function` (eager execution is disabled inside a
|
|
`tf.function` by default), `fn` is called once per replica to generate a
|
|
Tensorflow graph, which will then be reused for execution with new inputs.
|
|
Otherwise, if eager execution is enabled, `fn` will be called once per
|
|
replica every step just like regular python code.
|
|
|
|
Example usage:
|
|
|
|
1. Constant tensor input.
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> tensor_input = tf.constant(3.0)
|
|
>>> @tf.function
|
|
... def replica_fn(input):
|
|
... return input*2.0
|
|
>>> result = strategy.run(replica_fn, args=(tensor_input,))
|
|
>>> result
|
|
PerReplica:{
|
|
0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
|
|
1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
|
|
}
|
|
|
|
2. DistributedValues input. {: value=2}
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> @tf.function
|
|
... def run():
|
|
... def value_fn(value_context):
|
|
... return value_context.num_replicas_in_sync
|
|
... distributed_values = (
|
|
... strategy.experimental_distribute_values_from_function(
|
|
... value_fn))
|
|
... def replica_fn2(input):
|
|
... return input*2
|
|
... return strategy.run(replica_fn2, args=(distributed_values,))
|
|
>>> result = run()
|
|
>>> result
|
|
<tf.Tensor: shape=(), dtype=int32, numpy=4>
|
|
|
|
3. Use `tf.distribute.ReplicaContext` to allreduce values. {: value=3}
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
|
|
>>> @tf.function
|
|
... def run():
|
|
... def value_fn(value_context):
|
|
... return tf.constant(value_context.replica_id_in_sync_group)
|
|
... distributed_values = (
|
|
... strategy.experimental_distribute_values_from_function(
|
|
... value_fn))
|
|
... def replica_fn(input):
|
|
... return tf.distribute.get_replica_context().all_reduce(
|
|
... "sum", input)
|
|
... return strategy.run(replica_fn, args=(distributed_values,))
|
|
>>> result = run()
|
|
>>> result
|
|
PerReplica:{
|
|
0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
|
|
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
|
|
}
|
|
|
|
Args:
|
|
fn: The function to run on each replica.
|
|
args: Optional positional arguments to `fn`. Its element can be a tensor,
|
|
a nested structure of tensors or a `tf.distribute.DistributedValues`.
|
|
kwargs: Optional keyword arguments to `fn`. Its element can be a tensor,
|
|
a nested structure of tensors or a `tf.distribute.DistributedValues`.
|
|
options: An optional instance of `tf.distribute.RunOptions` specifying
|
|
the options to run `fn`.
|
|
|
|
Returns:
|
|
Merged return value of `fn` across replicas. The structure of the return
|
|
value is the same as the return value from `fn`. Each element in the
|
|
structure can either be `tf.distribute.DistributedValues`, `Tensor`
|
|
objects, or `Tensor`s (for example, if running on a single replica).
|
|
"""
|
|
del options
|
|
|
|
if not isinstance(args, (list, tuple)):
|
|
raise ValueError(
|
|
"positional args must be a list or tuple, got {}".format(type(args)))
|
|
|
|
with self.scope():
|
|
# tf.distribute supports Eager functions, so AutoGraph should not be
|
|
# applied when the caller is also in Eager mode.
|
|
fn = autograph.tf_convert(
|
|
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
|
|
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
|
|
|
def reduce(self, reduce_op, value, axis):
|
|
"""Reduce `value` across replicas and return result on current device.
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> def step_fn():
|
|
... i = tf.distribute.get_replica_context().replica_id_in_sync_group
|
|
... return tf.identity(i)
|
|
>>>
|
|
>>> per_replica_result = strategy.run(step_fn)
|
|
>>> total = strategy.reduce("SUM", per_replica_result, axis=None)
|
|
>>> total
|
|
<tf.Tensor: shape=(), dtype=int32, numpy=1>
|
|
|
|
To see how this would look with multiple replicas, consider the same
|
|
example with MirroredStrategy with 2 GPUs:
|
|
|
|
```python
|
|
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
|
|
def step_fn():
|
|
i = tf.distribute.get_replica_context().replica_id_in_sync_group
|
|
return tf.identity(i)
|
|
|
|
per_replica_result = strategy.run(step_fn)
|
|
# Check devices on which per replica result is:
|
|
strategy.experimental_local_results(per_replica_result)[0].device
|
|
# /job:localhost/replica:0/task:0/device:GPU:0
|
|
strategy.experimental_local_results(per_replica_result)[1].device
|
|
# /job:localhost/replica:0/task:0/device:GPU:1
|
|
|
|
total = strategy.reduce("SUM", per_replica_result, axis=None)
|
|
# Check device on which reduced result is:
|
|
total.device
|
|
# /job:localhost/replica:0/task:0/device:CPU:0
|
|
|
|
```
|
|
|
|
This API is typically used for aggregating the results returned from
|
|
different replicas, for reporting etc. For example, loss computed from
|
|
different replicas can be averaged using this API before printing.
|
|
|
|
Note: The result is copied to the "current" device - which would typically
|
|
be the CPU of the worker on which the program is running. For `TPUStrategy`,
|
|
it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`,
|
|
this is CPU of each worker.
|
|
|
|
There are a number of different tf.distribute APIs for reducing values
|
|
across replicas:
|
|
* `tf.distribute.ReplicaContext.all_reduce`: This differs from
|
|
`Strategy.reduce` in that it is for replica context and does
|
|
not copy the results to the host device. `all_reduce` should be typically
|
|
used for reductions inside the training step such as gradients.
|
|
* `tf.distribute.StrategyExtended.reduce_to` and
|
|
`tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more
|
|
advanced versions of `Strategy.reduce` as they allow customizing the
|
|
destination of the result. They are also called in cross replica context.
|
|
|
|
_What should axis be?_
|
|
|
|
Given a per-replica value returned by `run`, say a
|
|
per-example loss, the batch will be divided across all the replicas. This
|
|
function allows you to aggregate across replicas and optionally also across
|
|
batch elements by specifying the axis parameter accordingly.
|
|
|
|
For example, if you have a global batch size of 8 and 2
|
|
replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
|
|
`[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will
|
|
aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`.
|
|
This is useful when each replica is computing a scalar or some other value
|
|
that doesn't have a "batch" dimension (like a gradient or loss).
|
|
```
|
|
strategy.reduce("sum", per_replica_result, axis=None)
|
|
```
|
|
|
|
Sometimes, you will want to aggregate across both the global batch _and_
|
|
all replicas. You can get this behavior by specifying the batch
|
|
dimension as the `axis`, typically `axis=0`. In this case it would return a
|
|
scalar `0+1+2+3+4+5+6+7`.
|
|
```
|
|
strategy.reduce("sum", per_replica_result, axis=0)
|
|
```
|
|
|
|
If there is a last partial batch, you will need to specify an axis so
|
|
that the resulting shape is consistent across replicas. So if the last
|
|
batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
|
|
would get a shape mismatch unless you specify `axis=0`. If you specify
|
|
`tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
|
|
denominator of 6. Contrast this with computing `reduce_mean` to get a
|
|
scalar value on each replica and this function to average those means,
|
|
which will weigh some values `1/8` and others `1/4`.
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
|
|
be combined. Allows using string representation of the enum such as
|
|
"SUM", "MEAN".
|
|
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
|
|
`Strategy.run`, to be combined into a single tensor. It can also be a
|
|
regular tensor when used with `OneDeviceStrategy` or default strategy.
|
|
axis: specifies the dimension to reduce along within each
|
|
replica's tensor. Should typically be set to the batch dimension, or
|
|
`None` to only reduce across replicas (e.g. if the tensor has no batch
|
|
dimension).
|
|
|
|
Returns:
|
|
A `Tensor`.
|
|
"""
|
|
# TODO(josh11b): support `value` being a nest.
|
|
_require_cross_replica_or_default_context_extended(self._extended)
|
|
if isinstance(reduce_op, six.string_types):
|
|
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
|
if axis is None:
|
|
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
|
|
if reduce_op == reduce_util.ReduceOp.SUM:
|
|
|
|
def reduce_sum(v):
|
|
return math_ops.reduce_sum(v, axis=axis)
|
|
|
|
if eager_context.executing_eagerly():
|
|
# As some strategies (e.g. TPUStrategy) doesn't support pure eager
|
|
# execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be
|
|
# run from eager mode. Cache the tf.function by `axis` to avoid the
|
|
# same function to be traced again.
|
|
if axis not in self._reduce_sum_fns:
|
|
self._reduce_sum_fns[axis] = def_function.function(reduce_sum)
|
|
value = self.run(self._reduce_sum_fns[axis], args=(value,))
|
|
else:
|
|
value = self.run(reduce_sum, args=(value,))
|
|
|
|
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
|
|
if reduce_op != reduce_util.ReduceOp.MEAN:
|
|
raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
|
|
"not: %r" % reduce_op)
|
|
|
|
def mean_reduce_helper(v, axes=axis):
|
|
"""Computes the numerator and denominator on each replica."""
|
|
numer = math_ops.reduce_sum(v, axis=axes)
|
|
def dimension(axis):
|
|
if v.shape.rank is not None:
|
|
# Note(joshl): We support axis < 0 to be consistent with the
|
|
# tf.math.reduce_* operations.
|
|
if axis < 0:
|
|
if axis + v.shape.rank < 0:
|
|
raise ValueError(
|
|
"`axis` = %r out of range for `value` with rank %d" %
|
|
(axis, v.shape.rank))
|
|
axis += v.shape.rank
|
|
elif axis >= v.shape.rank:
|
|
raise ValueError(
|
|
"`axis` = %r out of range for `value` with rank %d" %
|
|
(axis, v.shape.rank))
|
|
# TF v2 returns `None` for unknown dimensions and an integer for
|
|
# known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
|
|
# or tensor_shape.Dimension(integer). `dimension_value` hides this
|
|
# difference, always returning `None` or an integer.
|
|
dim = tensor_shape.dimension_value(v.shape[axis])
|
|
if dim is not None:
|
|
# By returning a python value in the static shape case, we can
|
|
# maybe get a fast path for reducing the denominator.
|
|
# TODO(b/151871486): Remove array_ops.identity after we fallback to
|
|
# simple reduction if inputs are all on CPU.
|
|
return array_ops.identity(
|
|
constant_op.constant(dim, dtype=dtypes.int64))
|
|
elif axis < 0:
|
|
axis = axis + array_ops.rank(v)
|
|
# TODO(b/151871486): Remove array_ops.identity after we fallback to
|
|
# simple reduction if inputs are all on CPU.
|
|
return array_ops.identity(
|
|
array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
|
|
if isinstance(axis, six.integer_types):
|
|
denom = dimension(axis)
|
|
elif isinstance(axis, (tuple, list)):
|
|
denom = math_ops.reduce_prod([dimension(a) for a in axes])
|
|
else:
|
|
raise TypeError(
|
|
"Expected `axis` to be an integer, tuple or list not: %r" % axis)
|
|
# TODO(josh11b): Should we cast denom to v.dtype here instead of after the
|
|
# reduce is complete?
|
|
return numer, denom
|
|
|
|
if eager_context.executing_eagerly():
|
|
# As some strategies (e.g. TPUStrategy) doesn't support pure eager
|
|
# execution, wrap the `mean_reduce_helper` with a `tf.function` so it can
|
|
# be run from eager mode. Cache the tf.function by `axis` to avoid the
|
|
# same function to be traced again.
|
|
if axis not in self._mean_reduce_helper_fns:
|
|
self._mean_reduce_helper_fns[axis] = def_function.function(
|
|
mean_reduce_helper)
|
|
numer, denom = self.run(self._mean_reduce_helper_fns[axis], args=(value,))
|
|
else:
|
|
numer, denom = self.run(mean_reduce_helper, args=(value,))
|
|
|
|
# TODO(josh11b): Should batch reduce here instead of doing two.
|
|
numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access
|
|
denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access
|
|
denom = math_ops.cast(denom, numer.dtype)
|
|
return math_ops.truediv(numer, denom)
|
|
|
|
@doc_controls.do_not_doc_inheritable # DEPRECATED
|
|
@deprecated(None, "use `experimental_local_results` instead.")
|
|
def unwrap(self, value):
|
|
"""Returns the list of all local per-replica values contained in `value`.
|
|
|
|
DEPRECATED: Please use `experimental_local_results` instead.
|
|
|
|
Note: This only returns values on the workers initiated by this client.
|
|
When using a `tf.distribute.Strategy` like
|
|
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
|
|
will be its own client, and this function will only return values
|
|
computed on that worker.
|
|
|
|
Args:
|
|
value: A value returned by `experimental_run()`,
|
|
`extended.call_for_each_replica()`, or a variable created in `scope`.
|
|
|
|
Returns:
|
|
A tuple of values contained in `value`. If `value` represents a single
|
|
value, this returns `(value,).`
|
|
"""
|
|
return self._extended._local_results(value) # pylint: disable=protected-access
|
|
|
|
def experimental_local_results(self, value):
|
|
"""Returns the list of all local per-replica values contained in `value`.
|
|
|
|
Note: This only returns values on the worker initiated by this client.
|
|
When using a `tf.distribute.Strategy` like
|
|
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
|
|
will be its own client, and this function will only return values
|
|
computed on that worker.
|
|
|
|
Args:
|
|
value: A value returned by `experimental_run()`, `run(), or a variable
|
|
created in `scope`.
|
|
|
|
Returns:
|
|
A tuple of values contained in `value` where ith element corresponds to
|
|
ith replica. If `value` represents a single value, this returns
|
|
`(value,).`
|
|
"""
|
|
return self._extended._local_results(value) # pylint: disable=protected-access
|
|
|
|
@doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only
|
|
def group(self, value, name=None):
|
|
"""Shortcut for `tf.group(self.experimental_local_results(value))`."""
|
|
return self._extended._group(value, name) # pylint: disable=protected-access
|
|
|
|
@property
|
|
def num_replicas_in_sync(self):
|
|
"""Returns number of replicas over which gradients are aggregated."""
|
|
return self._extended._num_replicas_in_sync # pylint: disable=protected-access
|
|
|
|
@doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string
|
|
@deprecated(None, "use `update_config_proto` instead.")
|
|
def configure(self,
|
|
session_config=None,
|
|
cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None):
|
|
# pylint: disable=g-doc-return-or-yield,g-doc-args
|
|
"""DEPRECATED: use `update_config_proto` instead.
|
|
|
|
Configures the strategy class.
|
|
|
|
DEPRECATED: This method's functionality has been split into the strategy
|
|
constructor and `update_config_proto`. In the future, we will allow passing
|
|
cluster and config_proto to the constructor to configure the strategy. And
|
|
`update_config_proto` can be used to update the config_proto based on the
|
|
specific strategy.
|
|
"""
|
|
return self._extended._configure( # pylint: disable=protected-access
|
|
session_config, cluster_spec, task_type, task_id)
|
|
|
|
@doc_controls.do_not_generate_docs # DEPRECATED
|
|
def update_config_proto(self, config_proto):
|
|
"""DEPRECATED TF 1.x ONLY."""
|
|
return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
|
|
|
|
def __deepcopy__(self, memo):
|
|
# First do a regular deepcopy of `self`.
|
|
cls = self.__class__
|
|
result = cls.__new__(cls)
|
|
memo[id(self)] = result
|
|
for k, v in self.__dict__.items():
|
|
setattr(result, k, copy.deepcopy(v, memo))
|
|
# One little fix-up: we want `result._extended` to reference `result`
|
|
# instead of `self`.
|
|
result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access
|
|
return result
|
|
|
|
def __copy__(self):
|
|
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
|
|
|
@property
|
|
def cluster_resolver(self):
|
|
"""Returns the cluster resolver associated with this strategy.
|
|
|
|
In general, when using a multi-worker `tf.distribute` strategy such as
|
|
`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
|
|
`tf.distribute.TPUStrategy()`, there is a
|
|
`tf.distribute.cluster_resolver.ClusterResolver` associated with the
|
|
strategy used, and such an instance is returned by this property.
|
|
|
|
Strategies that intend to have an associated
|
|
`tf.distribute.cluster_resolver.ClusterResolver` must set the
|
|
relevant attribute, or override this property; otherwise, `None` is returned
|
|
by default. Those strategies should also provide information regarding what
|
|
is returned by this property.
|
|
|
|
Single-worker strategies usually do not have a
|
|
`tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
|
|
property will return `None`.
|
|
|
|
The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
|
|
user needs to access information such as the cluster spec, task type or task
|
|
id. For example,
|
|
|
|
```python
|
|
|
|
os.environ['TF_CONFIG'] = json.dumps({
|
|
'cluster': {
|
|
'worker': ["localhost:12345", "localhost:23456"],
|
|
'ps': ["localhost:34567"]
|
|
},
|
|
'task': {'type': 'worker', 'index': 0}
|
|
})
|
|
|
|
# This implicitly uses TF_CONFIG for the cluster and current task info.
|
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
|
|
|
...
|
|
|
|
if strategy.cluster_resolver.task_type == 'worker':
|
|
# Perform something that's only applicable on workers. Since we set this
|
|
# as a worker above, this block will run on this particular instance.
|
|
elif strategy.cluster_resolver.task_type == 'ps':
|
|
# Perform something that's only applicable on parameter servers. Since we
|
|
# set this as a worker above, this block will not run on this particular
|
|
# instance.
|
|
```
|
|
|
|
For more information, please see
|
|
`tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
|
|
|
|
Returns:
|
|
The cluster resolver associated with this strategy. Returns `None` if a
|
|
cluster resolver is not applicable or available in this strategy.
|
|
"""
|
|
if hasattr(self.extended, "_cluster_resolver"):
|
|
return self.extended._cluster_resolver # pylint: disable=protected-access
|
|
return None
|
|
|
|
|
|
@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring
|
|
class Strategy(StrategyBase):
|
|
|
|
__doc__ = StrategyBase.__doc__
|
|
|
|
def experimental_distribute_values_from_function(self, value_fn):
|
|
"""Generates `tf.distribute.DistributedValues` from `value_fn`.
|
|
|
|
This function is to generate `tf.distribute.DistributedValues` to pass
|
|
into `run`, `reduce`, or other methods that take
|
|
distributed values when not using datasets.
|
|
|
|
Args:
|
|
value_fn: The function to run to generate values. It is called for
|
|
each replica with `tf.distribute.ValueContext` as the sole argument. It
|
|
must return a Tensor or a type that can be converted to a Tensor.
|
|
Returns:
|
|
A `tf.distribute.DistributedValues` containing a value for each replica.
|
|
|
|
Example usage:
|
|
|
|
1. Return constant value per replica:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> def value_fn(ctx):
|
|
... return tf.constant(1.)
|
|
>>> distributed_values = (
|
|
... strategy.experimental_distribute_values_from_function(
|
|
... value_fn))
|
|
>>> local_result = strategy.experimental_local_results(
|
|
... distributed_values)
|
|
>>> local_result
|
|
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
|
|
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
|
|
|
|
2. Distribute values in array based on replica_id: {: value=2}
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> array_value = np.array([3., 2., 1.])
|
|
>>> def value_fn(ctx):
|
|
... return array_value[ctx.replica_id_in_sync_group]
|
|
>>> distributed_values = (
|
|
... strategy.experimental_distribute_values_from_function(
|
|
... value_fn))
|
|
>>> local_result = strategy.experimental_local_results(
|
|
... distributed_values)
|
|
>>> local_result
|
|
(3.0, 2.0)
|
|
|
|
3. Specify values using num_replicas_in_sync: {: value=3}
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> def value_fn(ctx):
|
|
... return ctx.num_replicas_in_sync
|
|
>>> distributed_values = (
|
|
... strategy.experimental_distribute_values_from_function(
|
|
... value_fn))
|
|
>>> local_result = strategy.experimental_local_results(
|
|
... distributed_values)
|
|
>>> local_result
|
|
(2, 2)
|
|
|
|
4. Place values on devices and distribute: {: value=4}
|
|
|
|
```
|
|
strategy = tf.distribute.TPUStrategy()
|
|
worker_devices = strategy.extended.worker_devices
|
|
multiple_values = []
|
|
for i in range(strategy.num_replicas_in_sync):
|
|
with tf.device(worker_devices[i]):
|
|
multiple_values.append(tf.constant(1.0))
|
|
|
|
def value_fn(ctx):
|
|
return multiple_values[ctx.replica_id_in_sync_group]
|
|
|
|
distributed_values = strategy.
|
|
experimental_distribute_values_from_function(
|
|
value_fn)
|
|
```
|
|
|
|
"""
|
|
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
|
|
value_fn)
|
|
|
|
def gather(self, value, axis):
|
|
# pylint: disable=line-too-long, protected-access
|
|
"""Gather `value` across replicas along `axis` to the current device.
|
|
|
|
Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
|
object `value`, this API gathers and concatenates `value` across replicas
|
|
along the `axis`-th dimension. The result is copied to the "current" device,
|
|
which would typically be the CPU of the worker on which the program is
|
|
running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
|
|
multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of
|
|
each worker.
|
|
|
|
This API can only be called in the cross-replica context. For a counterpart
|
|
in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
|
|
|
|
Note: For all strategies except `tf.distribute.TPUStrategy`, the input
|
|
`value` on different replicas must have the same rank, and their shapes must
|
|
be the same in all dimensions except the `axis`-th dimension. In other
|
|
words, their shapes cannot be different in a dimension `d` where `d` does
|
|
not equal to the `axis` argument. For example, given a
|
|
`tf.distribute.DistributedValues` with component tensors of shape
|
|
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
|
|
`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
|
|
`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
|
|
all tensors must have exactly the same rank and same shape.
|
|
|
|
Note: Given a `tf.distribute.DistributedValues` `value`, its component
|
|
tensors must have a non-zero rank. Otherwise, consider using
|
|
`tf.expand_dims` before gathering them.
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
|
|
... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
|
|
>>> @tf.function
|
|
... def run():
|
|
... return strategy.gather(distributed_values, axis=0)
|
|
>>> run()
|
|
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
|
|
array([[1],
|
|
[2],
|
|
[1],
|
|
[2]], dtype=int32)>
|
|
|
|
|
|
Consider the following example for more combinations:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
|
|
>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
|
|
>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
|
|
>>> @tf.function
|
|
... def run(axis):
|
|
... return strategy.gather(distributed_values, axis=axis)
|
|
>>> axis=0
|
|
>>> run(axis)
|
|
<tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
|
|
array([[[0, 1, 2],
|
|
[3, 4, 5]],
|
|
[[0, 1, 2],
|
|
[3, 4, 5]],
|
|
[[0, 1, 2],
|
|
[3, 4, 5]],
|
|
[[0, 1, 2],
|
|
[3, 4, 5]]], dtype=int32)>
|
|
>>> axis=1
|
|
>>> run(axis)
|
|
<tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
|
|
array([[[0, 1, 2],
|
|
[3, 4, 5],
|
|
[0, 1, 2],
|
|
[3, 4, 5],
|
|
[0, 1, 2],
|
|
[3, 4, 5],
|
|
[0, 1, 2],
|
|
[3, 4, 5]]], dtype=int32)>
|
|
>>> axis=2
|
|
>>> run(axis)
|
|
<tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
|
|
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
|
|
[3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
|
|
|
|
|
|
Args:
|
|
value: a `tf.distribute.DistributedValues` instance, e.g. returned by
|
|
`Strategy.run`, to be combined into a single tensor. It can also be a
|
|
regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
|
|
default strategy. The tensors that constitute the DistributedValues
|
|
can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
|
|
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
|
range [0, rank(value)).
|
|
|
|
Returns:
|
|
A `Tensor` that's the concatenation of `value` across replicas along
|
|
`axis` dimension.
|
|
"""
|
|
# pylint: enable=line-too-long
|
|
error_message = ("tf.distribute.Strategy.gather method requires "
|
|
"cross-replica context, use "
|
|
"get_replica_context().all_gather() instead")
|
|
_require_cross_replica_or_default_context_extended(self._extended,
|
|
error_message)
|
|
dst = device_util.current(
|
|
) or self._extended._default_device or "/device:CPU:0"
|
|
if isinstance(value, indexed_slices.IndexedSlices):
|
|
raise NotImplementedError("gather does not support IndexedSlices")
|
|
return self._extended._local_results(
|
|
self._extended._gather_to(value, dst, axis))[0]
|
|
|
|
|
|
# TF v1.x version has additional deprecated APIs
|
|
@tf_export(v1=["distribute.Strategy"])
|
|
class StrategyV1(StrategyBase):
|
|
"""A list of devices with a state & compute distribution policy.
|
|
|
|
See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
|
|
for overview and examples.
|
|
|
|
Note: Not all `tf.distribute.Strategy` implementations currently support
|
|
TensorFlow's partitioned variables (where a single variable is split across
|
|
multiple devices) at this time.
|
|
"""
|
|
|
|
def make_dataset_iterator(self, dataset):
|
|
"""Makes an iterator for input provided via `dataset`.
|
|
|
|
DEPRECATED: This method is not available in TF 2.x.
|
|
|
|
Data from the given dataset will be distributed evenly across all the
|
|
compute replicas. We will assume that the input dataset is batched by the
|
|
global batch size. With this assumption, we will make a best effort to
|
|
divide each batch across all the replicas (one or more workers).
|
|
If this effort fails, an error will be thrown, and the user should instead
|
|
use `make_input_fn_iterator` which provides more control to the user, and
|
|
does not try to divide a batch across replicas.
|
|
|
|
The user could also use `make_input_fn_iterator` if they want to
|
|
customize which input is fed to which replica/worker etc.
|
|
|
|
Args:
|
|
dataset: `tf.data.Dataset` that will be distributed evenly across all
|
|
replicas.
|
|
|
|
Returns:
|
|
An `tf.distribute.InputIterator` which returns inputs for each step of the
|
|
computation. User should call `initialize` on the returned iterator.
|
|
"""
|
|
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
|
|
|
|
def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation
|
|
input_fn,
|
|
replication_mode=InputReplicationMode.PER_WORKER):
|
|
"""Returns an iterator split across replicas created from an input function.
|
|
|
|
DEPRECATED: This method is not available in TF 2.x.
|
|
|
|
The `input_fn` should take an `tf.distribute.InputContext` object where
|
|
information about batching and input sharding can be accessed:
|
|
|
|
```
|
|
def input_fn(input_context):
|
|
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
|
|
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
|
|
return d.shard(input_context.num_input_pipelines,
|
|
input_context.input_pipeline_id)
|
|
with strategy.scope():
|
|
iterator = strategy.make_input_fn_iterator(input_fn)
|
|
replica_results = strategy.experimental_run(replica_fn, iterator)
|
|
```
|
|
|
|
The `tf.data.Dataset` returned by `input_fn` should have a per-replica
|
|
batch size, which may be computed using
|
|
`input_context.get_per_replica_batch_size`.
|
|
|
|
Args:
|
|
input_fn: A function taking a `tf.distribute.InputContext` object and
|
|
returning a `tf.data.Dataset`.
|
|
replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
|
|
Only `PER_WORKER` is supported currently, which means there will be
|
|
a single call to `input_fn` per worker. Replicas will dequeue from the
|
|
local `tf.data.Dataset` on their worker.
|
|
|
|
Returns:
|
|
An iterator object that should first be `.initialize()`-ed. It may then
|
|
either be passed to `strategy.experimental_run()` or you can
|
|
`iterator.get_next()` to get the next value to pass to
|
|
`strategy.extended.call_for_each_replica()`.
|
|
"""
|
|
return super(StrategyV1, self).make_input_fn_iterator(
|
|
input_fn, replication_mode)
|
|
|
|
def experimental_make_numpy_dataset(self, numpy_input, session=None):
|
|
"""Makes a tf.data.Dataset for input provided via a numpy array.
|
|
|
|
This avoids adding `numpy_input` as a large constant in the graph,
|
|
and copies the data to the machine or machines that will be processing
|
|
the input.
|
|
|
|
Note that you will likely need to use
|
|
tf.distribute.Strategy.experimental_distribute_dataset
|
|
with the returned dataset to further distribute it with the strategy.
|
|
|
|
Example:
|
|
```
|
|
numpy_input = np.ones([10], dtype=np.float32)
|
|
dataset = strategy.experimental_make_numpy_dataset(numpy_input)
|
|
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
|
```
|
|
|
|
Args:
|
|
numpy_input: A nest of NumPy input arrays that will be converted into a
|
|
dataset. Note that lists of Numpy arrays are stacked, as that is normal
|
|
`tf.data.Dataset` behavior.
|
|
session: (TensorFlow v1.x graph execution only) A session used for
|
|
initialization.
|
|
|
|
Returns:
|
|
A `tf.data.Dataset` representing `numpy_input`.
|
|
"""
|
|
return self.extended.experimental_make_numpy_dataset(
|
|
numpy_input, session=session)
|
|
|
|
@deprecated(
|
|
None,
|
|
"This method is not available in TF 2.x. Please switch to using `run` instead."
|
|
)
|
|
def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation
|
|
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
|
|
|
|
DEPRECATED: This method is not available in TF 2.x. Please switch
|
|
to using `run` instead.
|
|
|
|
When eager execution is enabled, executes ops specified by `fn` on each
|
|
replica. Otherwise, builds a graph to execute the ops on each replica.
|
|
|
|
Each replica will take a single, different input from the inputs provided by
|
|
one `get_next` call on the input iterator.
|
|
|
|
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
|
as `replica_id_in_sync_group`.
|
|
|
|
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
|
|
used, and whether eager execution is enabled, `fn` may be called one or more
|
|
times (once for each replica).
|
|
|
|
Args:
|
|
fn: The function to run. The inputs to the function must match the outputs
|
|
of `input_iterator.get_next()`. The output must be a `tf.nest` of
|
|
`Tensor`s.
|
|
input_iterator: (Optional) input iterator from which the inputs are taken.
|
|
|
|
Returns:
|
|
Merged return value of `fn` across replicas. The structure of the return
|
|
value is the same as the return value from `fn`. Each element in the
|
|
structure can either be `PerReplica` (if the values are unsynchronized),
|
|
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
|
|
single replica).
|
|
"""
|
|
return super(StrategyV1, self).experimental_run(
|
|
fn, input_iterator)
|
|
|
|
def reduce(self, reduce_op, value, axis=None):
|
|
return super(StrategyV1, self).reduce(reduce_op, value, axis)
|
|
|
|
reduce.__doc__ = StrategyBase.reduce.__doc__
|
|
|
|
def update_config_proto(self, config_proto):
|
|
"""Returns a copy of `config_proto` modified for use with this strategy.
|
|
|
|
DEPRECATED: This method is not available in TF 2.x.
|
|
|
|
The updated config has something needed to run a strategy, e.g.
|
|
configuration to run collective ops, or device filters to improve
|
|
distributed training performance.
|
|
|
|
Args:
|
|
config_proto: a `tf.ConfigProto` object.
|
|
|
|
Returns:
|
|
The updated copy of the `config_proto`.
|
|
"""
|
|
return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
|
|
|
|
|
|
# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
|
|
# instead descend from StrategyExtendedV1.
|
|
@tf_export("distribute.StrategyExtended", v1=[])
|
|
class StrategyExtendedV2(object):
|
|
"""Additional APIs for algorithms that need to be distribution-aware.
|
|
|
|
Note: For most usage of `tf.distribute.Strategy`, there should be no need to
|
|
call these methods, since TensorFlow libraries (such as optimizers) already
|
|
call these methods when needed on your behalf.
|
|
|
|
|
|
Some common use cases of functions on this page:
|
|
|
|
* _Locality_
|
|
|
|
`tf.distribute.DistributedValues` can have the same _locality_ as a
|
|
_distributed variable_, which leads to a mirrored value residing on the same
|
|
devices as the variable (as opposed to the compute devices). Such values may
|
|
be passed to a call to `tf.distribute.StrategyExtended.update` to update the
|
|
value of a variable. You may use
|
|
`tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the
|
|
same locality as another variable. You may convert a "PerReplica" value to a
|
|
variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or
|
|
`tf.distribute.StrategyExtended.batch_reduce_to`.
|
|
|
|
* _How to update a distributed variable_
|
|
|
|
A distributed variable is variables created on multiple devices. As discussed
|
|
in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute),
|
|
mirrored variable and SyncOnRead variable are two examples. The standard
|
|
pattern for updating distributed variables is to:
|
|
|
|
1. In your function passed to `tf.distribute.Strategy.run`,
|
|
compute a list of (update, variable) pairs. For example, the update might
|
|
be a gradient of the loss with respect to the variable.
|
|
2. Switch to cross-replica mode by calling
|
|
`tf.distribute.get_replica_context().merge_call()` with the updates and
|
|
variables as arguments.
|
|
3. Call
|
|
`tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
|
|
(for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
|
|
(for a list of variables) to sum the updates.
|
|
4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
|
|
its value.
|
|
|
|
Steps 2 through 4 are done automatically by class
|
|
`tf.keras.optimizers.Optimizer` if you call its
|
|
`tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
|
|
|
|
In fact, a higher-level solution to update a distributed variable is by
|
|
calling `assign` on the variable as you would do to a regular `tf.Variable`.
|
|
You can call the method in both _replica context_ and _cross-replica context_.
|
|
For a _mirrored variable_, calling `assign` in _replica context_ requires you
|
|
to specify the `aggregation` type in the variable constructor. In that case,
|
|
the context switching and sync described in steps 2 through 4 are handled for
|
|
you. If you call `assign` on _mirrored variable_ in _cross-replica context_,
|
|
you can only assign a single value or assign values from another mirrored
|
|
variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead
|
|
variable_, in _replica context_, you can simply call `assign` on it and no
|
|
aggregation happens under the hood. In _cross-replica context_, you can only
|
|
assign a single value to a SyncOnRead variable. One example case is restoring
|
|
from a checkpoint: if the `aggregation` type of the variable is
|
|
`tf.VariableAggregation.SUM`, it is assumed that replica values were added
|
|
before checkpointing, so at the time of restoring, the value is divided by
|
|
the number of replicas and then assigned to each replica; if the `aggregation`
|
|
type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica
|
|
directly.
|
|
|
|
"""
|
|
|
|
def __init__(self, container_strategy):
|
|
self._container_strategy_weakref = weakref.ref(container_strategy)
|
|
self._default_device = None
|
|
# This property is used to determine if we should set drop_remainder=True
|
|
# when creating Datasets from numpy array inputs.
|
|
self._require_static_shapes = False
|
|
|
|
def _resource_creator_scope(self):
|
|
"""Returns one or a list of ops.resource_creator_scope for some Strategy."""
|
|
return None
|
|
|
|
def _default_device_scope(self):
|
|
if self._default_device:
|
|
return ops.device(self._default_device)
|
|
return None
|
|
|
|
def _container_strategy(self):
|
|
"""Get the containing `tf.distribute.Strategy`.
|
|
|
|
This should not generally be needed except when creating a new
|
|
`ReplicaContext` and to validate that the caller is in the correct
|
|
`scope()`.
|
|
|
|
Returns:
|
|
The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
|
|
"""
|
|
container_strategy = self._container_strategy_weakref()
|
|
assert container_strategy is not None
|
|
return container_strategy
|
|
|
|
def _scope(self, strategy):
|
|
"""Implementation of tf.distribute.Strategy.scope()."""
|
|
|
|
def creator_with_resource_vars(next_creator, **kwargs):
|
|
"""Variable creator to use in `_CurrentDistributionContext`."""
|
|
if ops.inside_function():
|
|
if_graph_building = "graph_building"
|
|
else:
|
|
if_graph_building = "not_graph_building"
|
|
|
|
with monitoring.MonitoredTimer(distributed_variable_creation_time_counter.get_cell(strategy.__class__.__name__, if_graph_building)):
|
|
_require_strategy_scope_extended(self)
|
|
kwargs["use_resource"] = True
|
|
kwargs["distribute_strategy"] = strategy
|
|
|
|
# Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
|
|
# dereferencing a `Tensor` that is without a `name`. We still need to
|
|
# propagate the metadata it's holding.
|
|
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
|
|
checkpoint_restore_uid = kwargs[
|
|
"initial_value"].checkpoint_position.restore_uid
|
|
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
|
|
elif isinstance(kwargs["initial_value"],
|
|
trackable.CheckpointInitialValueCallable):
|
|
checkpoint_restore_uid = kwargs[
|
|
"initial_value"].checkpoint_position.restore_uid
|
|
elif (isinstance(kwargs["initial_value"], functools.partial) and
|
|
isinstance(kwargs["initial_value"].func,
|
|
trackable.CheckpointInitialValueCallable)):
|
|
# Some libraries (e.g., Keras) create partial function out of
|
|
# initializer to bind shape/dtype, for example:
|
|
# initial_val = functools.partial(initializer, shape, dtype=dtype)
|
|
# Therefore to get the restore_uid we need to examine the "func" of
|
|
# the partial function.
|
|
checkpoint_restore_uid = kwargs[
|
|
"initial_value"].func.checkpoint_position.restore_uid
|
|
else:
|
|
checkpoint_restore_uid = None
|
|
|
|
created = self._create_variable(next_creator, **kwargs)
|
|
|
|
if checkpoint_restore_uid is not None:
|
|
# pylint: disable=protected-access
|
|
# Let the checkpointing infrastructure know that the variable was
|
|
# already restored so it doesn't waste memory loading the value again.
|
|
# In this case of CheckpointInitialValueCallable this may already be
|
|
# done by the final variable creator, but it doesn't hurt to do it
|
|
# again.
|
|
created._maybe_initialize_trackable()
|
|
created._update_uid = checkpoint_restore_uid
|
|
# pylint: enable=protected-access
|
|
return created
|
|
|
|
def distributed_getter(getter, *args, **kwargs):
|
|
if not self._allow_variable_partition():
|
|
if kwargs.pop("partitioner", None) is not None:
|
|
tf_logging.log_first_n(
|
|
tf_logging.WARN, "Partitioned variables are disabled when using "
|
|
"current tf.distribute.Strategy.", 1)
|
|
return getter(*args, **kwargs)
|
|
|
|
return _CurrentDistributionContext(
|
|
strategy,
|
|
variable_scope.variable_creator_scope(creator_with_resource_vars),
|
|
variable_scope.variable_scope(
|
|
variable_scope.get_variable_scope(),
|
|
custom_getter=distributed_getter,
|
|
),
|
|
strategy.extended._resource_creator_scope(), # pylint: disable=protected-access
|
|
self._default_device_scope(),
|
|
)
|
|
|
|
def _allow_variable_partition(self):
|
|
return False
|
|
|
|
def _create_variable(self, next_creator, **kwargs):
|
|
# Note: should support "colocate_with" argument.
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def variable_created_in_scope(self, v):
|
|
"""Tests whether `v` was created while this strategy scope was active.
|
|
|
|
Variables created inside the strategy scope are "owned" by it:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy()
|
|
>>> with strategy.scope():
|
|
... v = tf.Variable(1.)
|
|
>>> strategy.extended.variable_created_in_scope(v)
|
|
True
|
|
|
|
Variables created outside the strategy are not owned by it:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy()
|
|
>>> v = tf.Variable(1.)
|
|
>>> strategy.extended.variable_created_in_scope(v)
|
|
False
|
|
|
|
Args:
|
|
v: A `tf.Variable` instance.
|
|
|
|
Returns:
|
|
True if `v` was created inside the scope, False if not.
|
|
"""
|
|
return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access
|
|
|
|
def colocate_vars_with(self, colocate_with_variable):
|
|
"""Scope that controls which devices variables will be created on.
|
|
|
|
No operations should be added to the graph inside this scope, it
|
|
should only be used when creating variables (some implementations
|
|
work by changing variable creation, others work by using a
|
|
tf.compat.v1.colocate_with() scope).
|
|
|
|
This may only be used inside `self.scope()`.
|
|
|
|
Example usage:
|
|
|
|
```
|
|
with strategy.scope():
|
|
var1 = tf.Variable(...)
|
|
with strategy.extended.colocate_vars_with(var1):
|
|
# var2 and var3 will be created on the same device(s) as var1
|
|
var2 = tf.Variable(...)
|
|
var3 = tf.Variable(...)
|
|
|
|
def fn(v1, v2, v3):
|
|
# operates on v1 from var1, v2 from var2, and v3 from var3
|
|
|
|
# `fn` runs on every device `var1` is on, `var2` and `var3` will be there
|
|
# too.
|
|
strategy.extended.update(var1, fn, args=(var2, var3))
|
|
```
|
|
|
|
Args:
|
|
colocate_with_variable: A variable created in this strategy's `scope()`.
|
|
Variables created while in the returned context manager will be on the
|
|
same set of devices as `colocate_with_variable`.
|
|
|
|
Returns:
|
|
A context manager.
|
|
"""
|
|
|
|
def create_colocated_variable(next_creator, **kwargs):
|
|
_require_strategy_scope_extended(self)
|
|
kwargs["use_resource"] = True
|
|
kwargs["colocate_with"] = colocate_with_variable
|
|
return next_creator(**kwargs)
|
|
|
|
_require_strategy_scope_extended(self)
|
|
self._validate_colocate_with_variable(colocate_with_variable)
|
|
return variable_scope.variable_creator_scope(create_colocated_variable)
|
|
|
|
def _validate_colocate_with_variable(self, colocate_with_variable):
|
|
"""Validate `colocate_with_variable` argument to `colocate_vars_with`."""
|
|
pass
|
|
|
|
def _make_dataset_iterator(self, dataset):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _make_input_fn_iterator(self, input_fn, replication_mode):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _experimental_distribute_dataset(self, dataset, options):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _distribute_datasets_from_function(self, dataset_fn, options):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _experimental_distribute_values_from_function(self, value_fn):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _reduce(self, reduce_op, value):
|
|
# Default implementation until we have an implementation for each strategy.
|
|
dst = device_util.current() or self._default_device or "/device:CPU:0"
|
|
return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
|
|
|
|
def reduce_to(self, reduce_op, value, destinations, options=None):
|
|
"""Combine (via e.g. sum or mean) values across replicas.
|
|
|
|
`reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
|
|
variables. It supports both dense values and `tf.IndexedSlices`.
|
|
|
|
This API currently can only be called in cross-replica context. Other
|
|
variants to reduce values across replicas are:
|
|
* `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of
|
|
this API.
|
|
* `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
|
|
in replica context. It supports both batched and non-batched all-reduce.
|
|
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
|
|
to the host in cross-replica context.
|
|
|
|
`destinations` specifies where to reduce the value to, e.g. "GPU:0". You can
|
|
also pass in a `Tensor`, and the destinations will be the device of that
|
|
tensor. For all-reduce, pass the same to `value` and `destinations`.
|
|
|
|
It can be used in `tf.distribute.ReplicaContext.merge_call` to write code
|
|
that works for all `tf.distribute.Strategy`.
|
|
|
|
@tf.function
|
|
def step_fn(var):
|
|
|
|
def merge_fn(strategy, value, var):
|
|
# All-reduce the value. Note that `value` here is a
|
|
# `tf.distribute.DistributedValues`.
|
|
reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
|
|
value, destinations=var)
|
|
strategy.extended.update(var, lambda var, value: var.assign(value),
|
|
args=(reduced,))
|
|
|
|
value = tf.identity(1.)
|
|
tf.distribute.get_replica_context().merge_call(merge_fn,
|
|
args=(value, var))
|
|
|
|
def run(strategy):
|
|
with strategy.scope():
|
|
v = tf.Variable(0.)
|
|
strategy.run(step_fn, args=(v,))
|
|
return v
|
|
|
|
run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
|
|
MirroredVariable:{
|
|
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
|
|
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
|
|
}
|
|
run(tf.distribute.experimental.CentralStorageStrategy(
|
|
compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
|
|
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
|
|
run(tf.distribute.OneDeviceStrategy("GPU:0"))
|
|
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
|
|
be combined. Allows using string representation of the enum such as
|
|
"SUM", "MEAN".
|
|
value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to reduce to. To perform an all-reduce, pass the same to `value` and
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
|
to the devices of that variable, and this method doesn't update the
|
|
variable.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
|
perform collective operations. This overrides the default options if the
|
|
`tf.distribute.Strategy` takes one in the constructor. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details of the
|
|
options.
|
|
|
|
Returns:
|
|
A tensor or value reduced to `destinations`.
|
|
"""
|
|
with monitoring.MonitoredTimer(
|
|
distributed_api_time_counter.get_cell(self.__class__.__name__, "Reduce_to_eagerly")
|
|
) if not ops.inside_function() else contextlib.nullcontext():
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
assert not isinstance(destinations, (list, tuple))
|
|
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
|
|
if isinstance(reduce_op, six.string_types):
|
|
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
|
assert (reduce_op == reduce_util.ReduceOp.SUM or
|
|
reduce_op == reduce_util.ReduceOp.MEAN)
|
|
return self._reduce_to(reduce_op, value, destinations, options)
|
|
|
|
def _reduce_to(self, reduce_op, value, destinations, options):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
|
|
"""Combine multiple `reduce_to` calls into one for faster execution.
|
|
|
|
Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
|
|
It's more efficient than reduce each value separately.
|
|
|
|
This API currently can only be called in cross-replica context. Other
|
|
variants to reduce values across replicas are:
|
|
* `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of
|
|
this API.
|
|
* `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
|
|
in replica context. It supports both batched and non-batched all-reduce.
|
|
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
|
|
to the host in cross-replica context.
|
|
|
|
See `reduce_to` for more information.
|
|
|
|
@tf.function
|
|
def step_fn(var):
|
|
|
|
def merge_fn(strategy, value, var):
|
|
# All-reduce the value. Note that `value` here is a
|
|
# `tf.distribute.DistributedValues`.
|
|
reduced = strategy.extended.batch_reduce_to(
|
|
tf.distribute.ReduceOp.SUM, [(value, var)])[0]
|
|
strategy.extended.update(var, lambda var, value: var.assign(value),
|
|
args=(reduced,))
|
|
|
|
value = tf.identity(1.)
|
|
tf.distribute.get_replica_context().merge_call(merge_fn,
|
|
args=(value, var))
|
|
|
|
def run(strategy):
|
|
with strategy.scope():
|
|
v = tf.Variable(0.)
|
|
strategy.run(step_fn, args=(v,))
|
|
return v
|
|
|
|
run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
|
|
MirroredVariable:{
|
|
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
|
|
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
|
|
}
|
|
run(tf.distribute.experimental.CentralStorageStrategy(
|
|
compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
|
|
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
|
|
run(tf.distribute.OneDeviceStrategy("GPU:0"))
|
|
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
|
|
be combined. Allows using string representation of the enum such as
|
|
"SUM", "MEAN".
|
|
value_destination_pairs: a sequence of (value, destinations) pairs. See
|
|
`tf.distribute.Strategy.reduce_to` for descriptions.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
|
perform collective operations. This overrides the default options if the
|
|
`tf.distribute.Strategy` takes one in the constructor. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details of the
|
|
options.
|
|
|
|
Returns:
|
|
A list of reduced values, one per pair in `value_destination_pairs`.
|
|
"""
|
|
with monitoring.MonitoredTimer(
|
|
distributed_api_time_counter.get_cell(self.__class__.__name__, "Batch_reduce_to_eagerly")
|
|
) if not ops.inside_function() else contextlib.nullcontext():
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
|
|
if isinstance(reduce_op, six.string_types):
|
|
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
|
return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
|
|
|
|
def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
|
|
return [
|
|
self.reduce_to(reduce_op, t, destinations=v, options=options)
|
|
for t, v in value_destination_pairs
|
|
]
|
|
|
|
def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
|
|
"""All-reduce `value` across all replicas so that all get the final result.
|
|
|
|
If `value` is a nested structure of tensors, all-reduces of these tensors
|
|
will be batched when possible. `options` can be set to hint the batching
|
|
behavior.
|
|
|
|
This API must be called in a replica context.
|
|
|
|
Args:
|
|
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
|
|
be combined.
|
|
value: Value to be reduced. A tensor or a nested structure of tensors.
|
|
options: A `tf.distribute.experimental.CommunicationOptions`. Options to
|
|
perform collective operations. This overrides the default options if the
|
|
`tf.distribute.Strategy` takes one in the constructor.
|
|
|
|
Returns:
|
|
A tensor or a nested strucutre of tensors with the reduced values. The
|
|
structure is the same as `value`.
|
|
"""
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
replica_context = get_replica_context()
|
|
assert replica_context, (
|
|
"`StrategyExtended._replica_ctx_all_reduce` must be called in"
|
|
" a replica context")
|
|
|
|
def merge_fn(_, flat_value):
|
|
return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
|
|
options)
|
|
|
|
reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
|
|
return nest.pack_sequence_as(value, reduced)
|
|
|
|
def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True):
|
|
"""Run `fn` with `args` and `kwargs` to update `var`."""
|
|
# This method is called by ReplicaContext.update. Strategies who'd like to
|
|
# remove merge_call in this path should override this method.
|
|
replica_context = get_replica_context()
|
|
if not replica_context:
|
|
raise ValueError("`StrategyExtended._replica_ctx_update` must be called "
|
|
"in a replica context.")
|
|
|
|
def merge_fn(_, *merged_args, **merged_kwargs):
|
|
return self.update(var, fn, merged_args, merged_kwargs, group=group)
|
|
|
|
return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
|
|
|
|
def _gather_to(self, value, destinations, axis, options=None):
|
|
"""Gather `value` across replicas along axis-th dimension to `destinations`.
|
|
|
|
`gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
|
|
object, along `axis`-th dimension. It supports only dense tensors but NOT
|
|
sparse tensor. This API can only be called in cross-replica context.
|
|
|
|
Args:
|
|
value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to reduce to. To perform an all-gather, pass the same to `value` and
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
|
to the devices of that variable, and this method doesn't update the
|
|
variable.
|
|
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
|
|
range [0, rank(value)).
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
|
perform collective operations. This overrides the default options if the
|
|
`tf.distribute.Strategy` takes one in the constructor. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details of the
|
|
options.
|
|
|
|
Returns:
|
|
A tensor or value gathered to `destinations`.
|
|
"""
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
assert not isinstance(destinations, (list, tuple))
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
return self._gather_to_implementation(value, destinations, axis, options)
|
|
|
|
def _gather_to_implementation(self, value, destinations, axis, options):
|
|
raise NotImplementedError("_gather_to must be implemented in descendants")
|
|
|
|
def _batch_gather_to(self, value_destination_pairs, axis, options=None):
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
return [
|
|
self._gather_to(t, destinations=v, axis=axis, options=options)
|
|
for t, v in value_destination_pairs
|
|
]
|
|
|
|
def update(self, var, fn, args=(), kwargs=None, group=True):
|
|
"""Run `fn` to update `var` using inputs mirrored to the same devices.
|
|
|
|
`tf.distribute.StrategyExtended.update` takes a distributed variable `var`
|
|
to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It
|
|
applies `fn` to each component variable of `var` and passes corresponding
|
|
values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain
|
|
per-replica values. If they contain mirrored values, they will be unwrapped
|
|
before calling `fn`. For example, `fn` can be `assign_add` and `args` can be
|
|
a mirrored DistributedValues where each component contains the value to be
|
|
added to this mirrored variable `var`. Calling `update` will call
|
|
`assign_add` on each component variable of `var` with the corresponding
|
|
tensor value on that device.
|
|
|
|
Example usage:
|
|
|
|
```python
|
|
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
|
|
devices
|
|
with strategy.scope():
|
|
v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
|
|
def update_fn(v):
|
|
return v.assign(1.0)
|
|
result = strategy.extended.update(v, update_fn)
|
|
# result is
|
|
# Mirrored:{
|
|
# 0: tf.Tensor(1.0, shape=(), dtype=float32),
|
|
# 1: tf.Tensor(1.0, shape=(), dtype=float32)
|
|
# }
|
|
```
|
|
|
|
If `var` is mirrored across multiple devices, then this method implements
|
|
logic as following:
|
|
|
|
```python
|
|
results = {}
|
|
for device, v in var:
|
|
with tf.device(device):
|
|
# args and kwargs will be unwrapped if they are mirrored.
|
|
results[device] = fn(v, *args, **kwargs)
|
|
return merged(results)
|
|
```
|
|
|
|
Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with
|
|
`var`.
|
|
|
|
Args:
|
|
var: Variable, possibly mirrored to multiple devices, to operate on.
|
|
fn: Function to call. Should take the variable as the first argument.
|
|
args: Tuple or list. Additional positional arguments to pass to `fn()`.
|
|
kwargs: Dict with keyword arguments to pass to `fn()`.
|
|
group: Boolean. Defaults to True. If False, the return value will be
|
|
unwrapped.
|
|
|
|
Returns:
|
|
By default, the merged return value of `fn` across all replicas. The
|
|
merged result has dependencies to make sure that if it is evaluated at
|
|
all, the side effects (updates) will happen on every replica. If instead
|
|
"group=False" is specified, this function will return a nest of lists
|
|
where each list has an element per replica, and the caller is responsible
|
|
for ensuring all elements are executed.
|
|
"""
|
|
# TODO(b/178944108): Update the documentation to relfect the fact that
|
|
# `update` can be called in a replica context.
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
replica_context = get_replica_context()
|
|
# pylint: disable=protected-access
|
|
if (replica_context is None or replica_context is
|
|
_get_default_replica_context()):
|
|
fn = autograph.tf_convert(
|
|
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
|
|
with self._container_strategy().scope():
|
|
return self._update(var, fn, args, kwargs, group)
|
|
else:
|
|
return self._replica_ctx_update(
|
|
var, fn, args=args, kwargs=kwargs, group=group)
|
|
|
|
def _update(self, var, fn, args, kwargs, group):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _local_results(self, val):
|
|
"""Returns local results per replica as a tuple."""
|
|
if isinstance(val, ds_types.DistributedValues):
|
|
return val._values # pylint: disable=protected-access
|
|
|
|
if nest.is_nested(val):
|
|
replica_values = []
|
|
|
|
def get_values(x, index):
|
|
if isinstance(x, ds_types.DistributedValues):
|
|
return x._values[index] # pylint: disable=protected-access
|
|
return x
|
|
|
|
for i in range(len(self.worker_devices)):
|
|
replica_values.append(
|
|
nest.map_structure(
|
|
lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop
|
|
val))
|
|
return tuple(replica_values)
|
|
return (val,)
|
|
|
|
def value_container(self, value):
|
|
"""Returns the container that this per-replica `value` belongs to.
|
|
|
|
Args:
|
|
value: A value returned by `run()` or a variable created in `scope()`.
|
|
|
|
Returns:
|
|
A container that `value` belongs to.
|
|
If value does not belong to any container (including the case of
|
|
container having been destroyed), returns the value itself.
|
|
`value in experimental_local_results(value_container(value))` will
|
|
always be true.
|
|
"""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _group(self, value, name=None):
|
|
"""Implementation of `group`."""
|
|
value = nest.flatten(self._local_results(value))
|
|
|
|
if len(value) != 1 or name is not None:
|
|
return control_flow_ops.group(value, name=name)
|
|
# Special handling for the common case of one op.
|
|
v, = value
|
|
if hasattr(v, "op"):
|
|
v = v.op
|
|
return v
|
|
|
|
@property
|
|
def experimental_require_static_shapes(self):
|
|
"""Returns `True` if static shape is required; `False` otherwise."""
|
|
return self._require_static_shapes
|
|
|
|
@property
|
|
def _num_replicas_in_sync(self):
|
|
"""Returns number of replicas over which gradients are aggregated."""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
@property
|
|
def worker_devices(self):
|
|
"""Returns the tuple of all devices used to for compute replica execution.
|
|
"""
|
|
# TODO(josh11b): More docstring
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
@property
|
|
def parameter_devices(self):
|
|
"""Returns the tuple of all devices used to place variables."""
|
|
# TODO(josh11b): More docstring
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _configure(self,
|
|
session_config=None,
|
|
cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None):
|
|
"""Configures the strategy class."""
|
|
del session_config, cluster_spec, task_type, task_id
|
|
|
|
def _update_config_proto(self, config_proto):
|
|
return copy.deepcopy(config_proto)
|
|
|
|
def _in_multi_worker_mode(self):
|
|
"""Whether this strategy indicates working in multi-worker settings.
|
|
|
|
Multi-worker training refers to the setup where the training is
|
|
distributed across multiple workers, as opposed to the case where
|
|
only a local process performs the training. This function is
|
|
used by higher-level APIs such as Keras' `model.fit()` to infer
|
|
for example whether or not a distribute coordinator should be run,
|
|
and thus TensorFlow servers should be started for communication
|
|
with other servers in the cluster, or whether or not saving/restoring
|
|
checkpoints is relevant for preemption fault tolerance.
|
|
|
|
Subclasses should override this to provide whether the strategy is
|
|
currently in multi-worker setup.
|
|
|
|
Experimental. Signature and implementation are subject to change.
|
|
"""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
|
|
@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring
|
|
class StrategyExtendedV1(StrategyExtendedV2):
|
|
|
|
__doc__ = StrategyExtendedV2.__doc__
|
|
|
|
def experimental_make_numpy_dataset(self, numpy_input, session=None):
|
|
"""Makes a dataset for input provided via a numpy array.
|
|
|
|
This avoids adding `numpy_input` as a large constant in the graph,
|
|
and copies the data to the machine or machines that will be processing
|
|
the input.
|
|
|
|
Args:
|
|
numpy_input: A nest of NumPy input arrays that will be distributed evenly
|
|
across all replicas. Note that lists of Numpy arrays are stacked, as
|
|
that is normal `tf.data.Dataset` behavior.
|
|
session: (TensorFlow v1.x graph execution only) A session used for
|
|
initialization.
|
|
|
|
Returns:
|
|
A `tf.data.Dataset` representing `numpy_input`.
|
|
"""
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
return self._experimental_make_numpy_dataset(numpy_input, session=session)
|
|
|
|
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def broadcast_to(self, tensor, destinations):
|
|
"""Mirror a tensor on one device to all worker devices.
|
|
|
|
Args:
|
|
tensor: A Tensor value to broadcast.
|
|
destinations: A mirrored variable or device string specifying the
|
|
destination devices to copy `tensor` to.
|
|
|
|
Returns:
|
|
A value mirrored to `destinations` devices.
|
|
"""
|
|
assert destinations is not None # from old strategy.broadcast()
|
|
# TODO(josh11b): More docstring
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
assert not isinstance(destinations, (list, tuple))
|
|
return self._broadcast_to(tensor, destinations)
|
|
|
|
def _broadcast_to(self, tensor, destinations):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
@deprecated(None, "please use `run` instead.")
|
|
def experimental_run_steps_on_iterator(self,
|
|
fn,
|
|
iterator,
|
|
iterations=1,
|
|
initial_loop_values=None):
|
|
"""DEPRECATED: please use `run` instead.
|
|
|
|
Run `fn` with input from `iterator` for `iterations` times.
|
|
|
|
This method can be used to run a step function for training a number of
|
|
times using input from a dataset.
|
|
|
|
Args:
|
|
fn: function to run using this distribution strategy. The function must
|
|
have the following signature: `def fn(context, inputs)`. `context` is an
|
|
instance of `MultiStepContext` that will be passed when `fn` is run.
|
|
`context` can be used to specify the outputs to be returned from `fn`
|
|
by calling `context.set_last_step_output`. It can also be used to
|
|
capture non tensor outputs by `context.set_non_tensor_output`. See
|
|
`MultiStepContext` documentation for more information. `inputs` will
|
|
have same type/structure as `iterator.get_next()`. Typically, `fn`
|
|
will use `call_for_each_replica` method of the strategy to distribute
|
|
the computation over multiple replicas.
|
|
iterator: Iterator of a dataset that represents the input for `fn`. The
|
|
caller is responsible for initializing the iterator as needed.
|
|
iterations: (Optional) Number of iterations that `fn` should be run.
|
|
Defaults to 1.
|
|
initial_loop_values: (Optional) Initial values to be passed into the
|
|
loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
|
|
initial_loop_values argument when we have a mechanism to infer the
|
|
outputs of `fn`.
|
|
|
|
Returns:
|
|
Returns the `MultiStepContext` object which has the following properties,
|
|
among other things:
|
|
- run_op: An op that runs `fn` `iterations` times.
|
|
- last_step_outputs: A dictionary containing tensors set using
|
|
`context.set_last_step_output`. Evaluating this returns the value of
|
|
the tensors after the last iteration.
|
|
- non_tensor_outputs: A dictionary containing anything that was set by
|
|
`fn` by calling `context.set_non_tensor_output`.
|
|
"""
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
with self._container_strategy().scope():
|
|
return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
|
|
initial_loop_values)
|
|
|
|
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
|
initial_loop_values):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def call_for_each_replica(self, fn, args=(), kwargs=None):
|
|
"""Run `fn` once per replica.
|
|
|
|
`fn` may call `tf.get_replica_context()` to access methods such as
|
|
`replica_id_in_sync_group` and `merge_call()`.
|
|
|
|
`merge_call()` is used to communicate between the replicas and
|
|
re-enter the cross-replica context. All replicas pause their execution
|
|
having encountered a `merge_call()` call. After that the
|
|
`merge_fn`-function is executed. Its results are then unwrapped and
|
|
given back to each replica call. After that execution resumes until
|
|
`fn` is complete or encounters another `merge_call()`. Example:
|
|
|
|
```python
|
|
# Called once in "cross-replica" context.
|
|
def merge_fn(distribution, three_plus_replica_id):
|
|
# sum the values across replicas
|
|
return sum(distribution.experimental_local_results(three_plus_replica_id))
|
|
|
|
# Called once per replica in `distribution`, in a "replica" context.
|
|
def fn(three):
|
|
replica_ctx = tf.get_replica_context()
|
|
v = three + replica_ctx.replica_id_in_sync_group
|
|
# Computes the sum of the `v` values across all replicas.
|
|
s = replica_ctx.merge_call(merge_fn, args=(v,))
|
|
return s + v
|
|
|
|
with distribution.scope():
|
|
# in "cross-replica" context
|
|
...
|
|
merged_results = distribution.run(fn, args=[3])
|
|
# merged_results has the values from every replica execution of `fn`.
|
|
# This statement prints a list:
|
|
print(distribution.experimental_local_results(merged_results))
|
|
```
|
|
|
|
Args:
|
|
fn: function to run (will be run once per replica).
|
|
args: Tuple or list with positional arguments for `fn`.
|
|
kwargs: Dict with keyword arguments for `fn`.
|
|
|
|
Returns:
|
|
Merged return value of `fn` across all replicas.
|
|
"""
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
with self._container_strategy().scope():
|
|
return self._call_for_each_replica(fn, args, kwargs)
|
|
|
|
def _call_for_each_replica(self, fn, args, kwargs):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def read_var(self, v):
|
|
"""Reads the value of a variable.
|
|
|
|
Returns the aggregate value of a replica-local variable, or the
|
|
(read-only) value of any other variable.
|
|
|
|
Args:
|
|
v: A variable allocated within the scope of this `tf.distribute.Strategy`.
|
|
|
|
Returns:
|
|
A tensor representing the value of `v`, aggregated across replicas if
|
|
necessary.
|
|
"""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def update_non_slot(
|
|
self, colocate_with, fn, args=(), kwargs=None, group=True):
|
|
"""Runs `fn(*args, **kwargs)` on `colocate_with` devices.
|
|
|
|
Used to update non-slot variables.
|
|
|
|
DEPRECATED: TF 1.x ONLY.
|
|
|
|
Args:
|
|
colocate_with: Devices returned by `non_slot_devices()`.
|
|
fn: Function to execute.
|
|
args: Tuple or list. Positional arguments to pass to `fn()`.
|
|
kwargs: Dict with keyword arguments to pass to `fn()`.
|
|
group: Boolean. Defaults to True. If False, the return value will be
|
|
unwrapped.
|
|
|
|
Returns:
|
|
Return value of `fn`, possibly merged across devices.
|
|
"""
|
|
_require_cross_replica_or_default_context_extended(self)
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
fn = autograph.tf_convert(
|
|
fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
|
|
with self._container_strategy().scope():
|
|
return self._update_non_slot(colocate_with, fn, args, kwargs, group)
|
|
|
|
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def non_slot_devices(self, var_list):
|
|
"""Device(s) for non-slot variables.
|
|
|
|
DEPRECATED: TF 1.x ONLY.
|
|
|
|
This method returns non-slot devices where non-slot variables are placed.
|
|
Users can create non-slot variables on these devices by using a block:
|
|
|
|
```python
|
|
with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)):
|
|
...
|
|
```
|
|
|
|
Args:
|
|
var_list: The list of variables being optimized, needed with the
|
|
default `tf.distribute.Strategy`.
|
|
Returns:
|
|
A sequence of devices for non-slot variables.
|
|
"""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def _use_merge_call(self):
|
|
"""Whether to use merge-calls inside the distributed strategy."""
|
|
return True
|
|
|
|
@property
|
|
def experimental_between_graph(self):
|
|
"""Whether the strategy uses between-graph replication or not.
|
|
|
|
This is expected to return a constant value that will not be changed
|
|
throughout its life cycle.
|
|
"""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
@property
|
|
def experimental_should_init(self):
|
|
"""Whether initialization is needed."""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
@property
|
|
def should_checkpoint(self):
|
|
"""Whether checkpointing is needed."""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
@property
|
|
def should_save_summary(self):
|
|
"""Whether saving summaries is needed."""
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
|
|
# A note about the difference between the context managers
|
|
# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
|
|
# (defined above) used by `tf.distribute.Strategy.scope()`:
|
|
#
|
|
# * a ReplicaContext is only present during a `run()`
|
|
# call (except during a `merge_run` call) and in such a scope it
|
|
# will be returned by calls to `get_replica_context()`. Implementers of new
|
|
# Strategy descendants will frequently also need to
|
|
# define a descendant of ReplicaContext, and are responsible for
|
|
# entering and exiting this context.
|
|
#
|
|
# * Strategy.scope() sets up a variable_creator scope that
|
|
# changes variable creation calls (e.g. to make mirrored
|
|
# variables). This is intended as an outer scope that users enter once
|
|
# around their model creation and graph definition. There is no
|
|
# anticipated need to define descendants of _CurrentDistributionContext.
|
|
# It sets the current Strategy for purposes of
|
|
# `get_strategy()` and `has_strategy()`
|
|
# and switches the thread mode to a "cross-replica context".
|
|
class ReplicaContextBase(object):
|
|
"""A class with a collection of APIs that can be called in a replica context.
|
|
|
|
You can use `tf.distribute.get_replica_context` to get an instance of
|
|
`ReplicaContext`, which can only be called inside the function passed to
|
|
`tf.distribute.Strategy.run`.
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
|
|
>>> def func():
|
|
... replica_context = tf.distribute.get_replica_context()
|
|
... return replica_context.replica_id_in_sync_group
|
|
>>> strategy.run(func)
|
|
PerReplica:{
|
|
0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
|
|
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
|
|
}
|
|
"""
|
|
|
|
def __init__(self, strategy, replica_id_in_sync_group):
|
|
"""Creates a ReplicaContext.
|
|
|
|
Args:
|
|
strategy: A `tf.distribute.Strategy`.
|
|
replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an
|
|
integer whenever possible to avoid issues with nested `tf.function`. It
|
|
accepts a `Tensor` only to be compatible with `tpu.replicate`.
|
|
"""
|
|
self._strategy = strategy
|
|
self._thread_context = _InReplicaThreadMode( # pylint: disable=protected-access
|
|
self)
|
|
if not (replica_id_in_sync_group is None or
|
|
tensor_util.is_tf_type(replica_id_in_sync_group) or
|
|
isinstance(replica_id_in_sync_group, int)):
|
|
raise ValueError(
|
|
"replica_id_in_sync_group can only be an integer, a Tensor or None.")
|
|
self._replica_id_in_sync_group = replica_id_in_sync_group
|
|
# We need this check because TPUContext extends from ReplicaContext and
|
|
# does not pass a strategy object since it is used by TPUEstimator.
|
|
if strategy:
|
|
self._local_replica_id = strategy.extended._get_local_replica_id(
|
|
replica_id_in_sync_group)
|
|
self._summary_recording_distribution_strategy = None
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def __enter__(self):
|
|
_push_per_thread_mode(self._thread_context)
|
|
|
|
def replica_id_is_zero():
|
|
return math_ops.equal(self.replica_id_in_sync_group,
|
|
constant_op.constant(0))
|
|
|
|
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
|
self._summary_recording_distribution_strategy = (
|
|
summary_state.is_recording_distribution_strategy)
|
|
summary_state.is_recording_distribution_strategy = replica_id_is_zero
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
|
summary_state.is_recording_distribution_strategy = (
|
|
self._summary_recording_distribution_strategy)
|
|
_pop_per_thread_mode()
|
|
|
|
def merge_call(self, merge_fn, args=(), kwargs=None):
|
|
"""Merge args across replicas and run `merge_fn` in a cross-replica context.
|
|
|
|
This allows communication and coordination when there are multiple calls
|
|
to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
|
|
|
|
See `tf.distribute.Strategy.run` for an explanation.
|
|
|
|
If not inside a distributed scope, this is equivalent to:
|
|
|
|
```
|
|
strategy = tf.distribute.get_strategy()
|
|
with cross-replica-context(strategy):
|
|
return merge_fn(strategy, *args, **kwargs)
|
|
```
|
|
|
|
Args:
|
|
merge_fn: Function that joins arguments from threads that are given as
|
|
PerReplica. It accepts `tf.distribute.Strategy` object as
|
|
the first argument.
|
|
args: List or tuple with positional per-thread arguments for `merge_fn`.
|
|
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
|
|
|
|
Returns:
|
|
The return value of `merge_fn`, except for `PerReplica` values which are
|
|
unpacked.
|
|
"""
|
|
require_replica_context(self)
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
merge_fn = autograph.tf_convert(
|
|
merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
|
|
return self._merge_call(merge_fn, args, kwargs)
|
|
|
|
def _merge_call(self, merge_fn, args, kwargs):
|
|
"""Default implementation for single replica."""
|
|
_push_per_thread_mode( # thread-local, so not needed with multiple threads
|
|
_CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access
|
|
try:
|
|
return merge_fn(self._strategy, *args, **kwargs)
|
|
finally:
|
|
_pop_per_thread_mode()
|
|
|
|
@property
|
|
def num_replicas_in_sync(self):
|
|
"""Returns number of replicas that are kept in sync."""
|
|
return self._strategy.num_replicas_in_sync
|
|
|
|
@property
|
|
def replica_id_in_sync_group(self):
|
|
"""Returns the id of the replica.
|
|
|
|
This identifies the replica among all replicas that are kept in sync. The
|
|
value of the replica id can range from 0 to
|
|
`tf.distribute.ReplicaContext.num_replicas_in_sync` - 1.
|
|
|
|
NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
|
|
for low-level operations such as collective_permute.
|
|
|
|
Returns:
|
|
a `Tensor`.
|
|
"""
|
|
# It's important to prefer making the Tensor at call time whenever possible.
|
|
# Keeping Tensors in global states doesn't work well with nested
|
|
# tf.function, since it's possible that the tensor is generated in one func
|
|
# graph, and gets captured by another, which will result in a subtle "An op
|
|
# outside of the function building code is being passed a Graph tensor"
|
|
# error. Making the tensor at call time to ensure it is the same graph where
|
|
# it's used. However to be compatible with tpu.replicate(),
|
|
# self._replica_id_in_sync_group can also be a Tensor.
|
|
if tensor_util.is_tf_type(self._replica_id_in_sync_group):
|
|
return self._replica_id_in_sync_group
|
|
return constant_op.constant(
|
|
self._replica_id_in_sync_group,
|
|
dtypes.int32,
|
|
name="replica_id_in_sync_group")
|
|
|
|
@property
|
|
def _replica_id(self):
|
|
"""This is the local replica id in a given sync group."""
|
|
return self._local_replica_id
|
|
|
|
@property
|
|
def strategy(self):
|
|
"""The current `tf.distribute.Strategy` object."""
|
|
return self._strategy
|
|
|
|
@property
|
|
@deprecation.deprecated(None, "Please avoid relying on devices property.")
|
|
def devices(self):
|
|
"""Returns the devices this replica is to be executed on, as a tuple of strings.
|
|
|
|
NOTE: For `tf.distribute.MirroredStrategy` and
|
|
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a
|
|
nested
|
|
list of device strings, e.g., [["GPU:0"]].
|
|
"""
|
|
require_replica_context(self)
|
|
return (device_util.current(),)
|
|
|
|
def all_reduce(self, reduce_op, value, options=None):
|
|
"""All-reduces `value` across all replicas.
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> def step_fn():
|
|
... ctx = tf.distribute.get_replica_context()
|
|
... value = tf.identity(1.)
|
|
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
|
|
>>> strategy.experimental_local_results(strategy.run(step_fn))
|
|
(<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
|
|
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
|
|
|
|
It supports batched operations. You can pass a list of values and it
|
|
attempts to batch them when possible. You can also specify `options`
|
|
to indicate the desired batching behavior, e.g. batch the values into
|
|
multiple packs so that they can better overlap with computations.
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> def step_fn():
|
|
... ctx = tf.distribute.get_replica_context()
|
|
... value1 = tf.identity(1.)
|
|
... value2 = tf.identity(2.)
|
|
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
|
|
>>> strategy.experimental_local_results(strategy.run(step_fn))
|
|
([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
|
|
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
|
|
[<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
|
|
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>])
|
|
|
|
Note that all replicas need to participate in the all-reduce, otherwise this
|
|
operation hangs. Note that if there're multiple all-reduces, they need to
|
|
execute in the same order on all replicas. Dispatching all-reduce based on
|
|
conditions is usually error-prone.
|
|
|
|
Known limitation: if `value` contains `tf.IndexedSlices`, attempting to
|
|
compute gradient w.r.t `value` would result in an error.
|
|
|
|
This API currently can only be called in the replica context. Other
|
|
variants to reduce values across replicas are:
|
|
* `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
|
|
in the cross-replica context.
|
|
* `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
|
|
all-reduce API in the cross-replica context.
|
|
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
|
|
to the host in cross-replica context.
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
|
|
be combined. Allows using string representation of the enum such as
|
|
"SUM", "MEAN".
|
|
value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which
|
|
`tf.nest.flatten` accepts. The structure and the shapes of `value` need to be
|
|
same on all replicas.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
|
perform collective operations. This overrides the default options if the
|
|
`tf.distribute.Strategy` takes one in the constructor. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details of the
|
|
options.
|
|
|
|
Returns:
|
|
A nested structure of `tf.Tensor` with the reduced values. The structure
|
|
is the same as `value`.
|
|
"""
|
|
flattened_value = nest.flatten(value)
|
|
has_indexed_slices = False
|
|
|
|
for v in flattened_value:
|
|
if isinstance(v, indexed_slices.IndexedSlices):
|
|
has_indexed_slices = True
|
|
|
|
if isinstance(reduce_op, six.string_types):
|
|
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
|
|
def batch_all_reduce(strategy, *value_flat):
|
|
return strategy.extended.batch_reduce_to(
|
|
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
|
|
options)
|
|
|
|
# Due to the use of `capture_call_time_value` in collective ops, we have
|
|
# to maintain two branches: one w/ merge_call and one w/o. Details can be
|
|
# found in b/184009754.
|
|
if self._strategy.extended._use_merge_call(): # pylint: disable=protected-access
|
|
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
|
|
if has_indexed_slices:
|
|
return nest.pack_sequence_as(
|
|
value,
|
|
self.merge_call(batch_all_reduce, args=flattened_value))
|
|
|
|
@custom_gradient.custom_gradient
|
|
def grad_wrapper(*xs):
|
|
ys = self.merge_call(batch_all_reduce, args=xs)
|
|
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
|
|
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
|
|
return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
|
|
else:
|
|
if has_indexed_slices:
|
|
return nest.pack_sequence_as(
|
|
value,
|
|
self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access
|
|
reduce_op, flattened_value, options))
|
|
|
|
@custom_gradient.custom_gradient
|
|
def grad_wrapper(*xs):
|
|
ys = self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access
|
|
reduce_op, xs, options)
|
|
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
|
|
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
|
|
|
|
return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
|
|
|
|
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
|
|
# all-reduce. It would return a function returning the result of reducing `t`
|
|
# across all replicas. The caller would wait to call this function until they
|
|
# needed the reduce result, allowing an efficient implementation:
|
|
# * With eager execution, the reduction could be performed asynchronously
|
|
# in the background, not blocking until the result was needed.
|
|
# * When constructing a graph, it could batch up all reduction requests up
|
|
# to that point that the first result is needed. Most likely this can be
|
|
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
|
|
|
|
|
@tf_export("distribute.ReplicaContext", v1=[])
|
|
class ReplicaContext(ReplicaContextBase):
|
|
|
|
__doc__ = ReplicaContextBase.__doc__
|
|
|
|
def all_gather(self, value, axis, options=None):
|
|
"""All-gathers `value` across all replicas along `axis`.
|
|
|
|
Note: An `all_gather` method can only be called in replica context. For
|
|
a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
|
|
All replicas need to participate in the all-gather, otherwise this
|
|
operation hangs. So if `all_gather` is called in any replica, it must be
|
|
called in all replicas.
|
|
|
|
Note: If there are multiple `all_gather` calls, they need to be executed in
|
|
the same order on all replicas. Dispatching `all_gather` based on conditions
|
|
is usually error-prone.
|
|
|
|
For all strategies except `tf.distribute.TPUStrategy`, the input
|
|
`value` on different replicas must have the same rank, and their shapes must
|
|
be the same in all dimensions except the `axis`-th dimension. In other
|
|
words, their shapes cannot be different in a dimension `d` where `d` does
|
|
not equal to the `axis` argument. For example, given a
|
|
`tf.distribute.DistributedValues` with component tensors of shape
|
|
`(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
|
|
`all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
|
|
or `all_gather(..., axis=2, ...)`. However, with
|
|
`tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
|
|
same shape.
|
|
|
|
Note: The input `value` must have a non-zero rank. Otherwise, consider using
|
|
`tf.expand_dims` before gathering them.
|
|
|
|
You can pass in a single tensor to all-gather:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> @tf.function
|
|
... def gather_value():
|
|
... ctx = tf.distribute.get_replica_context()
|
|
... local_value = tf.constant([1, 2, 3])
|
|
... return ctx.all_gather(local_value, axis=0)
|
|
>>> result = strategy.run(gather_value)
|
|
>>> result
|
|
PerReplica:{
|
|
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
|
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
|
}
|
|
>>> strategy.experimental_local_results(result)
|
|
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
|
dtype=int32)>,
|
|
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
|
|
dtype=int32)>)
|
|
|
|
|
|
You can also pass in a nested structure of tensors to all-gather, say, a
|
|
list:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
|
>>> @tf.function
|
|
... def gather_nest():
|
|
... ctx = tf.distribute.get_replica_context()
|
|
... value_1 = tf.constant([1, 2, 3])
|
|
... value_2 = tf.constant([[1, 2], [3, 4]])
|
|
... # all_gather a nest of `tf.distribute.DistributedValues`
|
|
... return ctx.all_gather([value_1, value_2], axis=0)
|
|
>>> result = strategy.run(gather_nest)
|
|
>>> result
|
|
[PerReplica:{
|
|
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
|
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
|
|
}, PerReplica:{
|
|
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
array([[1, 2],
|
|
[3, 4],
|
|
[1, 2],
|
|
[3, 4]], dtype=int32)>,
|
|
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
array([[1, 2],
|
|
[3, 4],
|
|
[1, 2],
|
|
[3, 4]], dtype=int32)>
|
|
}]
|
|
>>> strategy.experimental_local_results(result)
|
|
([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
|
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
array([[1, 2],
|
|
[3, 4],
|
|
[1, 2],
|
|
[3, 4]], dtype=int32)>],
|
|
[<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
|
|
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
|
|
array([[1, 2],
|
|
[3, 4],
|
|
[1, 2],
|
|
[3, 4]], dtype=int32)>])
|
|
|
|
|
|
What if you are all-gathering tensors with different shapes on different
|
|
replicas? Consider the following example with two replicas, where you have
|
|
`value` as a nested structure consisting of two items to all-gather, `a` and
|
|
`b`.
|
|
|
|
* On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`.
|
|
* On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`.
|
|
* Result for `all_gather` with `axis=0` (on each of the replicas) is:
|
|
|
|
```
|
|
{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
|
|
```
|
|
|
|
Args:
|
|
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
|
|
or a `tf.distribute.DistributedValues` instance. The structure of the
|
|
`tf.Tensor` need to be same on all replicas. The underlying tensor
|
|
constructs can only be dense tensors with non-zero rank, NOT
|
|
`tf.IndexedSlices`.
|
|
axis: 0-D int32 Tensor. Dimension along which to gather.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. Options to
|
|
perform collective operations. This overrides the default options if the
|
|
`tf.distribute.Strategy` takes one in the constructor. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details of the
|
|
options.
|
|
|
|
Returns:
|
|
A nested structure of `tf.Tensor` with the gathered values. The structure
|
|
is the same as `value`.
|
|
"""
|
|
for v in nest.flatten(value):
|
|
if isinstance(v, indexed_slices.IndexedSlices):
|
|
raise NotImplementedError("all_gather does not support IndexedSlices")
|
|
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
|
|
def batch_all_gather(strategy, *value_flat):
|
|
return strategy.extended._batch_gather_to( # pylint: disable=protected-access
|
|
[(v, _batch_reduce_destination(v)) for v in value_flat], axis,
|
|
options)
|
|
|
|
@custom_gradient.custom_gradient
|
|
def grad_wrapper(*xs):
|
|
ys = self.merge_call(batch_all_gather, args=xs)
|
|
|
|
def grad(*dy_s):
|
|
grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s)
|
|
new_grads = []
|
|
for i, grad in enumerate(grads):
|
|
input_shape = array_ops.shape(xs[i])
|
|
axis_dim = array_ops.reshape(input_shape[axis], [1])
|
|
with ops.control_dependencies([array_ops.identity(grads)]):
|
|
d = self.all_gather(axis_dim, axis=0)
|
|
begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group])
|
|
end_dim = begin_dim + array_ops.shape(xs[i])[axis]
|
|
new_grad = array_ops.gather(
|
|
grad, axis=axis, indices=math_ops.range(begin_dim, end_dim))
|
|
new_grads.append(new_grad)
|
|
return new_grads
|
|
|
|
return ys, grad
|
|
|
|
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
|
|
|
def _update(self, var, fn, args=(), kwargs=None, group=True):
|
|
"""Run `fn` to update `var` with `args` and `kwargs` in replica context.
|
|
|
|
`tf.distribute.ReplicaContext.update` takes a (distributed) variable `var`
|
|
to be updated, an update function `fn`, and `args` and `kwargs` for `fn`.
|
|
`fn` applies to each component variable of `var` with corresponding input
|
|
values from `args` and `kwargs`.
|
|
|
|
Example usage:
|
|
|
|
>>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
|
|
>>> with strategy.scope():
|
|
... distributed_variable = tf.Variable(5.0)
|
|
>>> distributed_variable
|
|
MirroredVariable:{
|
|
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
|
|
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0>
|
|
}
|
|
>>> def replica_fn(v):
|
|
... value = tf.identity(1.0)
|
|
... replica_context = tf.distribute.get_replica_context()
|
|
... update_fn = lambda var, value: var.assign(value)
|
|
... replica_context._update(v, update_fn, args=(value,))
|
|
>>> strategy.run(replica_fn, args=(distributed_variable,))
|
|
>>> distributed_variable
|
|
MirroredVariable:{
|
|
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
|
|
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
|
|
}
|
|
|
|
This API must be called in a replica context.
|
|
|
|
Note that if `var` is a MirroredVariable (i.e., the type of variable created
|
|
under the scope of a synchronous strategy, and is synchronized on-write, see
|
|
`tf.VariableSynchronization` for more information) and `args`/`kwargs`
|
|
contains different values for different replicas, `var` will be dangerously
|
|
out of synchronization. Thus we recommend using `variable.assign(value)` as
|
|
long as you can, which under the hood aggregates the updates and guarantees
|
|
the synchronization. The case where you actually want this API instead of
|
|
`variable.assign(value)` is that before assigning `value` to the `variable`,
|
|
you'd like to conduct some pre-`assign` computation colocated with the
|
|
variable devices (i.e. where variables reside, for MirroredStrategy they are
|
|
the same as the compute device, for ParameterServerStrategy they refer to
|
|
parameter servers). E.g.,
|
|
|
|
```python
|
|
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
|
|
with strategy.scope():
|
|
v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
|
|
def replica_fn(inputs):
|
|
value = computation(inputs)
|
|
replica_context = tf.distribute.get_replica_context()
|
|
reduced_value = replica_context.all_reduce(value)
|
|
|
|
def update_fn(var, value):
|
|
# this computation will colocate with `var`'s device
|
|
updated_value = post_reduce_pre_update_computation(value)
|
|
var.assign(value)
|
|
|
|
replica_context._update(v, update_fn, args=(reduced_value,))
|
|
|
|
strategy.run(replica_fn, args=(inputs,))
|
|
```
|
|
|
|
This code snippet is consistent across all strategies. If you directly
|
|
compute and use `assign` in the replica context instead of wrapping it with
|
|
`update`, for strategies with fewer variable devices than compute devices
|
|
(e.g., parameter server strategy, usually), the
|
|
`post_reduce_pre_update_computation` will happen
|
|
N==number_of_compute_devices times which is less performant.
|
|
|
|
|
|
Args:
|
|
var: Variable, possibly distributed to multiple devices, to operate on.
|
|
fn: Function to call. Should take the variable as the first argument.
|
|
args: Tuple or list. Additional positional arguments to pass to `fn()`.
|
|
kwargs: Dict with keyword arguments to pass to `fn()`.
|
|
group: Boolean. Defaults to True. Most strategies enter a merge_call to
|
|
conduct update in cross-replica context, and group=True guarantees updates
|
|
on all replicas is executed.
|
|
|
|
Returns:
|
|
The return value of `fn` for the local replica.
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access
|
|
|
|
|
|
@tf_export(v1=["distribute.ReplicaContext"])
|
|
class ReplicaContextV1(ReplicaContextBase):
|
|
__doc__ = ReplicaContextBase.__doc__
|
|
|
|
|
|
def _batch_reduce_destination(x):
|
|
"""Returns the destinations for batch all-reduce."""
|
|
if isinstance(x, tensor_lib.Tensor):
|
|
# If this is a one device strategy.
|
|
return x.device
|
|
else:
|
|
return x
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
class _DefaultDistributionStrategyV1(StrategyV1):
|
|
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
|
|
|
|
def __init__(self):
|
|
if not _creating_default_strategy_singleton:
|
|
raise RuntimeError("Should only create a single instance of "
|
|
"_DefaultDistributionStrategy")
|
|
super(_DefaultDistributionStrategyV1,
|
|
self).__init__(_DefaultDistributionExtended(self))
|
|
|
|
def __deepcopy__(self, memo):
|
|
del memo
|
|
raise RuntimeError("Should only create a single instance of "
|
|
"_DefaultDistributionStrategy")
|
|
|
|
|
|
class _DefaultDistributionStrategy(Strategy):
|
|
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
|
|
|
|
def __init__(self):
|
|
if not _creating_default_strategy_singleton:
|
|
raise RuntimeError("Should only create a single instance of "
|
|
"_DefaultDistributionStrategy")
|
|
super(_DefaultDistributionStrategy, self).__init__(
|
|
_DefaultDistributionExtended(self))
|
|
|
|
def __deepcopy__(self, memo):
|
|
del memo
|
|
raise RuntimeError("Should only create a single instance of "
|
|
"_DefaultDistributionStrategy")
|
|
|
|
|
|
class _DefaultDistributionContext(object):
|
|
"""Context manager setting the default `tf.distribute.Strategy`."""
|
|
|
|
__slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
|
|
|
|
def __init__(self, strategy):
|
|
|
|
def creator(next_creator, **kwargs):
|
|
_require_strategy_scope_strategy(strategy)
|
|
return next_creator(**kwargs)
|
|
|
|
self._var_creator_scope = variable_scope.variable_creator_scope(creator)
|
|
self._strategy = strategy
|
|
self._nested_count = 0
|
|
|
|
def __enter__(self):
|
|
# Allow this scope to be entered if this strategy is already in scope.
|
|
if has_strategy():
|
|
raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
|
|
if self._nested_count == 0:
|
|
self._var_creator_scope.__enter__()
|
|
self._nested_count += 1
|
|
return self._strategy
|
|
|
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
self._nested_count -= 1
|
|
if self._nested_count == 0:
|
|
try:
|
|
self._var_creator_scope.__exit__(
|
|
exception_type, exception_value, traceback)
|
|
except RuntimeError as e:
|
|
six.raise_from(
|
|
RuntimeError("Variable creator scope nesting error: move call to "
|
|
"tf.distribute.set_strategy() out of `with` scope."),
|
|
e)
|
|
|
|
|
|
class _DefaultDistributionExtended(StrategyExtendedV1):
|
|
"""Implementation of _DefaultDistributionStrategy."""
|
|
|
|
def __init__(self, container_strategy):
|
|
super(_DefaultDistributionExtended, self).__init__(container_strategy)
|
|
self._retrace_functions_for_each_device = False
|
|
|
|
def _scope(self, strategy):
|
|
"""Context manager setting a variable creator and `self` as current."""
|
|
return _DefaultDistributionContext(strategy)
|
|
|
|
def colocate_vars_with(self, colocate_with_variable):
|
|
"""Does not require `self.scope`."""
|
|
_require_strategy_scope_extended(self)
|
|
return ops.colocate_with(colocate_with_variable)
|
|
|
|
def variable_created_in_scope(self, v):
|
|
return v._distribute_strategy is None # pylint: disable=protected-access
|
|
|
|
def _experimental_distribute_dataset(self, dataset, options):
|
|
return dataset
|
|
|
|
def _distribute_datasets_from_function(self, dataset_fn, options):
|
|
return dataset_fn(InputContext())
|
|
|
|
def _experimental_distribute_values_from_function(self, value_fn):
|
|
return value_fn(ValueContext())
|
|
|
|
def _make_dataset_iterator(self, dataset):
|
|
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
|
|
|
|
def _make_input_fn_iterator(self,
|
|
input_fn,
|
|
replication_mode=InputReplicationMode.PER_WORKER):
|
|
dataset = input_fn(InputContext())
|
|
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
|
|
|
|
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
|
numpy_flat = nest.flatten(numpy_input)
|
|
vars_flat = tuple(
|
|
variable_v1.VariableV1(array_ops.zeros(i.shape, i.dtype),
|
|
trainable=False, use_resource=True)
|
|
for i in numpy_flat
|
|
)
|
|
for v, i in zip(vars_flat, numpy_flat):
|
|
numpy_dataset.init_var_from_numpy(v, i, session)
|
|
vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
|
|
return dataset_ops.Dataset.from_tensor_slices(vars_nested)
|
|
|
|
def _broadcast_to(self, tensor, destinations):
|
|
if destinations is None:
|
|
return tensor
|
|
else:
|
|
raise NotImplementedError("TODO")
|
|
|
|
def _call_for_each_replica(self, fn, args, kwargs):
|
|
with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
|
|
return fn(*args, **kwargs)
|
|
|
|
def _reduce_to(self, reduce_op, value, destinations, options):
|
|
# TODO(josh11b): Use destinations?
|
|
del reduce_op, destinations, options
|
|
return value
|
|
|
|
def _gather_to_implementation(self, value, destinations, axis, options):
|
|
del destinations, axis, options
|
|
return value
|
|
|
|
def _update(self, var, fn, args, kwargs, group):
|
|
# The implementations of _update() and _update_non_slot() are identical
|
|
# except _update() passes `var` as the first argument to `fn()`.
|
|
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
|
|
|
|
def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
|
|
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
|
|
# once that value is used for something.
|
|
with UpdateContext(colocate_with):
|
|
result = fn(*args, **kwargs)
|
|
if should_group:
|
|
return result
|
|
else:
|
|
return nest.map_structure(self._local_results, result)
|
|
|
|
def read_var(self, replica_local_var):
|
|
return array_ops.identity(replica_local_var)
|
|
|
|
def _local_results(self, distributed_value):
|
|
return (distributed_value,)
|
|
|
|
def value_container(self, value):
|
|
return value
|
|
|
|
@property
|
|
def _num_replicas_in_sync(self):
|
|
return 1
|
|
|
|
@property
|
|
def worker_devices(self):
|
|
raise RuntimeError("worker_devices() method unsupported by default "
|
|
"tf.distribute.Strategy.")
|
|
|
|
@property
|
|
def parameter_devices(self):
|
|
raise RuntimeError("parameter_devices() method unsupported by default "
|
|
"tf.distribute.Strategy.")
|
|
|
|
def non_slot_devices(self, var_list):
|
|
return min(var_list, key=lambda x: x.name)
|
|
|
|
def _in_multi_worker_mode(self):
|
|
"""Whether this strategy indicates working in multi-worker settings."""
|
|
# Default strategy doesn't indicate multi-worker training.
|
|
return False
|
|
|
|
@property
|
|
def should_checkpoint(self):
|
|
return True
|
|
|
|
@property
|
|
def should_save_summary(self):
|
|
return True
|
|
|
|
def _get_local_replica_id(self, replica_id_in_sync_group):
|
|
return replica_id_in_sync_group
|
|
|
|
def _get_replica_id_in_sync_group(self, replica_id):
|
|
return replica_id
|
|
|
|
# TODO(priyag): This should inherit from `InputIterator`, once dependency
|
|
# issues have been resolved.
|
|
class DefaultInputIterator(object):
|
|
"""Default implementation of `InputIterator` for default strategy."""
|
|
|
|
def __init__(self, dataset):
|
|
self._dataset = dataset
|
|
if eager_context.executing_eagerly():
|
|
self._iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
else:
|
|
self._iterator = dataset_ops.make_initializable_iterator(dataset)
|
|
|
|
def get_next(self):
|
|
return self._iterator.get_next()
|
|
|
|
def get_next_as_optional(self):
|
|
return self._iterator.get_next_as_optional()
|
|
|
|
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
|
def initialize(self):
|
|
"""Initialize underlying iterators.
|
|
|
|
Returns:
|
|
A list of any initializer ops that should be run.
|
|
"""
|
|
if eager_context.executing_eagerly():
|
|
self._iterator = self._dataset.make_one_shot_iterator()
|
|
return []
|
|
else:
|
|
return [self._iterator.initializer]
|
|
|
|
@property
|
|
def initializer(self):
|
|
"""Returns a list of ops that initialize the iterator."""
|
|
return self.initialize()
|
|
|
|
# TODO(priyag): Delete this once all strategies use global batch size.
|
|
@property
|
|
def _global_batch_size(self):
|
|
"""Global and per-replica batching are equivalent for this strategy."""
|
|
return True
|
|
|
|
|
|
class _DefaultReplicaContext(ReplicaContext):
|
|
"""ReplicaContext for _DefaultDistributionStrategy."""
|
|
|
|
@property
|
|
def replica_id_in_sync_group(self):
|
|
# Return 0 instead of a constant tensor to avoid creating a new node for
|
|
# users who don't use distribution strategy.
|
|
return 0
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# We haven't yet implemented deserialization for DistributedVariables.
|
|
# So here we catch any attempts to deserialize variables
|
|
# when using distribution strategies.
|
|
# pylint: disable=protected-access
|
|
_original_from_proto = ref_variable._from_proto_fn
|
|
|
|
|
|
def _from_proto_fn(v, import_scope=None):
|
|
if has_strategy():
|
|
raise NotImplementedError(
|
|
"Deserialization of variables is not yet supported when using a "
|
|
"tf.distribute.Strategy.")
|
|
else:
|
|
return _original_from_proto(v, import_scope=import_scope)
|
|
|
|
ref_variable._from_proto_fn = _from_proto_fn
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def get_local_results_or_value_container(variable):
|
|
strategy, context = get_strategy_and_replica_context()
|
|
if context:
|
|
return [strategy.extended.value_container(variable)]
|
|
else:
|
|
return strategy.experimental_local_results(variable)
|
|
|
|
|
|
tape.register_watched_variable_resolver(get_local_results_or_value_container)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Metrics to track which distribution strategy is being called
|
|
distribution_strategy_gauge = monitoring.StringGauge(
|
|
"/tensorflow/api/distribution_strategy",
|
|
"Gauge to track the type of distribution strategy used.", "TFVersion")
|
|
distribution_strategy_replica_gauge = monitoring.IntGauge(
|
|
"/tensorflow/api/distribution_strategy/replica",
|
|
"Gauge to track the number of replica each distribution strategy used.",
|
|
"CountType")
|
|
distribution_strategy_input_api_counter = monitoring.Counter(
|
|
"/tensorflow/api/distribution_strategy/input_api",
|
|
"Counter to track the usage of the input APIs", "strategy", "api")
|
|
distributed_variable_creation_time_counter = monitoring.Counter(
|
|
"/tensorflow/api/distribution_strategy/distributed_variable_creation_time_usecs",
|
|
"Time to create distributed variables (us).", "strategy", "if_graph_building")
|
|
distributed_api_time_counter = monitoring.Counter(
|
|
"/tensorflow/api/distribution_strategy/distributed_variable_api_time_usecs",
|
|
"Time spent on an API (us).", "strategy", "api")
|