389 lines
14 KiB
Python
389 lines
14 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utility functions used by values.py and ps_values.py."""
|
|
|
|
from tensorflow.python.distribute import distribute_lib
|
|
from tensorflow.python.distribute import reduce_util
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.saved_model import save_context
|
|
from tensorflow.python.saved_model import save_options
|
|
from tensorflow.python.training.saving import saveable_object
|
|
|
|
|
|
def write_object_proto(var, proto, options):
|
|
"""Update a SavedObject proto for the caller.
|
|
|
|
If a DistributedVariable object supports this method, it will be called when
|
|
saving with a pre-built `SavedObject` proto representing the object, plus an
|
|
instance of `SaveOptions`. This method is then free to modify that proto
|
|
instance.
|
|
|
|
`DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
|
|
write out information about their components to the
|
|
`experimental_distributed_variable_components` field of a
|
|
`SavedVariable` (depending on the `SaveOptions` variable policy).
|
|
|
|
Args:
|
|
var: The DistributedVariable object.
|
|
proto: A pre-built `SavedObject` proto for this object. It is assumed this
|
|
will be a `SavedVariable` instance.
|
|
options: A `SaveOptions` instance.
|
|
"""
|
|
if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access
|
|
):
|
|
for var in var.values:
|
|
var_proto = (
|
|
proto.variable.experimental_distributed_variable_components.add())
|
|
var_proto.name = var.name.split(":")[0]
|
|
var_proto.device = var.device
|
|
|
|
|
|
def get_on_write_saveable(var, primary_var, name):
|
|
"""Return saveable spec for AUTO and ON_WRITE variables."""
|
|
# We use a callable so that we don't have to evaluate this expression
|
|
# in the case where we are trying to restore instead of save.
|
|
def tensor():
|
|
if context.executing_eagerly() and not primary_var.is_initialized():
|
|
# A SaveSpec tensor value of `None` indicates that the variable is
|
|
# uninitialized.
|
|
return None
|
|
strategy = var.distribute_strategy
|
|
return strategy.extended.read_var(var)
|
|
|
|
spec = saveable_object.SaveSpec(
|
|
tensor=tensor,
|
|
slice_spec="",
|
|
name=name,
|
|
dtype=var.dtype,
|
|
device=primary_var.device)
|
|
|
|
return tensor, [spec]
|
|
|
|
|
|
def get_on_write_restore_ops(var, tensor):
|
|
"""Return restore ops for AUTO and ON_WRITE variables."""
|
|
packed_var = var._packed_variable # pylint: disable=protected-access
|
|
if packed_var is not None:
|
|
return control_flow_ops.group(
|
|
tuple(
|
|
assign_on_device(d, packed_var, tensor)
|
|
for d in packed_var.devices))
|
|
return control_flow_ops.group(
|
|
tuple(
|
|
assign_on_device(v.device, v, tensor)
|
|
for v in var.values))
|
|
|
|
|
|
def get_on_read_saveable(var, primary_var, name):
|
|
"""Return saveables for ON_READ variable."""
|
|
|
|
# We use a callable so that we don't have to evaluate this expression
|
|
# in the case where we are trying to restore instead of save.
|
|
def tensor():
|
|
return var._get_cross_replica() # pylint: disable=protected-access
|
|
|
|
spec = saveable_object.SaveSpec(
|
|
tensor=tensor,
|
|
slice_spec="",
|
|
name=name,
|
|
dtype=var.dtype,
|
|
device=primary_var.device)
|
|
|
|
return tensor, [spec]
|
|
|
|
|
|
def get_on_read_restore_ops(var, tensor, aggregation):
|
|
"""Return restore ops for ON_READ variables."""
|
|
# To preserve the sum across save and restore, we have to divide the
|
|
# total across all devices when restoring a variable that was summed
|
|
# when saving.
|
|
if aggregation == vs.VariableAggregation.SUM:
|
|
strategy = var.distribute_strategy
|
|
tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
|
|
var.dtype)
|
|
return control_flow_ops.group(
|
|
tuple(
|
|
assign_on_device(v.device, v, tensor)
|
|
for v in var.values))
|
|
|
|
|
|
# Utility function that indicates if you are in an UpdateContext when running
|
|
# in a replica fn.
|
|
def in_replica_update_context():
|
|
return distribute_lib.get_update_replica_id() is not None
|
|
|
|
|
|
def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
|
|
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=assign_fn,
|
|
value=value,
|
|
use_locking=use_locking,
|
|
name=name,
|
|
read_value=read_value)
|
|
|
|
|
|
def on_write_assign_add(var, value, use_locking=False, name=None,
|
|
read_value=True):
|
|
assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=assign_add_fn,
|
|
value=value,
|
|
use_locking=use_locking,
|
|
name=name,
|
|
read_value=read_value)
|
|
|
|
|
|
def on_write_assign_sub(var, value, use_locking=False, name=None,
|
|
read_value=True):
|
|
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=assign_sub_fn,
|
|
value=value,
|
|
use_locking=use_locking,
|
|
name=name,
|
|
read_value=read_value)
|
|
|
|
|
|
def assign_on_each_device(var, assign_func, value, read_value):
|
|
"""Update the variable on each replica with the given assign_func and value."""
|
|
if var._packed_variable is not None: # pylint: disable=protected-access
|
|
update = control_flow_ops.group(
|
|
tuple(
|
|
assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access
|
|
else:
|
|
update = control_flow_ops.group(
|
|
tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
|
|
if not read_value:
|
|
return update
|
|
with ops.control_dependencies([update] if update else []):
|
|
return var.read_value()
|
|
|
|
|
|
def on_read_assign_sub_cross_replica(var, value, read_value=True):
|
|
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
|
|
if distribute_lib.in_cross_replica_context():
|
|
if var.aggregation == vs.VariableAggregation.SUM:
|
|
raise ValueError(
|
|
"SyncOnReadVariable does not support `assign_sub` in "
|
|
"cross-replica context when aggregation is set to "
|
|
"`tf.VariableAggregation.SUM`.")
|
|
return assign_on_each_device(var, assign_sub_on_device,
|
|
value, read_value)
|
|
|
|
|
|
def on_read_assign_add_cross_replica(var, value, read_value=True):
|
|
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
|
|
if distribute_lib.in_cross_replica_context():
|
|
if var.aggregation == vs.VariableAggregation.SUM:
|
|
raise ValueError(
|
|
"SyncOnReadVariable does not support `assign_add` in "
|
|
"cross-replica context when aggregation is set to "
|
|
"`tf.VariableAggregation.SUM`.")
|
|
return assign_on_each_device(var, assign_add_on_device,
|
|
value, read_value)
|
|
|
|
|
|
def on_read_assign_cross_replica(var, value, read_value=True):
|
|
"""Return the value of the variable in cross replica context."""
|
|
with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
|
|
if distribute_lib.in_cross_replica_context():
|
|
# To preserve the sum across save and restore, we have to divide the
|
|
# total across all devices when restoring a variable that was summed
|
|
# when saving.
|
|
tensor = value
|
|
if var.aggregation == vs.VariableAggregation.SUM:
|
|
strategy = var._distribute_strategy # pylint: disable=protected-access
|
|
tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
|
|
var.dtype)
|
|
return assign_on_each_device(var, assign_on_device, tensor,
|
|
read_value)
|
|
|
|
|
|
def scatter_sub(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_sub_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def scatter_add(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_add_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def scatter_mul(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_mul_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def scatter_div(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_div_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def scatter_min(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_min_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def scatter_max(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_max_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def scatter_update(var, sparse_delta, use_locking=False, name=None):
|
|
scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
|
|
return var._update( # pylint: disable=protected-access
|
|
update_fn=scatter_update_fn,
|
|
value=sparse_delta,
|
|
use_locking=use_locking,
|
|
name=name)
|
|
|
|
|
|
def get_current_replica_id_as_int():
|
|
"""Returns the current replica ID as an integer, or `None`."""
|
|
replica_context = distribute_lib.get_replica_context()
|
|
if replica_context:
|
|
replica_id = replica_context._replica_id # pylint: disable=protected-access
|
|
if not isinstance(replica_id, int):
|
|
replica_id = tensor_util.constant_value(replica_id)
|
|
else:
|
|
replica_id = distribute_lib.get_update_replica_id()
|
|
return replica_id
|
|
|
|
|
|
def assign_on_device(device, variable, tensor):
|
|
with ops.device(device):
|
|
return variable.assign(tensor)
|
|
|
|
|
|
def assign_add_on_device(device, variable, tensor):
|
|
with ops.device(device):
|
|
return variable.assign_add(tensor)
|
|
|
|
|
|
def assign_sub_on_device(device, variable, tensor):
|
|
with ops.device(device):
|
|
return variable.assign_sub(tensor)
|
|
|
|
|
|
def assert_replica_context(strategy):
|
|
replica_context = distribute_lib.get_replica_context()
|
|
if not replica_context:
|
|
raise RuntimeError(
|
|
"Replica-local variables may only be assigned in a replica context.")
|
|
if replica_context.strategy is not strategy:
|
|
raise RuntimeError(
|
|
"Replica-local variables may only be assigned in a replica context.")
|
|
|
|
|
|
def apply_aggregation(strategy, value, aggregation, destinations):
|
|
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
|
return strategy.extended.broadcast_to(
|
|
strategy.experimental_local_results(value)[0],
|
|
destinations=destinations)
|
|
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
|
|
return strategy.extended.reduce_to(reduce_op, value, destinations)
|
|
|
|
|
|
aggregation_error_msg = (
|
|
"You must specify an aggregation method to update a "
|
|
"{variable_type} in Replica Context. You can do so by passing "
|
|
"an explicit value for argument `aggregation` to tf.Variable(..)."
|
|
"e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
|
|
"`tf.VariableAggregation` lists the possible aggregation methods."
|
|
"This is required because {variable_type} should always be "
|
|
"kept in sync. When updating them or assigning to them in a "
|
|
"replica context, we automatically try to aggregate the values "
|
|
"before updating the variable. For this aggregation, we need to "
|
|
"know the aggregation method. "
|
|
"Another alternative is to not try to update such "
|
|
"{variable_type} in replica context, but in cross replica "
|
|
"context. You can enter cross replica context by calling "
|
|
"`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
|
|
"Inside `merge_fn`, you can then update the {variable_type} "
|
|
"using `tf.distribute.StrategyExtended.update()`.")
|
|
|
|
|
|
scatter_error_msg = ("{op_name} is only supported for mirrored "
|
|
"variable (variable created within certain "
|
|
"`tf.distribute.Strategy` scope) with NONE or "
|
|
"`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")
|
|
|
|
|
|
def is_saving_non_distributed():
|
|
"""Returns whether we're saving a non-distributed version of the model.
|
|
|
|
It returns True iff we are in saving context and are saving a non-distributed
|
|
version of the model. That is, SaveOptions.experimental_variable_policy is
|
|
NONE.
|
|
|
|
Returns:
|
|
A boolean.
|
|
"""
|
|
if not save_context.in_save_context():
|
|
return False
|
|
options = save_context.get_save_options()
|
|
return (options.experimental_variable_policy !=
|
|
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
|
|
|
|
|
|
def mark_as_unsaveable():
|
|
"""Marks the function as unsaveable if not inside save context."""
|
|
if ops.inside_function() and not save_context.in_save_context():
|
|
ops.get_default_graph().mark_as_unsaveable("""
|
|
ConcreteFunction that uses distributed variables in certain way cannot be saved.
|
|
If you're saving with
|
|
|
|
tf.saved_model.save(..., signatures=f.get_concrete_function())
|
|
|
|
do
|
|
|
|
@tf.function(input_signature=...)
|
|
def f_with_input_signature():
|
|
...
|
|
|
|
tf.saved_model.save(..., signatures=f_with_input_signature)`
|
|
|
|
instead.""")
|