137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utility functions for control flow.
|
|
|
|
This file is copied from tensorflow/python/ops/control_flow_util.py.
|
|
"""
|
|
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import smart_cond as smart_module
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import variables
|
|
|
|
|
|
def InXlaContext(graph):
|
|
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
|
|
return GetContainingXLAContext(ctxt) is not None
|
|
|
|
|
|
def GraphOrParentsInXlaContext(graph):
|
|
while True:
|
|
if InXlaContext(graph): return True
|
|
try:
|
|
graph = graph.outer_graph
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
def IsInWhileLoop(op):
|
|
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
|
return GetContainingWhileContext(ctxt) is not None
|
|
|
|
|
|
def GetContainingWhileContext(ctxt, stop_ctxt=None):
|
|
"""Returns the first ancestor WhileContext of `ctxt`.
|
|
|
|
Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
|
|
while loop.
|
|
|
|
Args:
|
|
ctxt: ControlFlowContext
|
|
stop_ctxt: ControlFlowContext, optional. If provided, the search will end
|
|
if it sees stop_ctxt.
|
|
|
|
Returns:
|
|
`ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
|
|
`ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not
|
|
`None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
|
|
"""
|
|
while ctxt:
|
|
if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
|
|
ctxt = ctxt.outer_context
|
|
return None
|
|
|
|
|
|
def GetContainingXLAContext(ctxt):
|
|
"""Returns the first ancestor XLAContext of `ctxt`.
|
|
|
|
Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
|
|
while loop.
|
|
|
|
Args:
|
|
ctxt: ControlFlowContext
|
|
|
|
Returns:
|
|
`ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
|
|
`ctxt`, or None if `ctxt` is not in a while loop.
|
|
"""
|
|
while ctxt:
|
|
if ctxt.IsXLAContext(): return ctxt
|
|
ctxt = ctxt.outer_context
|
|
return None
|
|
|
|
|
|
def smart_cond(pred, true_fn=None, false_fn=None, name=None): # pylint: disable=invalid-name
|
|
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
|
|
|
|
If `pred` is a bool or has a constant value, we return either `true_fn()`
|
|
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
|
|
|
|
Args:
|
|
pred: A scalar determining whether to return the result of `true_fn` or
|
|
`false_fn`.
|
|
true_fn: The callable to be performed if pred is true.
|
|
false_fn: The callable to be performed if pred is false.
|
|
name: Optional name prefix when using `tf.cond`.
|
|
|
|
Returns:
|
|
Tensors returned by the call to either `true_fn` or `false_fn`.
|
|
|
|
Raises:
|
|
TypeError: If `true_fn` or `false_fn` is not callable.
|
|
"""
|
|
if isinstance(pred, variables.Variable):
|
|
return control_flow_ops.cond(
|
|
pred, true_fn=true_fn, false_fn=false_fn, name=name)
|
|
return smart_module.smart_cond(
|
|
pred, true_fn=true_fn, false_fn=false_fn, name=name)
|
|
|
|
|
|
def constant_value(pred): # pylint: disable=invalid-name
|
|
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
|
|
|
|
Args:
|
|
pred: A scalar, either a Python bool or a TensorFlow boolean variable
|
|
or tensor, or the Python integer 1 or 0.
|
|
|
|
Returns:
|
|
True or False if `pred` has a constant boolean value, None otherwise.
|
|
|
|
Raises:
|
|
TypeError: If `pred` is not a Variable, Tensor or bool, or Python
|
|
integer 1 or 0.
|
|
"""
|
|
if isinstance(pred, ops.Tensor):
|
|
return tensor_util.constant_value(pred)
|
|
if pred in {0, 1}: # Accept 1/0 as valid boolean values
|
|
return bool(pred)
|
|
if isinstance(pred, bool):
|
|
return pred
|
|
if isinstance(pred, variables.Variable):
|
|
return None
|
|
raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. "
|
|
"Found instead: %s" % type(pred))
|