78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Executor for eager execution."""
|
||
|
|
||
|
from tensorflow.python import pywrap_tfe
|
||
|
|
||
|
|
||
|
class Executor(object):
|
||
|
"""A class for handling eager execution.
|
||
|
|
||
|
The default behavior for asynchronous execution is to serialize all ops on
|
||
|
a single thread. Having different `Executor` objects in different threads
|
||
|
enables executing ops asynchronously in parallel:
|
||
|
|
||
|
```python
|
||
|
def thread_function():
|
||
|
executor = executor.Executor(enable_async=True):
|
||
|
context.set_executor(executor)
|
||
|
|
||
|
a = threading.Thread(target=thread_function)
|
||
|
a.start()
|
||
|
b = threading.Thread(target=thread_function)
|
||
|
b.start()
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
__slots__ = ["_handle"]
|
||
|
|
||
|
def __init__(self, handle):
|
||
|
self._handle = handle
|
||
|
|
||
|
def __del__(self):
|
||
|
try:
|
||
|
self.wait()
|
||
|
pywrap_tfe.TFE_DeleteExecutor(self._handle)
|
||
|
except TypeError:
|
||
|
# Suppress some exceptions, mainly for the case when we're running on
|
||
|
# module deletion. Things that can go wrong include the pywrap module
|
||
|
# already being unloaded, self._handle. no longer being
|
||
|
# valid, and so on. Printing warnings in these cases is silly
|
||
|
# (exceptions raised from __del__ are printed as warnings to stderr).
|
||
|
pass # 'NoneType' object is not callable when the handle has been
|
||
|
# partially unloaded.
|
||
|
|
||
|
def is_async(self):
|
||
|
return pywrap_tfe.TFE_ExecutorIsAsync(self._handle)
|
||
|
|
||
|
def handle(self):
|
||
|
return self._handle
|
||
|
|
||
|
def wait(self):
|
||
|
"""Waits for ops dispatched in this executor to finish."""
|
||
|
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
|
||
|
|
||
|
def clear_error(self):
|
||
|
"""Clears errors raised in this executor during execution."""
|
||
|
pywrap_tfe.TFE_ExecutorClearError(self._handle)
|
||
|
|
||
|
|
||
|
def new_executor(enable_async,
|
||
|
enable_streaming_enqueue=True,
|
||
|
in_flight_nodes_limit=0):
|
||
|
handle = pywrap_tfe.TFE_NewExecutor(enable_async, enable_streaming_enqueue,
|
||
|
in_flight_nodes_limit)
|
||
|
return Executor(handle)
|