3RNN/Lib/site-packages/tensorflow/python/ops/critical_section_ops.py

420 lines
16 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critical Section object and execution logic."""
import collections
import contextlib
import threading
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
__all__ = ["CriticalSection"]
# Graph Keys
CRITICAL_SECTIONS = "critical_sections"
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
class _ExecutionSignature(
collections.namedtuple("_ExecutionSignature",
("op", "handle",
"resources", "exclusive_resource_access"))):
"""A class storing an `ExecuteInCriticalResource` op and associated attrs."""
pass
def _identity(x):
"""Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
if isinstance(x, tensor_array_ops.TensorArray):
return x.identity()
elif isinstance(x, ops.Operation):
return control_flow_ops.group(x)
elif context.executing_eagerly() and x is None:
return None
else:
return array_ops.identity(x)
def _get_device_or_colocation(op):
return op.device or _get_colocation(op)
def _get_colocation(op):
"""Get colocation symbol from op, if any."""
try:
return op.get_attr("_class")
except (ValueError, AttributeError):
return None
_CRITICAL_SECTION_STACK = threading.local()
def _get_critical_section_stack():
try:
return _CRITICAL_SECTION_STACK.value
except AttributeError:
_CRITICAL_SECTION_STACK.value = []
return _CRITICAL_SECTION_STACK.value
@contextlib.contextmanager
def _push_critical_section_stack(signature):
"""Push a CriticalSection._signature to the thread-local stack.
If the signature is already on the stack, raise an error because it means
we're trying to execute inside the same locked CriticalSection, which
will create a deadlock.
Args:
signature: Tuple of the type `CriticalSection._signature`. Uniquely
identifies a CriticalSection by its `shared_name`, `container`,
and device.
Yields:
An empty value. The context is guaranteed to run without deadlock.
Raises:
ValueError: If the signature is already on the stack.
RuntimeError: If another thread or function modifies the current stack
entry during the yield.
"""
stack = _get_critical_section_stack()
if signature in stack:
raise ValueError(
f"Attempting to lock a CriticalSection (signature={signature}) in which"
" we are already running. This is illegal and may cause deadlocks.")
stack.append(signature)
try:
yield
finally:
received_signature = stack.pop()
if received_signature != signature:
raise RuntimeError(
"CriticalSection stack inconsistency: expected signature "
f"{signature} but received {received_signature}")
@tf_export("CriticalSection")
class CriticalSection:
"""Critical section.
A `CriticalSection` object is a resource in the graph which executes subgraphs
in **serial** order. A common example of a subgraph one may wish to run
exclusively is the one given by the following function:
```python
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def count():
value = v.read_value()
with tf.control_dependencies([value]):
with tf.control_dependencies([v.assign_add(1)]):
return tf.identity(value)
```
Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
The snapshot value is returned.
If multiple workers or threads all execute `count` in parallel, there is no
guarantee that access to the variable `v` is atomic at any point within
any thread's calculation of `count`. In fact, even implementing an atomic
counter that guarantees that the user will see each value `0, 1, ...,` is
currently impossible.
The solution is to ensure any access to the underlying resource `v` is
only processed through a critical section:
```python
cs = CriticalSection()
f1 = cs.execute(count)
f2 = cs.execute(count)
output = f1 + f2
session.run(output)
```
The functions `f1` and `f2` will be executed serially, and updates to `v`
will be atomic.
**NOTES**
All resource objects, including the critical section and any captured
variables of functions executed on that critical section, will be
colocated to the same device (host and cpu/gpu).
When using multiple critical sections on the same resources, there is no
guarantee of exclusive access to those resources. This behavior is disallowed
by default (but see the kwarg `exclusive_resource_access`).
For example, running the same function in two separate critical sections
will not ensure serial execution:
```python
v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
def accumulate(up):
x = v.read_value()
with tf.control_dependencies([x]):
with tf.control_dependencies([v.assign_add(up)]):
return tf.identity(x)
ex1 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
ex2 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
bad_sum = ex1 + ex2
sess.run(v.initializer)
sess.run(bad_sum) # May return 0.0
```
"""
def __init__(self, name=None, shared_name=None,
critical_section_def=None, import_scope=None):
"""Creates a critical section."""
context.ensure_initialized()
if critical_section_def and name is not None:
raise ValueError(f"Arguments critical_section_def={critical_section_def} "
f"and shared_name={shared_name} are mutually exclusive. "
"Please only specify one of them.")
if critical_section_def:
raise ValueError("Argument `critical_section_def` is not supported.")
else:
self._init_from_args(name, shared_name)
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if shared_name is None:
shared_name = name
if container is None:
container = ""
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
# Get a uniquely identifying signature for the handle.
self._signature = (
container,
# If shared_name is empty, a unique CriticalSection is created.
shared_name or id(self._handle),
_get_device_or_colocation(self._handle))
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
@property
def name(self):
return self._handle.op.name
def execute(self, fn, exclusive_resource_access=True, name=None):
"""Execute function `fn()` inside the critical section.
`fn` should not accept any arguments. To add extra arguments to when
calling `fn` in the critical section, create a lambda:
```python
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
```
Args:
fn: The function to execute. Must return at least one tensor.
exclusive_resource_access: Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
name: The name to use when creating the execute operation.
Returns:
The tensors returned from `fn()`.
Raises:
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access == True` and
another `CriticalSection` has an execution requesting the same
resources as `fn``. Note, even if `exclusive_resource_access` is
`True`, if another execution in another `CriticalSection` was created
without `exclusive_resource_access=True`, a `ValueError` will be raised.
"""
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
# kwargs have been executed. This avoids certain types of deadlocks.
with _push_critical_section_stack(self._signature):
lock = gen_resource_variable_ops.mutex_lock(self._handle)
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop,
# this makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)
# captured_resources is a list of resources that are directly
# accessed only by ops created during fn(), not by any
# ancestors of those ops in the graph.
captured_resources = object_identity.ObjectIdentitySet([
input_ for op in created_ops
for input_ in op.inputs
if input_.dtype == dtypes.resource
])
# NOTE(ebrevdo): The only time self._is_self_handle() is True
# in this call is if one of the recently created ops, within
# the execute(), themselves attempt to access the
# CriticalSection. This will cause a deadlock.
if any(self._is_self_handle(x) for x in captured_resources):
raise ValueError(
"Attempting to lock a CriticalSection in which we are "
f"already running (signature={self._signature}). This is illegal "
"and may cause deadlocks.")
self._check_multiple_access_to_resources(
captured_resources, exclusive_resource_access)
r_flat = [_identity(x) for x in nest.flatten(r)]
with ops.control_dependencies(r_flat):
# The identity must run on the same machine as self._handle
with ops.colocate_with(self._handle):
# Do not use array_ops.identity as there are special
# optimizations within TensorFlow which seem to elide it
# even when optimizations are disabled(!).
ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
lock)
# Make sure that if any element of r is accessed, all of
# them are executed together.
r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(_identity, r)
if not context.executing_eagerly():
signature = _ExecutionSignature(
op=lock.op,
handle=self._handle,
resources=list(captured_resources),
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
return outputs
def _add_control_dependencies_to_lock(self, created_ops, lock_op):
"""To avoid deadlocks, all args must be executed before lock_op."""
# Get all arguments (explicit and captured) of all ops created by fn().
all_args = set([input_.op for op in created_ops for input_ in op.inputs])
all_args.update(
input_op for op in created_ops for input_op in op.control_inputs)
# Unfortunately, we can't use sets throughout because TF seems to
# create new Operation objects for the same op sometimes; and we
# can't rely on id(op).
# pylint: disable=protected-access
all_args_dict = dict((op._id, op) for op in all_args)
# Remove ops created within fn, or that lock_op already has a
# control dependency on. Also remove a possible self-loop.
for op in created_ops:
all_args_dict.pop(op._id, None)
for op in lock_op.control_inputs:
all_args_dict.pop(op._id, None)
for input_ in lock_op.inputs:
all_args_dict.pop(input_.op._id, None)
all_args_dict.pop(lock_op._id, None)
all_args = all_args_dict.values()
if not all_args:
# No control dependencies to add; return early.
return
# This group is important: it ensures that any ops in all_args
# outside the control context of the lock_op (and this fn, which
# runs in the same context) are added to this context before
# being added to the control dependencies of lock_op.
all_args = control_flow_ops.group(*all_args)
lock_op._add_control_input(all_args)
# pylint: enable=protected-access
def _is_self_handle(self, x):
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
if isinstance(x, ops.EagerTensor):
return x is self._handle
return (x.op.type == "MutexV2"
# blank shared_name means the op will create a unique one.
and x.op.get_attr("shared_name")
and (x.op.get_attr("shared_name") ==
self._handle.op.get_attr("shared_name"))
and (x.op.device == self._handle.op.device
or _get_colocation(x.op) == _get_colocation(self._handle.op)))
def _check_multiple_access_to_resources(
self, captured_resources, exclusive_resource_access):
"""Raise if captured_resources are accessed by another CriticalSection.
Args:
captured_resources: Set of tensors of type resource.
exclusive_resource_access: Whether this execution requires exclusive
resource access.
Raises:
ValueError: If any tensors in `captured_resources` are also accessed
by another `CriticalSection`, and at least one of them requires
exclusive resource access.
"""
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if self._is_self_handle(sg.handle):
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
continue
resource_intersection = captured_resources.intersection(sg.resources)
if resource_intersection:
raise ValueError(
"This execution would access resources: "
f"{list(resource_intersection)}. Either this lock "
f"(CriticalSection: {self._handle}) or lock '{sg}' "
f"(CriticalSection: {sg.handle}) requested exclusive resource "
"access of this resource. Did you mean to call execute with "
"keyword argument exclusive_resource_access=False?")