# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Control flow statements: loops, conditionals, etc. Note: most of these operators accept pairs of get_state/set_state functions, to capture mutations that the corresponding code blocks might make. These mutations only need to be captured when staging the control flow, and they just work when reverting to Python behavior. __Examples__ ``` while cond: self.x += i ``` When the functionalized version is executed as a Python loop, it just works: ``` def loop_body(): self.x += i # works as expected for Python loops ``` But it won't work for TF loops: ``` def loop_body(): self.x += i # self.x has the wrong value! ``` get_state/set_state allow piping the mutations through the loop variables as well, in effect changing the loop body: ``` def loop_body(self_x): self.x = self_x # self.x now has the proper value self.x += i # the original block self_x = self.x # write self.x back into the loop vars return self_x self_x = tf.while_loop(...) self.x = self_x # the result is not properly captured ``` """ import functools import sys import traceback import numpy as np from tensorflow.python.autograph.operators import py_builtins from tensorflow.python.autograph.operators import variables from tensorflow.python.autograph.utils import ag_logging from tensorflow.python.autograph.utils import misc from tensorflow.python.autograph.utils import tensors from tensorflow.python.autograph.utils import type_registry from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.types import distribute from tensorflow.python.util import nest from tensorflow.python.util import variable_utils PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. WARN_INEFFICIENT_UNROLL = True INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000 INEFFICIENT_UNROLL_MIN_OPS = 1 # TODO(mdan): Use the custom operator pattern instead of type dispatch. # An example of this pattern is found in the implementation of distributed # datasets. Before it can be used though, we need to standardize the interface. for_loop_registry = type_registry.TypeRegistry() def _is_none_or_undef(value): """Tests whether a value is None or undefined. AutoGraph represents undefined symbols using special objects of type Undefined or UndefinedReturnValue. Args: value: value to test Returns: Boolean """ return ((value is None) or isinstance(value, variables.UndefinedReturnValue) or isinstance(value, variables.Undefined)) def _verify_tf_condition(cond, tag): """Ensures that the condition can be used in a TF control flow.""" extra_hint = 'to check for None, use `is not None`' cond = ops.convert_to_tensor_v2(cond) if cond.dtype != dtypes.bool: raise ValueError( 'condition of {} expected to be `tf.bool` scalar, got {}' '; to use as boolean Tensor, use `tf.cast`' '; {}'.format(tag, cond, extra_hint)) if cond.shape is None or cond.shape.ndims is None: # TODO(mdan): Consider a explicit size check, if not too slow. cond = array_ops.reshape(cond, ()) elif cond.shape.ndims > 0: known_dims = [d for d in cond.shape.as_list() if d is not None] if np.prod(known_dims) > 1: raise ValueError( 'condition of {} expected to be `tf.bool` scalar, got {}' '; {}'.format(tag, cond, extra_hint)) else: cond = array_ops.reshape(cond, ()) return cond def verify_loop_init_vars( init_vars, symbol_names, first_iter_vars=None, extra_message=None ): """Ensures that all values in the state are valid to use in a TF loop. The init_vars may contain placeholder values derived from first_iter_vars. Args: init_vars: initial loop variables (as taken before entering the loop) symbol_names: corresponding names of the initial loop variables first_iter_vars: loop variables after one iteration of the loop extra_message: an extra string to append to the error message, in case of "undefined variable" errors (see variables.Undefined) """ if not symbol_names: return if first_iter_vars is None: first_iter_vars = (None,) * len(symbol_names) assert len(symbol_names) == len(init_vars) assert len(symbol_names) == len(first_iter_vars) for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars): if isinstance(val, variables.UndefinedReturnValue): if fi_val: raise ValueError( 'the return value from a TensorFlow loop may only be a {}; got {}' .format(LEGAL_LOOP_TYPES, type(fi_val))) else: # TODO(mdan): This can be handled by removing the return value. raise NotImplementedError( 'a return statement cannot be placed inside this TensorFlow loop;' ' this may happen if a return statement depends on a' ' static Python condition such as a hyperparameter') error_msg = None if val is None: error_msg = "'{}' is not allowed to be None before the loop".format(name) elif isinstance(val, variables.Undefined): error_msg = "'{}' must be defined before the loop".format(name) if extra_message: error_msg += '\n' + extra_message if error_msg is not None: raise ValueError(error_msg) def _is_subshape(left, right): """Returns True if left shape is at least as specific as right shape.""" # TODO(mdan): This code should be in TensorShape. # Note: this is not the same as TensorShape.is_compatible_with, which is # symmetric. # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. if right.dims is None: return True if left.ndims != right.ndims: return False for ldim, rdim in zip(left.dims, right.dims): if rdim.value is not None and ldim.value != rdim.value: return False return True # TODO(mdan): Remove these verifications once TF ops can properly report names. def _verify_single_loop_var( name, check_shape, init, entry, exit_, shape_invariant): """Verifies whether the initial, entry and exit values are consistent.""" assert entry is not None, "no TF op should set '{}' to None?".format(name) if exit_ is None: raise ValueError("'{}' is None at the end of the iteration.".format(name)) if isinstance(init, (bool, int, float, str, np.ndarray)): init = ops.convert_to_tensor_v2(init) if isinstance(entry, (bool, int, float, str, np.ndarray)): entry = ops.convert_to_tensor_v2(entry) if isinstance(exit_, (bool, int, float, str, np.ndarray)): exit_ = ops.convert_to_tensor_v2(exit_) if (not tensor_util.is_tf_type(entry) or not tensor_util.is_tf_type(exit_)): return # TODO(mdan): Properly account for CompositeTensors. if (not hasattr(entry, 'dtype') or not hasattr(exit_, 'dtype')): return if (not hasattr(entry, 'shape') or not hasattr(exit_, 'shape')): return if entry.dtype != exit_.dtype: raise TypeError( "'{}' has dtype {} before the loop, but dtype {} after one" ' iteration'.format( name, entry.dtype.name, exit_.dtype.name, )) if check_shape: exit_shape = exit_.shape if shape_invariant is None: entry_shape = entry.shape if not _is_subshape(exit_shape, entry_shape): raise ValueError( "'{}' has shape {} before the loop, but shape {} after one" ' iteration. Use tf.autograph.experimental.set_loop_options to set' ' shape invariants.'.format(name, entry_shape, exit_shape)) else: init_shape = init.shape if not _is_subshape(init_shape, shape_invariant): raise ValueError( "'{}' has shape {} before the loop, which does not conform with" ' the shape invariant {}.'.format(name, init_shape, shape_invariant)) if not _is_subshape(exit_shape, shape_invariant): raise ValueError( "'{}' has shape {} after one iteration, which does not conform with" ' the shape invariant {}.'.format(name, exit_shape, shape_invariant) ) def verify_tf_loop_vars( init_vars, iter_entry_vars, iter_exit_vars, symbol_names, opts, check_shapes=True, ): """Verifies loop variables for consistency.""" if check_shapes and 'shape_invariants' in opts: shape_invariants = opts['shape_invariants'] else: shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) assert len(symbol_names) == len(shape_invariants) assert len(symbol_names) == len(init_vars) assert len(symbol_names) == len(iter_entry_vars) assert len(symbol_names) == len(iter_exit_vars) for i in range(len(symbol_names)): name = symbol_names[i] init = init_vars[i] entry = iter_entry_vars[i] exit_ = iter_exit_vars[i] invariant = shape_invariants[i] try: nest.assert_same_structure(init, entry, expand_composites=True) except (ValueError, TypeError): # `Variable`s in `init` may be implicitly converted to `Tensor`s. Convert # `ResourceVariable`s to Tensors so tf.nest.assert_same_structure # won't break due to type spec mismatches between `ResourceVariable`s and # `Tensor`s. try: init_tensors = variable_utils.convert_variables_to_tensors(init) nest.assert_same_structure(init_tensors, entry, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError("'{}' does not have the same nested structure after one" ' iteration.\n\n{}'.format(name, e)) from e try: nest.assert_same_structure(entry, exit_, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError("'{}' does not have the same nested structure after one" ' iteration.\n\n{}'.format(name, e)) from e if invariant is not None: try: nest.assert_same_structure(init, invariant, expand_composites=False) except (ValueError, TypeError) as e: raise TypeError("'{}' does not have the same nested structure as its" ' corresponding shape invariant.\n\n{}'.format( name, e)) from e nest.map_structure( functools.partial(_verify_single_loop_var, name, check_shapes), init, entry, exit_, invariant) def verify_single_cond_var(name, body_var, orelse_var): """Verifies whether body_var and orelse_var are consistent.""" if body_var is None: raise ValueError("'{}' is None at the end of the main branch.".format(name)) if orelse_var is None: raise ValueError( "'{}' is None at the end of the else branch.".format(name)) if isinstance(body_var, (bool, int, float, str, np.ndarray)): body_var = ops.convert_to_tensor_v2(body_var) if isinstance(orelse_var, (bool, int, float, str, np.ndarray)): orelse_var = ops.convert_to_tensor_v2(orelse_var) if (not tensor_util.is_tf_type(body_var) or not tensor_util.is_tf_type(orelse_var)): return # TODO(mdan): Properly account for CompositeTensors. if (not hasattr(body_var, 'dtype') or not hasattr(orelse_var, 'dtype')): return if body_var.dtype != orelse_var.dtype: raise TypeError( "'{}' has dtype {} in the main branch, but dtype {} in the else" ' branch'.format(name, body_var.dtype.name, orelse_var.dtype.name)) def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name): """Verifies variables output by a conditional branch for consistency.""" for name, var_ in zip(symbol_names, vars_): if isinstance(var_, variables.Undefined): raise ValueError( "'{}' must also be initialized in the {} branch".format( name, branch_name)) if isinstance(var_, variables.UndefinedReturnValue): raise ValueError( 'the {} branch must also have a return statement.'.format( branch_name)) def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): """Verifies variables manipulated by a conditional for consistency.""" named_vars = zip(symbol_names, body_vars, orelse_vars) for name, body_var, orelse_var in named_vars: try: nest.assert_same_structure(body_var, orelse_var, expand_composites=True) except (ValueError, TypeError): # One branch of cond could be a `Tensor`, while the other branch could be # a `ResourceVariable`. Convert `ResourceVariable`s to `Tensor`s so # assert_same_structure won't fail. try: body_var_tensors = variable_utils.convert_variables_to_tensors(body_var) orelse_var_tensors = variable_utils.convert_variables_to_tensors( orelse_var) nest.assert_same_structure(body_var_tensors, orelse_var_tensors, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( "'{}' must have the same nested structure in the main and else" ' branches:\n\n{}'.format(name, str(e))) from e nest.map_structure( functools.partial(verify_single_cond_var, name), body_var, orelse_var) def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Functional form of a for statement. The loop operates on a state, which includes all symbols that are variant across loop iterations, excluding the variables local to the loop. For example, given the loop below that calculates the geometric and arithmetic means or some numbers: ``` geo_mean = 1 arith_mean = 0 for i in range(n): a = numbers[i] geo_mean *= a arith_mean += a ``` The state is represented by the variables named geo_mean and arith_mean. The `extra_test`, `body`, `get_state` and `set_state` functions must bind to the original `geo_mean` and `arith_mean` symbols, using `nonlocal`. The inputs and outputs of the callables representing the loop blocks are not explicit - instead, these functions must use nonlocal/global for side effects. The inputs and outputs are instead controlled by the set_state/get_state functions. Args: iter_: The entity being iterated over. extra_test: Callable with boolean return type. An additional loop condition. body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. symbol_names: Tuple containing names of the loop variables returned by get_state. opts: Optional dict of extra loop parameters. """ try: for_fn = for_loop_registry.lookup(iter_) except LookupError: for_fn = _py_for_stmt # TODO(bwieder): Refactor isinstance(iter_, ragged_tensor.RaggedTensor) to use # the registry once python/autograph/utils does not depend on dataset_ops. if tensor_util.is_tf_type(iter_): if tensors.is_range_tensor(iter_): for_fn = _tf_range_for_stmt elif isinstance(iter_, ragged_tensor.RaggedTensor): for_fn = _tf_ragged_for_stmt else: for_fn = _known_len_tf_for_stmt elif isinstance(iter_, distribute.Iterator): for_fn = _tf_iterator_for_stmt elif isinstance(iter_, distribute.Iterable): # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)... for_fn = _tf_distributed_iterable_for_stmt for_fn(iter_, extra_test, body, get_state, set_state, symbol_names, opts) def _py_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts ): """Overload of for_stmt that executes a Python for loop.""" del get_state, set_state, symbol_names, opts if __debug__: checker = _PythonLoopChecker() before_iteration = checker.before_iteration after_iteration = checker.after_iteration before_iteration() original_body = body def protected_body(protected_iter): original_body(protected_iter) after_iteration() before_iteration() body = protected_body if extra_test is not None: def guarded_extra_test(): extra_test_result = extra_test() try: # Note: Using try/except and not tensor_util.is_tf_type to avoid # performance degradation. return bool(extra_test_result) except errors_impl.OperatorNotAllowedInGraphError as e: ag_logging.log( 1, 'Caught error while evaluating loop stop condition', exc_info=True) # TODO(mdan): We can pass the location of extra_test and show it here. raise NotImplementedError( 'break and return statements which depend on a TF condition are not' ' supported in Python for loops. Did you intend to make it a TF' ' loop?\nSee ' 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'python/autograph/g3doc/reference/limitations.md' '#consistency-of-control-flow-types for more info.') from e if guarded_extra_test(): for target in iter_: body(target) if not guarded_extra_test(): break else: for target in iter_: body(target) def _add_max_iterations_hint(opts, n): # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations. if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): opts['maximum_iterations'] = n def _known_len_tf_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF entities that admit a length.""" n = py_builtins.len_(iter_) # TODO(b/117628877): Revisit performance once XLA has the necessary support. # Note: using a TensorArray creates an extra copy, but can calculate # gradients more efficiently than StridedSlice. ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) iter_ = ta.unstack(iter_) iterate_index = 0 def aug_get_state(): return (iterate_index,) + get_state() def aug_set_state(aug_loop_vars): nonlocal iterate_index # TODO(b/171479293): Drop the lint override. iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iteration index # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): nonlocal iterate_index body(iter_.read(iterate_index)) iterate_index += 1 def aug_test(): main_test = iterate_index < n if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test _add_max_iterations_hint(opts, n) _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('',) + symbol_names, opts, ) def _tf_ragged_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF ragged tensors.""" init_vars = get_state() verify_loop_init_vars(init_vars, symbol_names) # TODO(mdan): Move this into len()? Requires eager support. if iter_.shape and iter_.shape[0] is not None: n = iter_.shape[0] else: n = iter_.row_lengths()[0] iterate_index = 0 def aug_get_state(): return (iterate_index,) + get_state() def aug_set_state(aug_loop_vars): nonlocal iterate_index # TODO(b/171479293): Drop the lint override. iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iteration index # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): nonlocal iterate_index body(iter_[iterate_index]) iterate_index += 1 def aug_test(): main_test = iterate_index < n if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test _add_max_iterations_hint(opts, n) _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('',) + symbol_names, opts) def _tf_range_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over a TF range (and elides it).""" start, limit, delta = iter_.op.inputs iterate = start def _value_or(name, var, default): if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): return default return var def aug_get_state(): state_vars = get_state() state_vars = tuple( _value_or(name, var, iterate) for name, var in zip(symbol_names, state_vars)) return (iterate,) + state_vars def aug_set_state(aug_loop_vars): nonlocal iterate # TODO(b/171479293): Drop the lint override. iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable # The iteration index is not "output" by the for loop. If the iterate # is used outside the loop, it will appear in the loop vars separately. set_state(loop_vars) def aug_body(): nonlocal iterate body(iterate) iterate += delta def aug_test(): # TODO(b/159713842): Remove once constant folding works. const_delta = tensor_util.constant_value(delta) if const_delta is not None: if const_delta >= 0: main_test = iterate < limit else: main_test = iterate > limit else: main_test = math_ops.logical_or( math_ops.logical_and(delta >= 0, iterate < limit), math_ops.logical_and(delta < 0, iterate > limit)) if extra_test is not None: main_test = control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test _add_max_iterations_hint( opts, math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32)) _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, ('',) + symbol_names, opts) def _tf_iterator_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" symbol_names = ('',) + symbol_names has_next = True def aug_get_state(): return (has_next,) + get_state() def aug_set_state(aug_loop_vars): nonlocal has_next # TODO(b/171479293): Drop the lint override. has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable set_state(loop_vars) init_vars = aug_get_state() verify_loop_init_vars(init_vars, symbol_names) def aug_body(): """Main body passed to _tf_while_stmt.""" nonlocal has_next opt_iterate = iter_.get_next_as_optional() has_next = opt_iterate.has_value() loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop. def main_path(): body(opt_iterate.get_value()) new_loop_vars = aug_get_state() # Note: this verification duplicates the one performed in tf_while_stmt, # but needs to be done earlier to prevent the tf.cond from blowing up # first. verify_tf_loop_vars( init_vars, loop_vars, new_loop_vars, symbol_names, opts) return new_loop_vars def noop_path(): return loop_vars # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. # Calling set_state so that get_state() _tf_while_loop sees the conditional # tensors. aug_set_state( control_flow_ops.cond(has_next, main_path, noop_path)) def aug_test(): # This value takes a complicated path to get here: # prev_iteration_body -> get_state -> tf.while_loop (as loop var) # -> current_iteration_body -> set_state -> has_next main_test = has_next if extra_test is not None: return control_flow_ops.cond(main_test, extra_test, lambda: False) return main_test _tf_while_stmt( aug_test, aug_body, aug_get_state, aug_set_state, symbol_names, opts) def _tf_distributed_iterable_for_stmt( iter_, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of for_stmt that iterates over TF distributed datasets.""" if extra_test is not None: raise NotImplementedError( 'break and return statements are not yet supported in ' 'for ... in distributed input loops.') init_vars = get_state() verify_loop_init_vars(init_vars, symbol_names) if 'shape_invariants' in opts: opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( opts['shape_invariants'], init_vars) def reduce_body(loop_vars, iterate): set_state(loop_vars) body(iterate) new_loop_vars = get_state() verify_tf_loop_vars( init_vars, loop_vars, new_loop_vars, symbol_names, opts) return new_loop_vars set_state(iter_.reduce(init_vars, reduce_body)) def while_stmt(test, body, get_state, set_state, symbol_names, opts): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are variant across loop iterations. In what follows we refer to state as either a tuple of entities that represent an actual state, or a list of arguments of the corresponding types. The inputs and outputs of the callables representing the loop blocks are not explicit - instead, these functions must use nonlocal/global for side effects. The inputs and outputs are instead controlled by the set_state/get_state functions. Args: test: Callable with boolean return type. The loop condition. body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. set_state: Additional callable which save values captured by get_state back into the Python environment. This is only useful when staging the loop. symbol_names: Tuple containing the names of all loop variables. opts: Optional dict of extra loop parameters. Returns: Tuple containing the final state. """ # Evaluate the initial test once in order to do the dispatch. The evaluation # is isolated to minimize unwanted side effects. # TODO(mdan): Do a full iteration - some state types might lower to Tensor. with func_graph.FuncGraph('tmp').as_default(): init_test = test() # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine # with the re-evaluation of `test` that `_tf_while_stmt` will make. if tensors.is_dense_tensor(init_test): _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts) return # Normal Python: We already consumed one evaluation of `test`; consistently, # unroll one iteration before dispatching to a normal loop. # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? if not init_test: return body() _py_while_stmt(test, body, get_state, set_state, opts) class _PythonLoopChecker(object): """Verifies Python loops for TF-specific limits.""" __slots__ = ( 'iterations', 'check_inefficient_unroll', 'check_op_count_after_iteration', 'ops_before_iteration', ) def __init__(self): self.iterations = 1 self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL # Triggered when we decided to test the op counts. self.check_op_count_after_iteration = False def _get_ops(self): return ops.get_default_graph().get_operations() def _check_unroll_limits(self): if self.iterations > PYTHON_MAX_ITERATIONS: raise ValueError('iteration limit exceeded') def _stop_checking_inefficient_unroll(self): self.check_inefficient_unroll = False self.check_op_count_after_iteration = False self.ops_before_iteration = None def _verify_inefficient_unroll(self): """Checks for possibly-inefficient creation of ops in a Python loop.""" assert self.ops_before_iteration is not None ops_after_iteration = self._get_ops() new_ops = tuple( op for op in ops_after_iteration if op not in self.ops_before_iteration) if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS: return False ag_logging.warning( 'Large unrolled loop detected. Did you mean to use a TF loop?' ' The following ops were created after iteration %s: %s' '\nSee' ' https://github.com/tensorflow/tensorflow/blob/master/' 'tensorflow/python/autograph/g3doc/reference/common_errors.md' '#warning-large-unrolled-loop-detected' '\n' 'Location:' '\n%s' '', self.iterations, new_ops, '\n'.join(traceback.format_stack())) return True def before_iteration(self): """Called before each iteration in a Python loop.""" if (self.check_inefficient_unroll and self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS): self.ops_before_iteration = self._get_ops() self.check_op_count_after_iteration = True def after_iteration(self): """Called after each iteration in a Python loop.""" self.iterations += 1 self._check_unroll_limits() if self.check_op_count_after_iteration: did_warn = self._verify_inefficient_unroll() if did_warn: self._stop_checking_inefficient_unroll() # Only warn once. elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3: # Once deciding to check the op counts, only do it for a few iterations. self._stop_checking_inefficient_unroll() def _py_while_stmt(test, body, get_state, set_state, opts): """Overload of while_stmt that executes a Python while loop.""" del opts, get_state, set_state if __debug__: checker = _PythonLoopChecker() before_iteration = checker.before_iteration after_iteration = checker.after_iteration before_iteration() original_body = body def protected_body(): original_body() after_iteration() before_iteration() body = protected_body def guarded_test(): test_result = test() try: # Note: Using try/except and not tensor_util.is_tf_type to avoid # performance degradation. return bool(test_result) except errors_impl.OperatorNotAllowedInGraphError as e: ag_logging.log( 1, 'Caught error while evaluating while loop condition', exc_info=True) # TODO(mdan): distinguish beteen these two cases. raise NotImplementedError( 'The condition of while loop started as non-Tensor, then changed to' ' Tensor. This may happen either because variables changed type, or' ' when a break or return statement inside the loop depends on a' ' Tensor condition. In both cases, changing to a TF loop should' ' remove the error.\nSee ' 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'python/autograph/g3doc/reference/limitations.md' '#consistency-of-control-flow-types for more info.') from e while guarded_test(): body() def _shape_invariants_mapping_to_positional_list(mapping, keys): # The keys are not expected to be hashable. mapping = {id(k): (k, v) for k, v in mapping} result = [] for k in keys: map_key, map_val = mapping.get(id(k), (None, None)) result.append( map_val if map_key is k else nest.map_structure(lambda _: None, k)) return tuple(result) # Textual description of what a legal TF loop variable is. This description # summarizes types that _placeholder_value below can handle. Keep the two # together and in sync. LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof' def _placeholder_value(like, shape_invariant, original=None): """Constructs a (dummy) placeholder value for a loop-initialized variable. Args: like: Any object. The value created by the first iteration of the loop. If a Python scalar, the placeholder will be the zero value of that type. If a Tensor, the placeholder will be a zero tensor of matching shape and dtype. If a list, dict or tuple, the placeholder will be an identical structure of placeholders. shape_invariant: The shape invariant specified by the user (or None, if nothing was specified) for the respective variable. original: Any object. The value of the variable prior to entering the loop. Typically, this is one of the special "Undefined" value, because that's when a placeholder is needed. Returns: Either a zero value of structure, shape and dtype mathing 'like', or 'original', if no such zero value could be created. """ if like is None: return original, None elif isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)): return original, None elif isinstance(like, (int, float, bool)): return type(like)(0), None elif tensor_util.is_tf_type(like): like_shape = shape_invariant if shape_invariant is not None else like.shape if like_shape is None or like_shape.rank is None: return array_ops.zeros((), like.dtype), like_shape # If the shape contains dynamic values, set the corresponding starting # dimension to either zero or what the shape invariant specified. placeholder_shape = [] has_dynamic_dims = False for s, i in zip(like.shape, like_shape): if i is None: like_dim = 0 elif isinstance(i, tensor_shape.Dimension): if i.value is None: like_dim = 0 else: like_dim = i.value else: like_dim = i if s is None: placeholder_shape.append(like_dim) has_dynamic_dims = True elif isinstance(s, tensor_shape.Dimension): if s.value is None: placeholder_shape.append(like_dim) has_dynamic_dims = True else: placeholder_shape.append(s.value) else: placeholder_shape.append(s) if has_dynamic_dims: invariant = like_shape else: invariant = None return array_ops.zeros(placeholder_shape, like.dtype), invariant elif isinstance(like, (list, tuple, dict)): if shape_invariant is None: zipped = nest.map_structure(lambda v: _placeholder_value(v, None), nest.flatten(like)) else: zipped = nest.map_structure(_placeholder_value, nest.flatten(like), nest.flatten(shape_invariant)) vals, invars = zip(*zipped) return (nest.pack_sequence_as(like, vals), nest.pack_sequence_as(like, invars)) # This is to be caught by _try_handling_undefineds, to give more context. raise TypeError( "Found an unsupported type '{}' while creating placeholder for {}." ' Supported types include Tensor, int, float, bool, list, tuple or dict.' .format(type(like).__name__, like)) def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, shape_invariants, symbol_names): """Makes a best-effort attempt to substitute undefineds with placeholders. Note: this substitution requires two things to happen: 1. the types of loop variables could be inferred (usually by staging one iteration) 2. these types could be replaced by placeholders (e.g. zero values, for tensors). Args: body: a function representing the loop body. See while_stmt. get_state: state getter for the loop statement. See while_stmt. set_state: state getter for the loop statement. See while_stmt. init_vars: loop variables before entering the loop. See while_stmt. nulls: list of boolean flags indicating whether the corresponding loop var is None or undefined. shape_invariants: user-specified shape invariant for each loop variable. symbol_names: list of loop variable names. See while_stmt. Returns: A tuple (success, new_init_vars, extra_shape_invariants, failure_message): * success is a boolean flag indicating whether types could be successfully inferred (step 1 above) * new_init_vars contains the loop vars, with None or undefined values replaced by default values, where possible (step 2 above) * extra_shape_invariants contains shape invariants that would be needed by while_stmt, for instance if the placeholder values had a shape different from the corresponding loop outputs """ state_modified = False first_iter_vars = None failure_message = None try: # Stage an iteration of the loop body in a temporary graph. with func_graph.FuncGraph('tmp').as_default(): # This call to set_state helps report nicer error messages when symbols # are inconsistently used. # Another complication is that non_tensor values will be autocast to # Tensor by while_loop, and their static value lost. So we need to account # that here. def autocast_to_tensor(v): if isinstance( v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)): init_val = ops.convert_to_tensor_v2(v) return array_ops.placeholder(init_val.dtype, init_val.shape) return v autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars) set_state(autocast_init_vars) state_modified = True body() first_iter_vars = get_state() # Note: the actual placeholder value doesn't matter, because as the # staging proved, it will be replaced by an actual value before being # read. inits_and_invariants = tuple( (_placeholder_value(iv, i, v) if n else (v, None)) for v, n, iv, i in zip(init_vars, nulls, first_iter_vars, shape_invariants)) init_vars, extra_shape_invariants = zip(*inits_and_invariants) success = True except (UnboundLocalError, TypeError, ValueError, KeyError): ag_logging.log(1, 'Caught error while staging loop body', exc_info=True) # Fall back to the old functionality. It will likely result in an input # validation failure. exc = sys.exc_info() failure_message = ( 'Note: AutoGraph tried to define it automatically, but ran into a' ' {}: {}'.format(exc[0].__name__, exc[1])) finally: if state_modified: set_state(init_vars) # This check runs regardless, in case we captured non-Tensor inputs. verify_loop_init_vars( init_vars, symbol_names, first_iter_vars, extra_message=failure_message) return success, init_vars, extra_shape_invariants def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars): """Creates an error message asking for the loop to iterate at least once.""" var_names = [] for sn, n, v in zip(symbol_names, nulls, init_vars): if not n: continue if isinstance(v, variables.UndefinedReturnValue): var_names.append('the function return value') else: var_names.append(sn) var_names = ', '.join(var_names) return 'loop must iterate at least once to initialize {}'.format(var_names) def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): """Overload of while_stmt that stages a TF while_stmt.""" init_vars = get_state() orig_init_vars = init_vars nulls = tuple(_is_none_or_undef(v) for v in init_vars) if any(nulls): shape_invars_by_init_vals = { id(v): i for v, i in opts.get('shape_invariants', ()) } shape_invariants = tuple( shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars) (require_one_iteration, init_vars, extra_shape_invariants) = _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, shape_invariants, symbol_names) else: require_one_iteration = False if require_one_iteration: merged_shape_invariants = dict(shape_invars_by_init_vals) # This has two roles: # 1. Shape invariants are remapped from the old init vars to the new ones. # 2. Any new shape invariants created by the init vars are kept, but only # if the user didn't already specify some. for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants): merged_invariant = merged_shape_invariants.get(id(v), ni) if merged_invariant is not None: merged_shape_invariants[id(nv)] = merged_invariant merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)]) for nv in init_vars if id(nv) in merged_shape_invariants) if merged_shape_invariants: opts = dict(**opts) opts['shape_invariants'] = merged_shape_invariants def aug_test(*loop_vars): if require_one_iteration: loop_vars = loop_vars[1:] set_state(loop_vars) return _verify_tf_condition(test(), 'while loop') def aug_body(*loop_vars): if require_one_iteration: loop_vars = loop_vars[1:] set_state(loop_vars) body() new_loop_vars = get_state() verify_tf_loop_vars( init_vars, loop_vars, new_loop_vars, symbol_names, opts) if require_one_iteration: new_loop_vars = (True,) + new_loop_vars return new_loop_vars if 'shape_invariants' in opts: opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( opts['shape_invariants'], init_vars) while_loop_opts = dict(opts) while_loop_opts.pop('iterate_names', None) # Non-v2 while_loop unpacks the results when there is only one return value. # This enforces consistency across versions. while_loop_opts['return_same_structure'] = True if require_one_iteration: aug_init_vars = (False,) + init_vars if 'shape_invariants' in while_loop_opts: while_loop_opts['shape_invariants'] = ( (None,) + while_loop_opts['shape_invariants']) else: aug_init_vars = init_vars final_loop_vars = control_flow_ops.while_loop( aug_test, aug_body, aug_init_vars, **while_loop_opts) if require_one_iteration: with ops.control_dependencies([ control_flow_ops.Assert(final_loop_vars[0], [ _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars) ]) ]): final_loop_vars = nest.map_structure( lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v), final_loop_vars[1:], ) set_state(final_loop_vars) def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): """Functional form of an if statement. The conditional operates on a state, which includes all symbols whose values are a function of the branch taken. For example, given the code below that calculates the abs function: ``` x = 1 if x > 0: x = -x ``` The state is represented by the variable `x`. The `body, `orelse` and `set_state` functions must bind to the original `x` symbol, using `nonlocal`. The inputs and outputs of the callables representing the loop blocks are not explicit - instead, these functions must use nonlocal/global for side effects. The inputs and outputs are instead controlled by the set_state/get_state functions. Args: cond: Boolean. body: Callable representing the main block of the conditional. orelse: Callable representing the else block of the conditional. get_state: Function that returns a tuple containing the values of all composite symbols modified within the conditional. This allows access to state that branches may mutate through side effects. This function is not needed and should not be called when dispatching to code matching Python's default semantics. This is useful for checkpointing to avoid unintended side-effects when staging requires evaluating all code-paths. set_state: Function to set the values of all composite symbols modified within the conditional. This is the complement to get_state, used to restore checkpointed values. The single argument a tuple containing values for each composite symbol that may be modified in a branch of the conditional. The is usually the result of a call to get_state. symbol_names: Tuple containing basic loop var names. nouts: Number of variables output by the statement. Vars which are not outputs will not be passed through staged control flow such as tf.cond. This includes variables that are defined before the conditional, but are not used after it. """ # Note: tf.cond doesn't support SparseTensor. if tensors.is_dense_tensor(cond): _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) else: _py_if_stmt(cond, body, orelse) def _tf_if_stmt( cond, body, orelse, get_state, set_state, symbol_names, nouts): """Overload of if_stmt that stages a TF cond.""" cond = _verify_tf_condition(cond, 'if statement') if not nouts: prev_get_state, prev_set_state = get_state, set_state # Control flow V1 wants at least one output. get_state = lambda: (0,) + prev_get_state() set_state = lambda v: prev_set_state(v[1:]) symbol_names += ('',) nouts = 1 init_vars = get_state() # TODO(mdan): Use nonlocal once we no longer need to support py2. new_body_vars_ = [None] new_orelse_vars_ = [None] def aug_body(): set_state(init_vars) body() new_body_vars = get_state() new_body_vars = new_body_vars[:nouts] new_body_vars_[0] = new_body_vars _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main') if new_orelse_vars_[0] is not None: _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names) return new_body_vars def aug_orelse(): set_state(init_vars) orelse() new_orelse_vars = get_state() new_orelse_vars = new_orelse_vars[:nouts] new_orelse_vars_[0] = new_orelse_vars _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else') if new_body_vars_[0] is not None: _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names) return new_orelse_vars final_cond_vars = control_flow_ops.cond( cond, aug_body, aug_orelse, strict=True) final_cond_vars = final_cond_vars + init_vars[nouts:] set_state(final_cond_vars) def _py_if_stmt(cond, body, orelse): """Overload of if_stmt that executes a Python if statement.""" return body() if cond else orelse()