427 lines
14 KiB
Python
427 lines
14 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tools for selecting ops in a graph."""
|
|
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor as tensor_lib
|
|
from tensorflow.python.util import object_identity
|
|
|
|
|
|
def is_differentiable(op):
|
|
try:
|
|
return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access
|
|
except LookupError:
|
|
return False
|
|
|
|
|
|
def is_iterable(obj):
|
|
"""Return true if the object is iterable."""
|
|
if isinstance(obj, tensor_lib.Tensor):
|
|
return False
|
|
try:
|
|
_ = iter(obj)
|
|
except Exception: # pylint: disable=broad-except
|
|
return False
|
|
return True
|
|
|
|
|
|
def concatenate_unique(la, lb):
|
|
"""Add all the elements of `lb` to `la` if they are not there already.
|
|
|
|
The elements added to `la` maintain ordering with respect to `lb`.
|
|
|
|
Args:
|
|
la: List of Python objects.
|
|
lb: List of Python objects.
|
|
Returns:
|
|
`la`: The list `la` with missing elements from `lb`.
|
|
"""
|
|
la_set = set(la)
|
|
for l in lb:
|
|
if l not in la_set:
|
|
la.append(l)
|
|
la_set.add(l)
|
|
return la
|
|
|
|
|
|
def get_tensors(graph):
|
|
"""get all the tensors which are input or output of an op in the graph.
|
|
|
|
Args:
|
|
graph: a `tf.Graph`.
|
|
Returns:
|
|
A list of `tf.Tensor`.
|
|
Raises:
|
|
TypeError: if graph is not a `tf.Graph`.
|
|
"""
|
|
if not isinstance(graph, ops.Graph):
|
|
raise TypeError("Expected a graph, got: {}".format(type(graph)))
|
|
ts = []
|
|
for op in graph.get_operations():
|
|
ts += op.outputs
|
|
return ts
|
|
|
|
|
|
def get_unique_graph(tops, check_types=None, none_if_empty=False):
|
|
"""Return the unique graph used by the all the elements in tops.
|
|
|
|
Args:
|
|
tops: iterable of elements to check (usually a list of tf.Operation and/or
|
|
tf.Tensor). Or a tf.Graph.
|
|
check_types: check that the element in tops are of given type(s). If None,
|
|
the types (tf.Operation, tf.Tensor) are used.
|
|
none_if_empty: don't raise an error if tops is an empty list, just return
|
|
None.
|
|
Returns:
|
|
The unique graph used by all the tops.
|
|
Raises:
|
|
TypeError: if tops is not a iterable of tf.Operation.
|
|
ValueError: if the graph is not unique.
|
|
"""
|
|
if isinstance(tops, ops.Graph):
|
|
return tops
|
|
if not is_iterable(tops):
|
|
raise TypeError("{} is not iterable".format(type(tops)))
|
|
if check_types is None:
|
|
check_types = (ops.Operation, tensor_lib.Tensor)
|
|
elif not is_iterable(check_types):
|
|
check_types = (check_types,)
|
|
g = None
|
|
for op in tops:
|
|
if not isinstance(op, check_types):
|
|
raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
|
|
t) for t in check_types]), type(op)))
|
|
if g is None:
|
|
g = op.graph
|
|
elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access
|
|
raise ValueError("Operation {} does not belong to given graph".format(op))
|
|
if g is None and not none_if_empty:
|
|
raise ValueError("Can't find the unique graph of an empty list")
|
|
return g
|
|
|
|
|
|
def check_graphs(*args):
|
|
"""Check that all the element in args belong to the same graph.
|
|
|
|
Args:
|
|
*args: a list of object with a obj.graph property.
|
|
Raises:
|
|
ValueError: if all the elements do not belong to the same graph.
|
|
"""
|
|
graph = None
|
|
for i, sgv in enumerate(args):
|
|
if graph is None and sgv.graph is not None:
|
|
graph = sgv.graph
|
|
elif sgv.graph is not None and sgv.graph is not graph:
|
|
raise ValueError(f"args[{i}] does not belong to the same graph as "
|
|
"other arguments.")
|
|
|
|
|
|
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
|
|
"""Convert ts to a list of `tf.Tensor`.
|
|
|
|
Args:
|
|
ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
|
|
check_graph: if `True` check if all the tensors belong to the same graph.
|
|
allow_graph: if `False` a `tf.Graph` cannot be converted.
|
|
ignore_ops: if `True`, silently ignore `tf.Operation`.
|
|
Returns:
|
|
A newly created list of `tf.Tensor`.
|
|
Raises:
|
|
TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
|
|
if `check_graph` is `True`, if all the ops do not belong to the same graph.
|
|
"""
|
|
if isinstance(ts, ops.Graph):
|
|
if allow_graph:
|
|
return get_tensors(ts)
|
|
else:
|
|
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
|
|
else:
|
|
if not is_iterable(ts):
|
|
ts = [ts]
|
|
if not ts:
|
|
return []
|
|
if check_graph:
|
|
check_types = None if ignore_ops else tensor_lib.Tensor
|
|
get_unique_graph(ts, check_types=check_types)
|
|
return [t for t in ts if isinstance(t, tensor_lib.Tensor)]
|
|
|
|
|
|
def get_generating_ops(ts):
|
|
"""Return all the generating ops of the tensors in `ts`.
|
|
|
|
Args:
|
|
ts: a list of `tf.Tensor`
|
|
Returns:
|
|
A list of all the generating `tf.Operation` of the tensors in `ts`.
|
|
Raises:
|
|
TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
|
|
"""
|
|
ts = make_list_of_t(ts, allow_graph=False)
|
|
return [t.op for t in ts]
|
|
|
|
|
|
def get_consuming_ops(ts):
|
|
"""Return all the consuming ops of the tensors in ts.
|
|
|
|
Args:
|
|
ts: a list of `tf.Tensor`
|
|
Returns:
|
|
A list of all the consuming `tf.Operation` of the tensors in `ts`.
|
|
Raises:
|
|
TypeError: if ts cannot be converted to a list of `tf.Tensor`.
|
|
"""
|
|
ts = make_list_of_t(ts, allow_graph=False)
|
|
tops = []
|
|
for t in ts:
|
|
for op in t.consumers():
|
|
if op not in tops:
|
|
tops.append(op)
|
|
return tops
|
|
|
|
|
|
def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False):
|
|
"""Convert ops to a list of `tf.Operation`.
|
|
|
|
Args:
|
|
tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
|
|
operation.
|
|
check_graph: if `True` check if all the operations belong to the same graph.
|
|
allow_graph: if `False` a `tf.Graph` cannot be converted.
|
|
ignore_ts: if True, silently ignore `tf.Tensor`.
|
|
Returns:
|
|
A newly created list of `tf.Operation`.
|
|
Raises:
|
|
TypeError: if tops cannot be converted to a list of `tf.Operation` or,
|
|
if `check_graph` is `True`, if all the ops do not belong to the
|
|
same graph.
|
|
"""
|
|
if isinstance(tops, ops.Graph):
|
|
if allow_graph:
|
|
return tops.get_operations()
|
|
else:
|
|
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
|
|
else:
|
|
if not is_iterable(tops):
|
|
tops = [tops]
|
|
if not tops:
|
|
return []
|
|
if check_graph:
|
|
check_types = None if ignore_ts else ops.Operation
|
|
get_unique_graph(tops, check_types=check_types)
|
|
return [op for op in tops if isinstance(op, ops.Operation)]
|
|
|
|
|
|
def _get_inputs(op, only_differentiable):
|
|
op_inputs = op.inputs
|
|
if only_differentiable:
|
|
return op_inputs if is_differentiable(op) else []
|
|
else:
|
|
return op_inputs
|
|
|
|
|
|
def get_backward_walk_ops(seed_ops,
|
|
inclusive=True,
|
|
within_ops=None,
|
|
within_ops_fn=None,
|
|
stop_at_ts=(),
|
|
control_inputs=False,
|
|
only_differentiable=False):
|
|
"""Do a backward graph walk and return all the visited ops.
|
|
|
|
Args:
|
|
seed_ops: an iterable of operations from which the backward graph
|
|
walk starts. If a list of tensors is given instead, the seed_ops are set
|
|
to be the generators of those tensors.
|
|
inclusive: if True the given seed_ops are also part of the resulting set.
|
|
within_ops: an iterable of `tf.Operation` within which the search is
|
|
restricted. If `within_ops` is `None`, the search is performed within
|
|
the whole graph.
|
|
within_ops_fn: if provided, a function on ops that should return True iff
|
|
the op is within the graph traversal. This can be used along within_ops,
|
|
in which case an op is within if it is also in within_ops.
|
|
stop_at_ts: an iterable of tensors at which the graph walk stops.
|
|
control_inputs: if True, control inputs will be used while moving backward.
|
|
only_differentiable: if True, only traverse ops which are differentiable.
|
|
This includes natively differentiable ops, or ops with custom gradients.
|
|
Returns:
|
|
A Python set of all the `tf.Operation` behind `seed_ops`.
|
|
Raises:
|
|
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
|
|
`tf.Operation`.
|
|
"""
|
|
control_inputs = control_inputs and (not only_differentiable)
|
|
|
|
if not is_iterable(seed_ops):
|
|
seed_ops = [seed_ops]
|
|
|
|
try:
|
|
first_seed_op = next(iter(seed_ops))
|
|
except StopIteration:
|
|
# Empty iterable.
|
|
return []
|
|
|
|
if isinstance(first_seed_op, tensor_lib.Tensor):
|
|
ts = make_list_of_t(seed_ops, allow_graph=False)
|
|
seed_ops = get_generating_ops(ts)
|
|
else:
|
|
seed_ops = make_list_of_op(seed_ops, allow_graph=False)
|
|
|
|
stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts))
|
|
seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops))
|
|
if within_ops:
|
|
within_ops = make_list_of_op(within_ops, allow_graph=False)
|
|
within_ops = object_identity.ObjectIdentitySet(within_ops)
|
|
seed_ops &= within_ops
|
|
|
|
def is_within(op):
|
|
return (within_ops is None or op in within_ops) and (
|
|
within_ops_fn is None or within_ops_fn(op))
|
|
|
|
result = list(seed_ops)
|
|
wave = set(seed_ops)
|
|
while wave:
|
|
new_wave = set()
|
|
for op in wave:
|
|
for new_t in _get_inputs(op, only_differentiable=only_differentiable):
|
|
if new_t in stop_at_ts:
|
|
continue
|
|
if new_t.op not in result and is_within(new_t.op):
|
|
new_wave.add(new_t.op)
|
|
if control_inputs:
|
|
for new_op in op.control_inputs:
|
|
if new_op not in result and is_within(new_op):
|
|
new_wave.add(new_op)
|
|
concatenate_unique(result, new_wave)
|
|
wave = new_wave
|
|
if not inclusive:
|
|
result = [op for op in result if op not in seed_ops]
|
|
return result
|
|
|
|
|
|
class UnliftableError(Exception):
|
|
"""Raised if a Tensor cannot be lifted from the graph."""
|
|
|
|
# Prevent autograph from rewriting this error.
|
|
ag_pass_through = True
|
|
|
|
|
|
def _as_operation(op_or_tensor):
|
|
if isinstance(op_or_tensor, tensor_lib.Tensor):
|
|
return op_or_tensor.op
|
|
return op_or_tensor
|
|
|
|
|
|
def graph_inputs(op):
|
|
return [x.op for x in op.inputs] + list(op.control_inputs)
|
|
|
|
|
|
def show_path(from_op, tensors, sources):
|
|
"""Find one path from `from_op` to any of `tensors`, ignoring `sources`.
|
|
|
|
Args:
|
|
from_op: A `tf.Operation`.
|
|
tensors: A `tf.Operation`, a `tf.Tensor`, or a list thereof.
|
|
sources: A list of `tf.Tensor`.
|
|
|
|
Returns:
|
|
A python string containing the path, or "??" if none is found.
|
|
"""
|
|
if isinstance(from_op, tensor_lib.Tensor):
|
|
from_op = from_op.op
|
|
|
|
if not isinstance(tensors, list):
|
|
tensors = [tensors]
|
|
|
|
final_ops = [_as_operation(tensor) for tensor in tensors]
|
|
|
|
visited_ops = set(x.op for x in sources)
|
|
ops_to_visit = list(final_ops)
|
|
some_op_output = {}
|
|
while ops_to_visit:
|
|
op = ops_to_visit.pop()
|
|
if op in visited_ops:
|
|
continue
|
|
visited_ops.add(op)
|
|
if op == from_op:
|
|
path_op = op
|
|
path = [path_op]
|
|
while path_op not in final_ops:
|
|
path_op = some_op_output[path_op]
|
|
path.append(path_op)
|
|
return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path))
|
|
else:
|
|
for inp in graph_inputs(op):
|
|
if inp not in visited_ops and inp not in sources:
|
|
some_op_output[inp] = op
|
|
ops_to_visit.append(inp)
|
|
return "??"
|
|
|
|
|
|
# TODO(jmenick) - there is considerable duplication of functionality between
|
|
# this function and get_backward_walk_ops(). Need to deduplicate.
|
|
def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
|
|
op_outputs, add_sources):
|
|
"""Walk a Graph and capture the subgraph between init_tensor and sources.
|
|
|
|
Note: This function mutates visited_ops and op_outputs.
|
|
|
|
Args:
|
|
init_tensor: A Tensor or Operation where the subgraph terminates.
|
|
sources: A set of Tensors where subgraph extraction should stop.
|
|
disallowed_placeholders: An optional set of ops which may not appear in the
|
|
lifted graph. Defaults to all placeholders.
|
|
visited_ops: A set of operations which were visited in a prior pass.
|
|
op_outputs: A defaultdict containing the outputs of an op which are to be
|
|
copied into the new subgraph.
|
|
add_sources: A boolean indicating whether placeholders which are not in
|
|
sources should be allowed.
|
|
|
|
Returns:
|
|
The set of placeholders upon which init_tensor depends and are not in
|
|
sources.
|
|
|
|
Raises:
|
|
UnliftableError: if init_tensor depends on a placeholder which is not in
|
|
sources and add_sources is False.
|
|
"""
|
|
ops_to_visit = [_as_operation(init_tensor)]
|
|
extra_sources = object_identity.ObjectIdentitySet()
|
|
while ops_to_visit:
|
|
op = ops_to_visit.pop()
|
|
if op in visited_ops:
|
|
continue
|
|
visited_ops.add(op)
|
|
|
|
should_raise = False
|
|
if disallowed_placeholders is not None and op in disallowed_placeholders:
|
|
should_raise = True
|
|
elif op.type == "Placeholder":
|
|
if disallowed_placeholders is None and not add_sources:
|
|
should_raise = True
|
|
extra_sources.update(op.outputs)
|
|
|
|
if should_raise:
|
|
raise UnliftableError(
|
|
"Unable to lift tensor %s because it depends transitively on "
|
|
"placeholder %s via at least one path, e.g.: %s" %
|
|
(repr(init_tensor), repr(op), show_path(op, init_tensor, sources)))
|
|
for inp in graph_inputs(op):
|
|
op_outputs[inp].add(op)
|
|
if inp not in visited_ops and inp not in (sources or extra_sources):
|
|
ops_to_visit.append(inp)
|
|
|
|
return extra_sources
|