368 lines
12 KiB
Python
368 lines
12 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Utility functions for control flow.
|
|
|
|
This file is necessary to avoid cyclic dependencies between ops.py and
|
|
control_flow_ops.py.
|
|
"""
|
|
|
|
import os
|
|
import traceback
|
|
|
|
from tensorflow.python import tf2
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
|
|
ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
|
|
os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
|
|
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
|
|
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
|
|
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
|
|
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
|
|
|
|
|
|
# TODO(b/137793122): Remove this.
|
|
def enable_control_flow_v2(): # pylint: disable=invalid-name
|
|
"""Use control flow v2.
|
|
|
|
Do not use this symbol. This will be removed.
|
|
"""
|
|
global ENABLE_CONTROL_FLOW_V2
|
|
ENABLE_CONTROL_FLOW_V2 = True
|
|
|
|
|
|
def EnableControlFlowV2(graph):
|
|
"""Returns whether control flow v2 should be used in `graph`."""
|
|
# Enable new control flow in FuncGraphs (but not legacy _FuncGraphs).
|
|
# TODO(skyewm): do something better than hasattr without messing up imports.
|
|
return ENABLE_CONTROL_FLOW_V2 or (
|
|
graph.building_function and not hasattr(graph, "_captured"))
|
|
|
|
|
|
def IsInXLAContext(op):
|
|
try:
|
|
xla_compile = op.get_attr("_XlaCompile")
|
|
if xla_compile: return True
|
|
except ValueError:
|
|
pass
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
return GetContainingXLAContext(ctxt) is not None
|
|
|
|
|
|
def InXlaContext(graph):
|
|
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
|
|
return GetContainingXLAContext(ctxt) is not None
|
|
|
|
|
|
def GraphOrParentsInXlaContext(graph):
|
|
while True:
|
|
if InXlaContext(graph): return True
|
|
try:
|
|
graph = graph.outer_graph
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
def IsInWhileLoop(op):
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
return GetContainingWhileContext(ctxt) is not None
|
|
|
|
|
|
def IsInCond(op):
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
return GetContainingCondContext(ctxt) is not None
|
|
|
|
|
|
def IsSwitch(op):
|
|
"""Return true if `op` is a Switch."""
|
|
return op.type == "Switch" or op.type == "RefSwitch"
|
|
|
|
|
|
def IsMerge(op):
|
|
"""Return true if `op` is a Merge."""
|
|
return op.type == "Merge" or op.type == "RefMerge"
|
|
|
|
|
|
def IsLoopEnter(op):
|
|
"""Returns true if `op` is an Enter."""
|
|
return op.type == "Enter" or op.type == "RefEnter"
|
|
|
|
|
|
def IsLoopExit(op):
|
|
"""Return true if `op` is an Exit."""
|
|
return op.type == "Exit" or op.type == "RefExit"
|
|
|
|
|
|
def IsCondSwitch(op):
|
|
"""Return true if `op` is the Switch for a conditional."""
|
|
if not IsSwitch(op):
|
|
return False
|
|
if not op.outputs:
|
|
return False
|
|
# Switch nodes are not part of the cond control flow context that they
|
|
# represent, so consider the consumers of its outputs to determine if it is
|
|
# cond switch or not. A switch is a cond switch iff all its consumers are in
|
|
# cond contexts.
|
|
is_cond_switch = True
|
|
for o in op.outputs:
|
|
for c in o.consumers():
|
|
ctxt = c._get_control_flow_context() # pylint: disable=protected-access
|
|
if IsLoopEnter(c):
|
|
ctxt = ctxt.outer_context
|
|
is_cond_switch = is_cond_switch and (ctxt is not None and
|
|
ctxt.IsCondContext())
|
|
return is_cond_switch
|
|
|
|
|
|
def IsCondMerge(op):
|
|
"""Return true if `op` is the Merge for a conditional."""
|
|
if not IsMerge(op):
|
|
return False
|
|
if not op.inputs:
|
|
return False
|
|
# Merge nodes are not part of the cond control flow context that they
|
|
# represent, so consider the inputs to the merge of to determine if it is
|
|
# cond merge or not: A merge is a cond merge iff all its inputs are in
|
|
# cond contexts.
|
|
is_cond_merge = True
|
|
for i in op.inputs:
|
|
ctxt = GetOutputContext(i.op)
|
|
is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
|
|
return is_cond_merge
|
|
|
|
|
|
def IsLoopSwitch(op):
|
|
"""Return true if `op` is the Switch for a while loop."""
|
|
if IsSwitch(op):
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
|
|
return False
|
|
|
|
|
|
def IsLoopMerge(op):
|
|
"""Return true if `op` is the Merge for a while loop."""
|
|
if IsMerge(op):
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
|
|
return False
|
|
|
|
|
|
def IsLoopConstantEnter(op):
|
|
"""Return true iff op is a loop invariant."""
|
|
return IsLoopEnter(op) and op.get_attr("is_constant")
|
|
|
|
|
|
def GetLoopConstantEnter(value):
|
|
"""Return the enter op if we can infer `value` to be a loop invariant."""
|
|
id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
|
|
op = value.op
|
|
while op.type in id_ops:
|
|
op = op.inputs[0].op
|
|
return op if IsLoopConstantEnter(op) else None
|
|
|
|
|
|
def GetOutputContext(op):
|
|
"""Return the control flow context for the output of an op."""
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
# Exit nodes usually have a control flow context, except in the case where the
|
|
# exit node was imported via import_graph_def (in which case no nodes have
|
|
# control flow contexts).
|
|
if ctxt is not None and IsLoopExit(op):
|
|
ctxt = ctxt.outer_context
|
|
return ctxt
|
|
|
|
|
|
def GetContainingWhileContext(ctxt, stop_ctxt=None):
|
|
"""Returns the first ancestor WhileContext of `ctxt`.
|
|
|
|
Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
|
|
while loop.
|
|
|
|
Args:
|
|
ctxt: ControlFlowContext
|
|
stop_ctxt: ControlFlowContext, optional. If provided, the search will end
|
|
if it sees stop_ctxt.
|
|
|
|
Returns:
|
|
`ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
|
|
`ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not
|
|
`None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
|
|
"""
|
|
while ctxt:
|
|
if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
|
|
ctxt = ctxt.outer_context
|
|
return None
|
|
|
|
|
|
def GetContainingXLAContext(ctxt):
|
|
"""Returns the first ancestor XLAContext of `ctxt`.
|
|
|
|
Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
|
|
while loop.
|
|
|
|
Args:
|
|
ctxt: ControlFlowContext
|
|
|
|
Returns:
|
|
`ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
|
|
`ctxt`, or None if `ctxt` is not in a while loop.
|
|
"""
|
|
while ctxt:
|
|
if ctxt.IsXLAContext(): return ctxt
|
|
ctxt = ctxt.outer_context
|
|
return None
|
|
|
|
|
|
def GetContainingCondContext(ctxt):
|
|
"""Returns the first ancestor CondContext of `ctxt`.
|
|
|
|
Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.
|
|
|
|
Args:
|
|
ctxt: ControlFlowContext
|
|
|
|
Returns:
|
|
`ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
|
|
`ctxt`, or None if `ctxt` is not in a cond.
|
|
"""
|
|
while ctxt:
|
|
if ctxt.IsCondContext(): return ctxt
|
|
ctxt = ctxt.outer_context
|
|
return None
|
|
|
|
|
|
def IsContainingContext(ctxt, maybe_containing_ctxt):
|
|
"""Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
|
|
while ctxt is not maybe_containing_ctxt:
|
|
if ctxt is None: return False
|
|
ctxt = ctxt.outer_context
|
|
return True
|
|
|
|
|
|
def OpInContext(op, ctxt):
|
|
return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access
|
|
|
|
|
|
def TensorInContext(tensor, ctxt):
|
|
return OpInContext(tensor.op, ctxt)
|
|
|
|
|
|
def CheckInputFromValidContext(op, input_op):
|
|
"""Returns whether `input_op` can be used from `op`s context.
|
|
|
|
Conceptually, only inputs from op's while context or any ancestor while
|
|
context (including outside of any context) are valid. In practice, there are
|
|
many other edge cases as well.
|
|
|
|
Args:
|
|
op: Operation
|
|
input_op: Operation
|
|
|
|
Raises:
|
|
ValueError: if input_op is from an invalid context.
|
|
"""
|
|
op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
input_ctxt = GetOutputContext(input_op)
|
|
valid = False
|
|
|
|
if not input_ctxt:
|
|
# input_op isn't in a control flow context.
|
|
valid = True
|
|
elif op_ctxt is input_ctxt:
|
|
# input_op is in the same context as op.
|
|
valid = True
|
|
else:
|
|
while_ctxt = GetContainingWhileContext(op_ctxt)
|
|
input_while_ctxt = GetContainingWhileContext(input_ctxt)
|
|
|
|
if while_ctxt is None:
|
|
if input_while_ctxt is None:
|
|
# Neither op nor input_op is in a while loop, but one or both are in
|
|
# conds. We allow this, although execution will fail if the branch
|
|
# corresponding to input_op's cond context isn't taken.
|
|
valid = True
|
|
# Invalid if op isn't in a while loop and input_op is. Unless...
|
|
if IsLoopEnter(op):
|
|
# WhileContext._BuildLoop clears context for Enter nodes.
|
|
valid = True
|
|
if IsSwitch(op):
|
|
# CondContext.AddValue clears context for Switch nodes.
|
|
valid = True
|
|
elif IsContainingContext(while_ctxt, input_while_ctxt):
|
|
# input_op is in a while loop which contains op's while loop (or not in a
|
|
# while loop at all).
|
|
valid = True
|
|
elif (while_ctxt.grad_state and
|
|
IsContainingContext(while_ctxt.grad_state.forward_context,
|
|
input_while_ctxt)):
|
|
# op is in a gradient context and input_op is in the associated forward
|
|
# pass context or an ancestor thereof. This case is need to build while
|
|
# loop gradients.
|
|
# NOTE(skyewm): we theoretically also need this case for custom gradient
|
|
# functions that close over tensors from ancestor contexts, but I haven't
|
|
# verified this.
|
|
valid = True
|
|
elif (while_ctxt.grad_state and
|
|
while_ctxt.grad_state.forward_context is
|
|
input_while_ctxt._outer_context): # pylint: disable=protected-access
|
|
# op is in a gradient context and input_op is in a child of the associated
|
|
# forward pass context. This case is needed for the gradients of while
|
|
# loops with conds.
|
|
valid = True
|
|
elif (input_while_ctxt.grad_state and
|
|
input_while_ctxt.grad_state.forward_context is while_ctxt):
|
|
# input_op is in the gradient context of op's context. This case is needed
|
|
# when the gradient of a while loop gradient is requested (this will
|
|
# eventually fail unless there is a stop_gradient() or similar).
|
|
valid = True
|
|
elif (input_while_ctxt.grad_state and
|
|
input_ctxt.grad_state.forward_context.grad_state and
|
|
input_ctxt.grad_state.forward_context.grad_state.forward_context is
|
|
while_ctxt):
|
|
# input_op is in the grad grad context of op's context. This case is
|
|
# needed when the gradient of a while loop gradient is requested (this
|
|
# will eventually fail unless there is a stop_gradient() or similar).
|
|
valid = True
|
|
|
|
if not valid:
|
|
if while_ctxt:
|
|
error_msg = (
|
|
f"Cannot use '{input_op.name}' as input to '{op.name}' because they "
|
|
"are in different while loops.")
|
|
else:
|
|
error_msg = (
|
|
f"Cannot use '{input_op.name}' as input to '{op.name}' because "
|
|
f"'{input_op.name}' is in a while loop.")
|
|
|
|
# Log the error message plus the relevant stack traces. The stacks may be
|
|
# useful for debugging this error, but we don't want to raise an
|
|
# unreadable exception.
|
|
log_msg = error_msg
|
|
log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
|
|
log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
|
|
log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
|
|
op.name, "".join(traceback.format_list(op.traceback)),
|
|
input_op.name, "".join(traceback.format_list(input_op.traceback)))
|
|
logging.info(log_msg)
|
|
raise ValueError(error_msg + " See info log for more details.")
|
|
|
|
|
|
def GetWhileContext(op):
|
|
"""Get the WhileContext to which this op belongs."""
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
if ctxt:
|
|
ctxt = ctxt.GetWhileContext()
|
|
return ctxt
|