"""Class implementing utilities used by tf.distribute.Strategy."""
from collections import abc
import contextlib
import threading
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_values as tpu_values_lib
from tensorflow.python.distribute import values as values_lib
from tensorflow.python.distribute.reduce_util import ReduceOp
from tensorflow.python.eager import context
from tensorflow.python.eager import record
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
def get_loss_reduction():
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction.
`tf.distribute.ReduceOp` corresponding to the last loss reduction for
estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
if not distribute_lib.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access
# If we are not in Estimator context then return 'SUM'. We do not need to
# scale loss in the optimizer.
return ReduceOp.SUM
last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if (last_reduction == losses_impl.Reduction.SUM or
last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM
return ReduceOp.SUM
return ReduceOp.MEAN
def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
values: Values to regroup
wrap_class: Class that `values` be wrapped in.
always_wrap: Always wrap the `values` in `wrap_class` even if the values
are the same except for DistributeVariable.
Wrapped `values`.
v0 = values[0]
if isinstance(v0, list):
for v in values[1:]:
assert isinstance(v, list)
assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
(len(v), len(v0), v, v0))
return [
regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
for i in range(len(v0))
if isinstance(v0, tuple):
for v in values[1:]:
assert isinstance(v, tuple)
assert len(v) == len(v0), ("Values to regroup had different lengths: "
f"len(v) == {len(v)}, len(v0) == {len(v0)}, "
f"v: {v}, v0: {v0}")
regrouped_tuple = tuple(
regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
for i in range(len(v0)))
if hasattr(v0, "_fields"):
# This tuple is in fact a namedtuple! Create a new namedtuple instance
# and initialize it with the regrouped values:
assert hasattr(v0, "_make")
return v0._make(regrouped_tuple)
return regrouped_tuple
if isinstance(v0, abc.Mapping):
v0keys = v0.keys()
for v in values[1:]:
assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v))
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
(set(v0keys), set(v.keys())))
# Use the actual type in case it is a class inherited from a dict.
return type(v0)({
key: regroup(tuple(v[key] for v in values),
wrap_class, always_wrap)
for key in v0keys
# If exactly the same object across all devices, return it unwrapped.
same_id = True
for v in values[1:]:
if v is not v0:
same_id = False
# Consider three cases where same_id is true:
# * If v0 is a DistributedVariable (a MirroredVariable or
# SyncOnReadVariable, and same_id means it is the same across all
# devices), we want to return it. We check DistributedVariable
# specifically since it can look like it has a
# _distributed_container member since its members do.
if same_id and isinstance(v0, values_lib.DistributedVariable):
return v0
# * If v0 is a member of a distributed variable, in which case
# value_container(v0) is not v0 itself, we want to
# return the DistributedVariable that contains it using the
# _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0 unless `always_wrap` is
# true.
if same_id and not always_wrap and value_container(v0) is v0:
return v0
# Detect the case where each device has a parallel component of the
# same MirroredVariable (or SyncOnReadVariable). In this case we
# want to return the containing MirroredVariable, after a bunch of
# sanity checking. In particular, each component should have the
# same container, and the devices of the variables should match the
# keys of the per-replica dictionary. For _UnreadVariables, use the wrap_class
# path, which calls tf.identity on them.
if (not isinstance(v0, resource_variable_ops._UnreadVariable) and # pylint: disable=protected-access
value_container(v0) is not v0):
# pylint: disable=protected-access
assert not isinstance(v0, values_lib.MirroredVariable), (
"ids = %s, values = %s" % ([id(v) for v in values], values))
distributed_container = value_container(v0)
assert distributed_container is not None
for v in values[1:]:
assert distributed_container is value_container(v)
return distributed_container
# pylint: enable=protected-access
return wrap_class(values)
def select_replica(replica_id, structured):
"""Specialize a nest of regular & per-replica values for one replica."""
def _get(x):
# `DistributedValues` would be sliced according to replica unless it is a
# `DistributedVariable` because `DistributedVariable` can be handled
# directly in the replica context.
if (isinstance(x, values_lib.DistributedVariable) or
not isinstance(x, values_lib.DistributedValues)):
return x
return x.values[replica_id]
return nest.map_structure(_get, structured)
def select_replica_mirrored(replica_id, structured):
"""Specialize a nest of regular & mirrored values for one replica."""
return select_replica(replica_id, structured)
def assert_mirrored(structured):
"""Raises if the structured is not composed of mirrored or regular values."""
def _assert_mirrored(x):
if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x):
raise TypeError(
"Expected value to be mirrored across replicas: %s in %s." %
(x, structured))
nest.map_structure(_assert_mirrored, structured)
def update_regroup(extended, updates, group):
"""Regroup for an update, with dependencies to ensure all updates execute."""
if not group:
regrouped = regroup(updates, values_lib.Mirrored)
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
def _make_grouped_mirrored(values):
"""Convert per-replica list `values` into Mirrored type with grouping."""
if len(values) == 1:
return values_lib.Mirrored(values)
# Make sure we run all updates. Without this, something like
# session.run(extended.update(...)) may only update one replica.
g = control_flow_ops.group(values)
# If values is just ops, the grouping is enough. Everything in values
# should have the same type, since we expect every replica to be performing
# the same computation.
if not all(tensor_util.is_tf_type(v) for v in values):
return g
# Otherwise we need tensors with the same values as `values`, but
# that have a dependency on `g`.
with_dep = []
for v in values:
with ops.device(v.device), ops.control_dependencies([g]):
return values_lib.Mirrored(with_dep)
return regroup(updates, _make_grouped_mirrored)
def value_container(val):
"""Returns the container that this per-replica `value` belongs to.
val: A value returned by `call_for_each_replica()` or a variable created in
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
# DistributedVariable has _distributed_container defined but we don't want to
# return it.
container = None
if not isinstance(val, values_lib.DistributedVariable):
if hasattr(val, "_distributed_container"):
container = val._distributed_container() # pylint: disable=protected-access
elif (isinstance(val, composite_tensor.CompositeTensor) and
hasattr(val, "handle") and
hasattr(val.handle, "_distributed_container")):
# For ResourceVariables, the _distributed_container attribute
# is added to their handle tensors.
container = val.handle._distributed_container() # pylint: disable=protected-access
return container if container is not None else val
def is_distributed_variable(v):
"""Determine if a variable is ds variable or TPU mirrored variable."""
return getattr(v, "is_distributed_variable", False)
def is_distributed_table(v):
"""Determine if an object is a DistributedTable."""
return getattr(v, "is_distributed_table", False)
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not %s created in scope: %s" %
(v, variable_strategy))
def validate_colocate_distributed_variable(v, extended):
if not isinstance(v, values_lib.DistributedVariable):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
def validate_colocate(v, extended):
if not hasattr(v, "_distribute_strategy"):
raise ValueError(
"`colocate_vars_with` must only be passed a variable created in this "
"tf.distribute.Strategy.scope(), not: %r" % (v,))
_validate_colocate_extended(v, extended)
# Variable creation function for sync strategies.
def _validate_synchronization(kwargs):
"""Validate that given synchronization value is valid."""
synchronization = kwargs.get("synchronization",
if synchronization == vs.VariableSynchronization.NONE:
raise ValueError(
"`NONE` variable synchronization mode is not supported with "
"tf.distribute strategy. Please change the `synchronization` for "
"variable: " + str(kwargs["name"]))
if synchronization not in (vs.VariableSynchronization.ON_READ,
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
if synchronization == vs.VariableSynchronization.AUTO:
return vs.VariableSynchronization.ON_WRITE
return synchronization
def _validate_aggregation(kwargs):
aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE)
if aggregation not in (vs.VariableAggregation.NONE,
raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
return aggregation
def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
policy_mapping, **kwargs):
"""Create distributed variables with given synchronization and aggregation."""
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
if kwargs.pop("experimental_batch_initialization", None):
variable_class_key = "LazyVariableClass"
variable_class_key = "VariableClass"
var_collections = kwargs.pop("collections", None)
if var_collections is None:
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
synchronization = _validate_synchronization(kwargs)
# Update synchronization in kwargs in case it's AUTO, which is converted to
kwargs["synchronization"] = synchronization
aggregation = _validate_aggregation(kwargs)
use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with record.stop_recording():
value_list = real_mirrored_creator(**kwargs)
# MirroredVariable is recreated during saved_model loading, and its
# component variables (value_list) will have None initializer. We
# set their initializers to no_op so that consumer like
# `global_variables_initializer` wouldn't complain, as it groups all
# variables' initializers thus all variables have to have initializers.
for v in value_list:
# pylint:disable=protected-access
if hasattr(v, "_initializer_op") and v._initializer_op is None:
v._initializer_op = control_flow_ops.no_op()
# pylint:enable=protected-access
if use_var_policy:
var_policy_cls = policy_mapping.get(synchronization)
var_policy = var_policy_cls(aggregation=aggregation)
var_cls = class_mapping.get(variable_class_key)
result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
var_cls = class_mapping.get(synchronization)
result = var_cls(strategy, value_list, aggregation)
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for value in value_list:
for i, trainable_variable in enumerate(l):
if value is trainable_variable:
del l[i]
g.add_to_collections(var_collections, result)
elif ops.GraphKeys.GLOBAL_STEP in var_collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
# Utility functions
# Return True if the Value is Mirrored or the Variable is replicated and kept in
# sync.
def is_mirrored(val):
return (getattr(val, "_is_mirrored", lambda: False))()
def is_sync_on_read(val):
return not is_mirrored(val)
class CachingScopeLocal(threading.local):
"""Class for maintaining thread local state for caching scope."""
def __init__(self):
super(CachingScopeLocal, self).__init__()
self.new_cache_scope_count = 0
self.cache_scope_exited_count = 0
def enter_scope(self):
self.new_cache_scope_count += 1
def exit_scope(self):
self.cache_scope_exited_count += 1
def in_caching_scope(self):
return self.new_cache_scope_count > self.cache_scope_exited_count
caching_scope_local = CachingScopeLocal()
def cache_variable_reads():
"""Scope for caching variable reads for AggregatingVariable.
The variable reads for AggregatingVariable inside this scope are cached. i.e.
the first read of variable reads the value from possibly remote handle, but
subsequent reads are returned using local cached value.
For example:
strategy = ParameterServerStrategy...
with strategy.scope():
# Variable v is of AggregatingVariable type with actual variable residing
# on PS.
v = tf.Variable(1.0)
with distribute_utils.cache_variable_reads():
v.read_value() # Reads value 1.0
v.assign(constant_op.constant(5.0)) # v changes to 5.0
t1 = v.read_value()
t2 = v.read_value() # Both t1 & t2 return cached value 1.0 from local CPU.
Notes about cache_variable_reads scope:
1. Nesting of scope cache_variable_reads() is not supported
2. And when caching scope is enabled, the thread enabling the cache and
mirrored_run._MirroredReplicaThread threads spawned from it will have
caching enabled.
A context for caching variables.
if caching_scope_local.in_caching_scope():
# There is nested cache scope, which is not supported.
raise ValueError("cache_variable_reads scope cannot be nested")
# The following mapping indicates the policy that you must use for a given
# variable `synchronization` and `aggregation` pair.
# OnWritePolicy is used for:
# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# OnReadPolicy is used for:
# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
"VariableClass": values_lib.DistributedVariable,
vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
"VariableClass": tpu_values_lib.TPUDistributedVariable,
"LazyVariableClass": tpu_values_lib.TPULazyDistributedVariable,
vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,