106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
"""TensorBoard helper routine for TF op evaluator.
|
||
|
|
||
|
Requires TensorFlow.
|
||
|
"""
|
||
|
|
||
|
|
||
|
import threading
|
||
|
|
||
|
|
||
|
class PersistentOpEvaluator:
|
||
|
"""Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently.
|
||
|
|
||
|
Extend this class to create a particular kind of op evaluator, like an
|
||
|
image encoder. In `initialize_graph`, create an appropriate TensorFlow
|
||
|
graph with placeholder inputs. In `run`, evaluate this graph and
|
||
|
return its result. This class will manage a singleton graph and
|
||
|
session to preserve memory usage, and will ensure that this graph and
|
||
|
session do not interfere with other concurrent sessions.
|
||
|
|
||
|
A subclass of this class offers a threadsafe, highly parallel Python
|
||
|
entry point for evaluating a particular TensorFlow graph.
|
||
|
|
||
|
Example usage:
|
||
|
|
||
|
class FluxCapacitanceEvaluator(PersistentOpEvaluator):
|
||
|
\"\"\"Compute the flux capacitance required for a system.
|
||
|
|
||
|
Arguments:
|
||
|
x: Available power input, as a `float`, in jigawatts.
|
||
|
|
||
|
Returns:
|
||
|
A `float`, in nanofarads.
|
||
|
\"\"\"
|
||
|
|
||
|
def initialize_graph(self):
|
||
|
self._placeholder = tf.placeholder(some_dtype)
|
||
|
self._op = some_op(self._placeholder)
|
||
|
|
||
|
def run(self, x):
|
||
|
return self._op.eval(feed_dict: {self._placeholder: x})
|
||
|
|
||
|
evaluate_flux_capacitance = FluxCapacitanceEvaluator()
|
||
|
|
||
|
for x in xs:
|
||
|
evaluate_flux_capacitance(x)
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self._session = None
|
||
|
self._initialization_lock = threading.Lock()
|
||
|
|
||
|
def _lazily_initialize(self):
|
||
|
"""Initialize the graph and session, if this has not yet been done."""
|
||
|
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
||
|
import tensorflow.compat.v1 as tf
|
||
|
|
||
|
with self._initialization_lock:
|
||
|
if self._session:
|
||
|
return
|
||
|
graph = tf.Graph()
|
||
|
with graph.as_default():
|
||
|
self.initialize_graph()
|
||
|
# Don't reserve GPU because libpng can't run on GPU.
|
||
|
config = tf.ConfigProto(device_count={"GPU": 0})
|
||
|
self._session = tf.Session(graph=graph, config=config)
|
||
|
|
||
|
def initialize_graph(self):
|
||
|
"""Create the TensorFlow graph needed to compute this operation.
|
||
|
|
||
|
This should write ops to the default graph and return `None`.
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
'Subclasses must implement "initialize_graph".'
|
||
|
)
|
||
|
|
||
|
def run(self, *args, **kwargs):
|
||
|
"""Evaluate the ops with the given input.
|
||
|
|
||
|
When this function is called, the default session will have the
|
||
|
graph defined by a previous call to `initialize_graph`. This
|
||
|
function should evaluate any ops necessary to compute the result
|
||
|
of the query for the given *args and **kwargs, likely returning
|
||
|
the result of a call to `some_op.eval(...)`.
|
||
|
"""
|
||
|
raise NotImplementedError('Subclasses must implement "run".')
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
self._lazily_initialize()
|
||
|
with self._session.as_default():
|
||
|
return self.run(*args, **kwargs)
|