508 lines
18 KiB
Python
508 lines
18 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Coordinator to help multiple threads stop when requested."""
|
|
import contextlib
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export("train.Coordinator")
|
|
class Coordinator:
|
|
"""A coordinator for threads.
|
|
|
|
This class implements a simple mechanism to coordinate the termination of a
|
|
set of threads.
|
|
|
|
#### Usage:
|
|
|
|
```python
|
|
# Create a coordinator.
|
|
coord = Coordinator()
|
|
# Start a number of threads, passing the coordinator to each of them.
|
|
...start thread 1...(coord, ...)
|
|
...start thread N...(coord, ...)
|
|
# Wait for all the threads to terminate.
|
|
coord.join(threads)
|
|
```
|
|
|
|
Any of the threads can call `coord.request_stop()` to ask for all the threads
|
|
to stop. To cooperate with the requests, each thread must check for
|
|
`coord.should_stop()` on a regular basis. `coord.should_stop()` returns
|
|
`True` as soon as `coord.request_stop()` has been called.
|
|
|
|
A typical thread running with a coordinator will do something like:
|
|
|
|
```python
|
|
while not coord.should_stop():
|
|
...do some work...
|
|
```
|
|
|
|
#### Exception handling:
|
|
|
|
A thread can report an exception to the coordinator as part of the
|
|
`request_stop()` call. The exception will be re-raised from the
|
|
`coord.join()` call.
|
|
|
|
Thread code:
|
|
|
|
```python
|
|
try:
|
|
while not coord.should_stop():
|
|
...do some work...
|
|
except Exception as e:
|
|
coord.request_stop(e)
|
|
```
|
|
|
|
Main code:
|
|
|
|
```python
|
|
try:
|
|
...
|
|
coord = Coordinator()
|
|
# Start a number of threads, passing the coordinator to each of them.
|
|
...start thread 1...(coord, ...)
|
|
...start thread N...(coord, ...)
|
|
# Wait for all the threads to terminate.
|
|
coord.join(threads)
|
|
except Exception as e:
|
|
...exception that was passed to coord.request_stop()
|
|
```
|
|
|
|
To simplify the thread implementation, the Coordinator provides a
|
|
context handler `stop_on_exception()` that automatically requests a stop if
|
|
an exception is raised. Using the context handler the thread code above
|
|
can be written as:
|
|
|
|
```python
|
|
with coord.stop_on_exception():
|
|
while not coord.should_stop():
|
|
...do some work...
|
|
```
|
|
|
|
#### Grace period for stopping:
|
|
|
|
After a thread has called `coord.request_stop()` the other threads have a
|
|
fixed time to stop, this is called the 'stop grace period' and defaults to 2
|
|
minutes. If any of the threads is still alive after the grace period expires
|
|
`coord.join()` raises a RuntimeError reporting the laggards.
|
|
|
|
```python
|
|
try:
|
|
...
|
|
coord = Coordinator()
|
|
# Start a number of threads, passing the coordinator to each of them.
|
|
...start thread 1...(coord, ...)
|
|
...start thread N...(coord, ...)
|
|
# Wait for all the threads to terminate, give them 10s grace period
|
|
coord.join(threads, stop_grace_period_secs=10)
|
|
except RuntimeError:
|
|
...one of the threads took more than 10s to stop after request_stop()
|
|
...was called.
|
|
except Exception:
|
|
...exception that was passed to coord.request_stop()
|
|
```
|
|
"""
|
|
|
|
def __init__(self, clean_stop_exception_types=None):
|
|
"""Create a new Coordinator.
|
|
|
|
Args:
|
|
clean_stop_exception_types: Optional tuple of Exception types that should
|
|
cause a clean stop of the coordinator. If an exception of one of these
|
|
types is reported to `request_stop(ex)` the coordinator will behave as
|
|
if `request_stop(None)` was called. Defaults to
|
|
`(tf.errors.OutOfRangeError,)` which is used by input queues to signal
|
|
the end of input. When feeding training data from a Python iterator it
|
|
is common to add `StopIteration` to this list.
|
|
"""
|
|
if clean_stop_exception_types is None:
|
|
clean_stop_exception_types = (errors.OutOfRangeError,)
|
|
self._clean_stop_exception_types = tuple(clean_stop_exception_types)
|
|
# Protects all attributes.
|
|
self._lock = threading.Lock()
|
|
# Event set when threads must stop.
|
|
self._stop_event = threading.Event()
|
|
# Python exc_info to report.
|
|
# If not None, it should hold the returned value of sys.exc_info(), which is
|
|
# a tuple containing exception (type, value, traceback).
|
|
self._exc_info_to_raise = None
|
|
# True if we have called join() already.
|
|
self._joined = False
|
|
# Set of threads registered for joining when join() is called. These
|
|
# threads will be joined in addition to the threads passed to the join()
|
|
# call. It's ok if threads are both registered and passed to the join()
|
|
# call.
|
|
self._registered_threads = set()
|
|
|
|
def _filter_exception(self, ex):
|
|
"""Check if the exception indicated in 'ex' should be ignored.
|
|
|
|
This method examines `ex` to check if it is an exception that should be
|
|
reported to the users. If yes, it returns `ex` as is, otherwise it returns
|
|
None.
|
|
|
|
The code returns None for exception types listed in
|
|
`_clean_stop_exception_types`.
|
|
|
|
Args:
|
|
ex: None, an `Exception`, or a Python `exc_info` tuple as returned by
|
|
`sys.exc_info()`.
|
|
|
|
Returns:
|
|
ex or None.
|
|
"""
|
|
if isinstance(ex, tuple):
|
|
ex2 = ex[1]
|
|
else:
|
|
ex2 = ex
|
|
if isinstance(ex2, self._clean_stop_exception_types):
|
|
# Ignore the exception.
|
|
ex = None
|
|
return ex
|
|
|
|
def request_stop(self, ex=None):
|
|
"""Request that the threads stop.
|
|
|
|
After this is called, calls to `should_stop()` will return `True`.
|
|
|
|
Note: If an exception is being passed in, in must be in the context of
|
|
handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
|
|
a newly created one.
|
|
|
|
Args:
|
|
ex: Optional `Exception`, or Python `exc_info` tuple as returned by
|
|
`sys.exc_info()`. If this is the first call to `request_stop()` the
|
|
corresponding exception is recorded and re-raised from `join()`.
|
|
"""
|
|
with self._lock:
|
|
ex = self._filter_exception(ex)
|
|
# If we have already joined the coordinator the exception will not have a
|
|
# chance to be reported, so just raise it normally. This can happen if
|
|
# you continue to use a session have having stopped and joined the
|
|
# coordinator threads.
|
|
if self._joined:
|
|
if isinstance(ex, tuple):
|
|
_, ex_instance, _ = ex
|
|
raise ex_instance
|
|
elif ex is not None:
|
|
# NOTE(touts): This is bogus if request_stop() is not called
|
|
# from the exception handler that raised ex.
|
|
_, ex_instance, _ = sys.exc_info()
|
|
raise ex_instance
|
|
if not self._stop_event.is_set():
|
|
if ex and self._exc_info_to_raise is None:
|
|
if isinstance(ex, tuple):
|
|
logging.info("Error reported to Coordinator: %s",
|
|
compat.as_str_any(ex[1]),
|
|
exc_info=ex)
|
|
self._exc_info_to_raise = ex
|
|
else:
|
|
logging.info("Error reported to Coordinator: %s, %s",
|
|
type(ex),
|
|
compat.as_str_any(ex))
|
|
self._exc_info_to_raise = sys.exc_info()
|
|
# self._exc_info_to_raise should contain a tuple containing exception
|
|
# (type, value, traceback)
|
|
if (len(self._exc_info_to_raise) != 3 or
|
|
not self._exc_info_to_raise[0] or
|
|
not self._exc_info_to_raise[1]):
|
|
# Raise, catch and record the exception here so that error happens
|
|
# where expected.
|
|
try:
|
|
raise ValueError(
|
|
"ex must be a tuple or sys.exc_info must return the current "
|
|
"exception: %s"
|
|
% self._exc_info_to_raise)
|
|
except ValueError:
|
|
# Record this error so it kills the coordinator properly.
|
|
# NOTE(touts): As above, this is bogus if request_stop() is not
|
|
# called from the exception handler that raised ex.
|
|
self._exc_info_to_raise = sys.exc_info()
|
|
|
|
self._stop_event.set()
|
|
|
|
def clear_stop(self):
|
|
"""Clears the stop flag.
|
|
|
|
After this is called, calls to `should_stop()` will return `False`.
|
|
"""
|
|
with self._lock:
|
|
self._joined = False
|
|
self._exc_info_to_raise = None
|
|
if self._stop_event.is_set():
|
|
self._stop_event.clear()
|
|
|
|
def should_stop(self):
|
|
"""Check if stop was requested.
|
|
|
|
Returns:
|
|
True if a stop was requested.
|
|
"""
|
|
return self._stop_event.is_set()
|
|
|
|
@contextlib.contextmanager
|
|
def stop_on_exception(self):
|
|
"""Context manager to request stop when an Exception is raised.
|
|
|
|
Code that uses a coordinator must catch exceptions and pass
|
|
them to the `request_stop()` method to stop the other threads
|
|
managed by the coordinator.
|
|
|
|
This context handler simplifies the exception handling.
|
|
Use it as follows:
|
|
|
|
```python
|
|
with coord.stop_on_exception():
|
|
# Any exception raised in the body of the with
|
|
# clause is reported to the coordinator before terminating
|
|
# the execution of the body.
|
|
...body...
|
|
```
|
|
|
|
This is completely equivalent to the slightly longer code:
|
|
|
|
```python
|
|
try:
|
|
...body...
|
|
except:
|
|
coord.request_stop(sys.exc_info())
|
|
```
|
|
|
|
Yields:
|
|
nothing.
|
|
"""
|
|
try:
|
|
yield
|
|
except: # pylint: disable=bare-except
|
|
self.request_stop(ex=sys.exc_info())
|
|
|
|
def wait_for_stop(self, timeout=None):
|
|
"""Wait till the Coordinator is told to stop.
|
|
|
|
Args:
|
|
timeout: Float. Sleep for up to that many seconds waiting for
|
|
should_stop() to become True.
|
|
|
|
Returns:
|
|
True if the Coordinator is told stop, False if the timeout expired.
|
|
"""
|
|
return self._stop_event.wait(timeout)
|
|
|
|
def register_thread(self, thread):
|
|
"""Register a thread to join.
|
|
|
|
Args:
|
|
thread: A Python thread to join.
|
|
"""
|
|
with self._lock:
|
|
self._registered_threads.add(thread)
|
|
|
|
def join(self, threads=None, stop_grace_period_secs=120,
|
|
ignore_live_threads=False):
|
|
"""Wait for threads to terminate.
|
|
|
|
This call blocks until a set of threads have terminated. The set of thread
|
|
is the union of the threads passed in the `threads` argument and the list
|
|
of threads that registered with the coordinator by calling
|
|
`Coordinator.register_thread()`.
|
|
|
|
After the threads stop, if an `exc_info` was passed to `request_stop`, that
|
|
exception is re-raised.
|
|
|
|
Grace period handling: When `request_stop()` is called, threads are given
|
|
'stop_grace_period_secs' seconds to terminate. If any of them is still
|
|
alive after that period expires, a `RuntimeError` is raised. Note that if
|
|
an `exc_info` was passed to `request_stop()` then it is raised instead of
|
|
that `RuntimeError`.
|
|
|
|
Args:
|
|
threads: List of `threading.Threads`. The started threads to join in
|
|
addition to the registered threads.
|
|
stop_grace_period_secs: Number of seconds given to threads to stop after
|
|
`request_stop()` has been called.
|
|
ignore_live_threads: If `False`, raises an error if any of the threads are
|
|
still alive after `stop_grace_period_secs`.
|
|
|
|
Raises:
|
|
RuntimeError: If any thread is still alive after `request_stop()`
|
|
is called and the grace period expires.
|
|
"""
|
|
# Threads registered after this call will not be joined.
|
|
with self._lock:
|
|
if threads is None:
|
|
threads = self._registered_threads
|
|
else:
|
|
threads = self._registered_threads.union(set(threads))
|
|
# Copy the set into a list to avoid race conditions where a new thread
|
|
# is added while we are waiting.
|
|
threads = list(threads)
|
|
|
|
# Wait for all threads to stop or for request_stop() to be called.
|
|
while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
|
|
pass
|
|
|
|
# If any thread is still alive, wait for the grace period to expire.
|
|
# By the time this check is executed, threads may still be shutting down,
|
|
# so we add a sleep of increasing duration to give them a chance to shut
|
|
# down without losing too many cycles.
|
|
# The sleep duration is limited to the remaining grace duration.
|
|
stop_wait_secs = 0.001
|
|
while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
|
|
time.sleep(stop_wait_secs)
|
|
stop_grace_period_secs -= stop_wait_secs
|
|
stop_wait_secs = 2 * stop_wait_secs
|
|
# Keep the waiting period within sane bounds.
|
|
# The minimum value is to avoid decreasing stop_wait_secs to a value
|
|
# that could cause stop_grace_period_secs to remain unchanged.
|
|
stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001)
|
|
|
|
# List the threads still alive after the grace period.
|
|
stragglers = [t.name for t in threads if t.is_alive()]
|
|
|
|
# Terminate with an exception if appropriate.
|
|
with self._lock:
|
|
self._joined = True
|
|
self._registered_threads = set()
|
|
if self._exc_info_to_raise:
|
|
_, ex_instance, _ = self._exc_info_to_raise
|
|
raise ex_instance
|
|
elif stragglers:
|
|
if ignore_live_threads:
|
|
logging.info("Coordinator stopped with threads still running: %s",
|
|
" ".join(stragglers))
|
|
else:
|
|
raise RuntimeError(
|
|
"Coordinator stopped with threads still running: %s" %
|
|
" ".join(stragglers))
|
|
|
|
@property
|
|
def joined(self):
|
|
return self._joined
|
|
|
|
def raise_requested_exception(self):
|
|
"""If an exception has been passed to `request_stop`, this raises it."""
|
|
with self._lock:
|
|
if self._exc_info_to_raise:
|
|
_, ex_instance, _ = self._exc_info_to_raise
|
|
raise ex_instance
|
|
|
|
|
|
# Threads for the standard services.
|
|
@tf_export(v1=["train.LooperThread"])
|
|
class LooperThread(threading.Thread):
|
|
"""A thread that runs code repeatedly, optionally on a timer.
|
|
|
|
This thread class is intended to be used with a `Coordinator`. It repeatedly
|
|
runs code specified either as `target` and `args` or by the `run_loop()`
|
|
method.
|
|
|
|
Before each run the thread checks if the coordinator has requested stop. In
|
|
that case the looper thread terminates immediately.
|
|
|
|
If the code being run raises an exception, that exception is reported to the
|
|
coordinator and the thread terminates. The coordinator will then request all
|
|
the other threads it coordinates to stop.
|
|
|
|
You typically pass looper threads to the supervisor `Join()` method.
|
|
"""
|
|
|
|
def __init__(self, coord, timer_interval_secs, target=None, args=None,
|
|
kwargs=None):
|
|
"""Create a LooperThread.
|
|
|
|
Args:
|
|
coord: A Coordinator.
|
|
timer_interval_secs: Time boundaries at which to call Run(), or None
|
|
if it should be called back to back.
|
|
target: Optional callable object that will be executed in the thread.
|
|
args: Optional arguments to pass to `target` when calling it.
|
|
kwargs: Optional keyword arguments to pass to `target` when calling it.
|
|
|
|
Raises:
|
|
ValueError: If one of the arguments is invalid.
|
|
"""
|
|
if not isinstance(coord, Coordinator):
|
|
raise ValueError("'coord' argument must be a Coordinator: %s" % coord)
|
|
super(LooperThread, self).__init__()
|
|
self.daemon = True
|
|
self._coord = coord
|
|
self._timer_interval_secs = timer_interval_secs
|
|
self._target = target
|
|
if self._target:
|
|
self._args = args or ()
|
|
self._kwargs = kwargs or {}
|
|
elif args or kwargs:
|
|
raise ValueError("'args' and 'kwargs' argument require that you also "
|
|
"pass 'target'")
|
|
self._coord.register_thread(self)
|
|
|
|
@staticmethod
|
|
def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
|
|
"""Start a LooperThread that calls a function periodically.
|
|
|
|
If `timer_interval_secs` is None the thread calls `target(args)`
|
|
repeatedly. Otherwise `target(args)` is called every `timer_interval_secs`
|
|
seconds. The thread terminates when a stop of the coordinator is
|
|
requested.
|
|
|
|
Args:
|
|
coord: A Coordinator.
|
|
timer_interval_secs: Number. Time boundaries at which to call `target`.
|
|
target: A callable object.
|
|
args: Optional arguments to pass to `target` when calling it.
|
|
kwargs: Optional keyword arguments to pass to `target` when calling it.
|
|
|
|
Returns:
|
|
The started thread.
|
|
"""
|
|
looper = LooperThread(coord, timer_interval_secs, target=target, args=args,
|
|
kwargs=kwargs)
|
|
looper.start()
|
|
return looper
|
|
|
|
def run(self):
|
|
with self._coord.stop_on_exception():
|
|
self.start_loop()
|
|
if self._timer_interval_secs is None:
|
|
# Call back-to-back.
|
|
while not self._coord.should_stop():
|
|
self.run_loop()
|
|
else:
|
|
# Next time at which to call run_loop(), starts as 'now'.
|
|
next_timer_time = time.time()
|
|
while not self._coord.wait_for_stop(next_timer_time - time.time()):
|
|
next_timer_time += self._timer_interval_secs
|
|
self.run_loop()
|
|
self.stop_loop()
|
|
|
|
def start_loop(self):
|
|
"""Called when the thread starts."""
|
|
pass
|
|
|
|
def stop_loop(self):
|
|
"""Called when the thread stops."""
|
|
pass
|
|
|
|
def run_loop(self):
|
|
"""Called at 'timer_interval_secs' boundaries."""
|
|
if self._target:
|
|
self._target(*self._args, **self._kwargs)
|