420 lines
16 KiB
Python
420 lines
16 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Critical Section object and execution logic."""
|
|
|
|
import collections
|
|
import contextlib
|
|
import threading
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import gen_resource_variable_ops
|
|
from tensorflow.python.ops import tensor_array_ops
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import object_identity
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
__all__ = ["CriticalSection"]
|
|
|
|
|
|
# Graph Keys
|
|
CRITICAL_SECTIONS = "critical_sections"
|
|
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
|
|
|
|
|
|
class _ExecutionSignature(
|
|
collections.namedtuple("_ExecutionSignature",
|
|
("op", "handle",
|
|
"resources", "exclusive_resource_access"))):
|
|
"""A class storing an `ExecuteInCriticalResource` op and associated attrs."""
|
|
pass
|
|
|
|
|
|
def _identity(x):
|
|
"""Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
|
|
if isinstance(x, tensor_array_ops.TensorArray):
|
|
return x.identity()
|
|
elif isinstance(x, ops.Operation):
|
|
return control_flow_ops.group(x)
|
|
elif context.executing_eagerly() and x is None:
|
|
return None
|
|
else:
|
|
return array_ops.identity(x)
|
|
|
|
|
|
def _get_device_or_colocation(op):
|
|
return op.device or _get_colocation(op)
|
|
|
|
|
|
def _get_colocation(op):
|
|
"""Get colocation symbol from op, if any."""
|
|
try:
|
|
return op.get_attr("_class")
|
|
except (ValueError, AttributeError):
|
|
return None
|
|
|
|
|
|
_CRITICAL_SECTION_STACK = threading.local()
|
|
|
|
|
|
def _get_critical_section_stack():
|
|
try:
|
|
return _CRITICAL_SECTION_STACK.value
|
|
except AttributeError:
|
|
_CRITICAL_SECTION_STACK.value = []
|
|
return _CRITICAL_SECTION_STACK.value
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _push_critical_section_stack(signature):
|
|
"""Push a CriticalSection._signature to the thread-local stack.
|
|
|
|
If the signature is already on the stack, raise an error because it means
|
|
we're trying to execute inside the same locked CriticalSection, which
|
|
will create a deadlock.
|
|
|
|
Args:
|
|
signature: Tuple of the type `CriticalSection._signature`. Uniquely
|
|
identifies a CriticalSection by its `shared_name`, `container`,
|
|
and device.
|
|
|
|
Yields:
|
|
An empty value. The context is guaranteed to run without deadlock.
|
|
|
|
Raises:
|
|
ValueError: If the signature is already on the stack.
|
|
RuntimeError: If another thread or function modifies the current stack
|
|
entry during the yield.
|
|
"""
|
|
stack = _get_critical_section_stack()
|
|
if signature in stack:
|
|
raise ValueError(
|
|
f"Attempting to lock a CriticalSection (signature={signature}) in which"
|
|
" we are already running. This is illegal and may cause deadlocks.")
|
|
stack.append(signature)
|
|
try:
|
|
yield
|
|
finally:
|
|
received_signature = stack.pop()
|
|
if received_signature != signature:
|
|
raise RuntimeError(
|
|
"CriticalSection stack inconsistency: expected signature "
|
|
f"{signature} but received {received_signature}")
|
|
|
|
|
|
@tf_export("CriticalSection")
|
|
class CriticalSection:
|
|
"""Critical section.
|
|
|
|
A `CriticalSection` object is a resource in the graph which executes subgraphs
|
|
in **serial** order. A common example of a subgraph one may wish to run
|
|
exclusively is the one given by the following function:
|
|
|
|
```python
|
|
v = resource_variable_ops.ResourceVariable(0.0, name="v")
|
|
|
|
def count():
|
|
value = v.read_value()
|
|
with tf.control_dependencies([value]):
|
|
with tf.control_dependencies([v.assign_add(1)]):
|
|
return tf.identity(value)
|
|
```
|
|
|
|
Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
|
|
The snapshot value is returned.
|
|
|
|
If multiple workers or threads all execute `count` in parallel, there is no
|
|
guarantee that access to the variable `v` is atomic at any point within
|
|
any thread's calculation of `count`. In fact, even implementing an atomic
|
|
counter that guarantees that the user will see each value `0, 1, ...,` is
|
|
currently impossible.
|
|
|
|
The solution is to ensure any access to the underlying resource `v` is
|
|
only processed through a critical section:
|
|
|
|
```python
|
|
cs = CriticalSection()
|
|
f1 = cs.execute(count)
|
|
f2 = cs.execute(count)
|
|
output = f1 + f2
|
|
session.run(output)
|
|
```
|
|
The functions `f1` and `f2` will be executed serially, and updates to `v`
|
|
will be atomic.
|
|
|
|
**NOTES**
|
|
|
|
All resource objects, including the critical section and any captured
|
|
variables of functions executed on that critical section, will be
|
|
colocated to the same device (host and cpu/gpu).
|
|
|
|
When using multiple critical sections on the same resources, there is no
|
|
guarantee of exclusive access to those resources. This behavior is disallowed
|
|
by default (but see the kwarg `exclusive_resource_access`).
|
|
|
|
For example, running the same function in two separate critical sections
|
|
will not ensure serial execution:
|
|
|
|
```python
|
|
v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
|
|
def accumulate(up):
|
|
x = v.read_value()
|
|
with tf.control_dependencies([x]):
|
|
with tf.control_dependencies([v.assign_add(up)]):
|
|
return tf.identity(x)
|
|
ex1 = CriticalSection().execute(
|
|
accumulate, 1.0, exclusive_resource_access=False)
|
|
ex2 = CriticalSection().execute(
|
|
accumulate, 1.0, exclusive_resource_access=False)
|
|
bad_sum = ex1 + ex2
|
|
sess.run(v.initializer)
|
|
sess.run(bad_sum) # May return 0.0
|
|
```
|
|
"""
|
|
|
|
def __init__(self, name=None, shared_name=None,
|
|
critical_section_def=None, import_scope=None):
|
|
"""Creates a critical section."""
|
|
context.ensure_initialized()
|
|
if critical_section_def and name is not None:
|
|
raise ValueError(f"Arguments critical_section_def={critical_section_def} "
|
|
f"and shared_name={shared_name} are mutually exclusive. "
|
|
"Please only specify one of them.")
|
|
if critical_section_def:
|
|
raise ValueError("Argument `critical_section_def` is not supported.")
|
|
else:
|
|
self._init_from_args(name, shared_name)
|
|
|
|
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
|
|
"""Initialize the CriticalSection from constructor arguments."""
|
|
with ops.name_scope(name, "CriticalSection", []) as name:
|
|
with ops.init_scope():
|
|
# pylint: disable=protected-access
|
|
container = ops.get_default_graph()._container
|
|
# pylint: enable=protected-access
|
|
if shared_name is None:
|
|
shared_name = name
|
|
if container is None:
|
|
container = ""
|
|
self._handle = gen_resource_variable_ops.mutex_v2(
|
|
shared_name=shared_name, container=container, name=name)
|
|
# Get a uniquely identifying signature for the handle.
|
|
self._signature = (
|
|
container,
|
|
# If shared_name is empty, a unique CriticalSection is created.
|
|
shared_name or id(self._handle),
|
|
_get_device_or_colocation(self._handle))
|
|
|
|
if not context.executing_eagerly():
|
|
ops.add_to_collections(CRITICAL_SECTIONS, self)
|
|
|
|
@property
|
|
def name(self):
|
|
return self._handle.op.name
|
|
|
|
def execute(self, fn, exclusive_resource_access=True, name=None):
|
|
"""Execute function `fn()` inside the critical section.
|
|
|
|
`fn` should not accept any arguments. To add extra arguments to when
|
|
calling `fn` in the critical section, create a lambda:
|
|
|
|
```python
|
|
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
|
|
```
|
|
|
|
Args:
|
|
fn: The function to execute. Must return at least one tensor.
|
|
exclusive_resource_access: Whether the resources required by
|
|
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
|
|
You may want to set this to `False` if you will be accessing a
|
|
resource in read-only mode in two different CriticalSections.
|
|
name: The name to use when creating the execute operation.
|
|
|
|
Returns:
|
|
The tensors returned from `fn()`.
|
|
|
|
Raises:
|
|
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
|
|
or lazy way that may cause a deadlock.
|
|
ValueError: If `exclusive_resource_access == True` and
|
|
another `CriticalSection` has an execution requesting the same
|
|
resources as `fn``. Note, even if `exclusive_resource_access` is
|
|
`True`, if another execution in another `CriticalSection` was created
|
|
without `exclusive_resource_access=True`, a `ValueError` will be raised.
|
|
"""
|
|
with ops.name_scope(name, "critical_section_execute", []):
|
|
# Ensure that mutex locking only happens *after* all args and
|
|
# kwargs have been executed. This avoids certain types of deadlocks.
|
|
with _push_critical_section_stack(self._signature):
|
|
lock = gen_resource_variable_ops.mutex_lock(self._handle)
|
|
|
|
if not context.executing_eagerly():
|
|
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
|
|
# Operations created by other threads.
|
|
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
|
existing_ops = ops.get_default_graph().get_operations()
|
|
with ops.control_dependencies([lock]):
|
|
r = fn()
|
|
# TODO(ebrevdo): If creating critical sections in a python loop,
|
|
# this makes graph creation time quadratic. Revisit if this
|
|
# becomes a problem.
|
|
created_ops = (set(ops.get_default_graph().get_operations())
|
|
.difference(existing_ops))
|
|
else:
|
|
with ops.control_dependencies([lock]):
|
|
r = fn()
|
|
|
|
if not context.executing_eagerly():
|
|
self._add_control_dependencies_to_lock(created_ops, lock.op)
|
|
|
|
# captured_resources is a list of resources that are directly
|
|
# accessed only by ops created during fn(), not by any
|
|
# ancestors of those ops in the graph.
|
|
captured_resources = object_identity.ObjectIdentitySet([
|
|
input_ for op in created_ops
|
|
for input_ in op.inputs
|
|
if input_.dtype == dtypes.resource
|
|
])
|
|
|
|
# NOTE(ebrevdo): The only time self._is_self_handle() is True
|
|
# in this call is if one of the recently created ops, within
|
|
# the execute(), themselves attempt to access the
|
|
# CriticalSection. This will cause a deadlock.
|
|
if any(self._is_self_handle(x) for x in captured_resources):
|
|
raise ValueError(
|
|
"Attempting to lock a CriticalSection in which we are "
|
|
f"already running (signature={self._signature}). This is illegal "
|
|
"and may cause deadlocks.")
|
|
|
|
self._check_multiple_access_to_resources(
|
|
captured_resources, exclusive_resource_access)
|
|
|
|
r_flat = [_identity(x) for x in nest.flatten(r)]
|
|
|
|
with ops.control_dependencies(r_flat):
|
|
# The identity must run on the same machine as self._handle
|
|
with ops.colocate_with(self._handle):
|
|
# Do not use array_ops.identity as there are special
|
|
# optimizations within TensorFlow which seem to elide it
|
|
# even when optimizations are disabled(!).
|
|
ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
|
|
lock)
|
|
|
|
# Make sure that if any element of r is accessed, all of
|
|
# them are executed together.
|
|
r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
|
|
|
|
with ops.control_dependencies([ensure_lock_exists]):
|
|
outputs = nest.map_structure(_identity, r)
|
|
|
|
if not context.executing_eagerly():
|
|
signature = _ExecutionSignature(
|
|
op=lock.op,
|
|
handle=self._handle,
|
|
resources=list(captured_resources),
|
|
exclusive_resource_access=exclusive_resource_access)
|
|
ops.add_to_collections(
|
|
CRITICAL_SECTION_EXECUTIONS, signature)
|
|
|
|
return outputs
|
|
|
|
def _add_control_dependencies_to_lock(self, created_ops, lock_op):
|
|
"""To avoid deadlocks, all args must be executed before lock_op."""
|
|
# Get all arguments (explicit and captured) of all ops created by fn().
|
|
all_args = set([input_.op for op in created_ops for input_ in op.inputs])
|
|
all_args.update(
|
|
input_op for op in created_ops for input_op in op.control_inputs)
|
|
# Unfortunately, we can't use sets throughout because TF seems to
|
|
# create new Operation objects for the same op sometimes; and we
|
|
# can't rely on id(op).
|
|
|
|
# pylint: disable=protected-access
|
|
all_args_dict = dict((op._id, op) for op in all_args)
|
|
|
|
# Remove ops created within fn, or that lock_op already has a
|
|
# control dependency on. Also remove a possible self-loop.
|
|
for op in created_ops:
|
|
all_args_dict.pop(op._id, None)
|
|
for op in lock_op.control_inputs:
|
|
all_args_dict.pop(op._id, None)
|
|
for input_ in lock_op.inputs:
|
|
all_args_dict.pop(input_.op._id, None)
|
|
all_args_dict.pop(lock_op._id, None)
|
|
|
|
all_args = all_args_dict.values()
|
|
|
|
if not all_args:
|
|
# No control dependencies to add; return early.
|
|
return
|
|
|
|
# This group is important: it ensures that any ops in all_args
|
|
# outside the control context of the lock_op (and this fn, which
|
|
# runs in the same context) are added to this context before
|
|
# being added to the control dependencies of lock_op.
|
|
all_args = control_flow_ops.group(*all_args)
|
|
|
|
lock_op._add_control_input(all_args)
|
|
# pylint: enable=protected-access
|
|
|
|
def _is_self_handle(self, x):
|
|
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
|
|
if isinstance(x, ops.EagerTensor):
|
|
return x is self._handle
|
|
return (x.op.type == "MutexV2"
|
|
# blank shared_name means the op will create a unique one.
|
|
and x.op.get_attr("shared_name")
|
|
and (x.op.get_attr("shared_name") ==
|
|
self._handle.op.get_attr("shared_name"))
|
|
and (x.op.device == self._handle.op.device
|
|
or _get_colocation(x.op) == _get_colocation(self._handle.op)))
|
|
|
|
def _check_multiple_access_to_resources(
|
|
self, captured_resources, exclusive_resource_access):
|
|
"""Raise if captured_resources are accessed by another CriticalSection.
|
|
|
|
Args:
|
|
captured_resources: Set of tensors of type resource.
|
|
exclusive_resource_access: Whether this execution requires exclusive
|
|
resource access.
|
|
|
|
Raises:
|
|
ValueError: If any tensors in `captured_resources` are also accessed
|
|
by another `CriticalSection`, and at least one of them requires
|
|
exclusive resource access.
|
|
"""
|
|
# Collections and op introspection does not work in eager
|
|
# mode. This is generally ok; since eager mode (as of
|
|
# writing) executes sequentially anyway.
|
|
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
|
|
if self._is_self_handle(sg.handle):
|
|
# Other executions in the same critical section are allowed.
|
|
continue
|
|
if not (exclusive_resource_access or sg.exclusive_resource_access):
|
|
# Neither execution requested exclusive access.
|
|
continue
|
|
resource_intersection = captured_resources.intersection(sg.resources)
|
|
if resource_intersection:
|
|
raise ValueError(
|
|
"This execution would access resources: "
|
|
f"{list(resource_intersection)}. Either this lock "
|
|
f"(CriticalSection: {self._handle}) or lock '{sg}' "
|
|
f"(CriticalSection: {sg.handle}) requested exclusive resource "
|
|
"access of this resource. Did you mean to call execute with "
|
|
"keyword argument exclusive_resource_access=False?")
|