
1244 lines
47 KiB
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Functional operations."""
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
# pylint: disable=unused-import
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
def foldl(fn,
"""foldl on the list of tensors unpacked from `elems` on dimension 0.
This foldl operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn, and the second is the value at the current
position of `elems`. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
This method also allows multi-arity `elems` and output of `fn`. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The signature of `fn` may
match the structure of `elems`. That is, if `elems` is
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
fn: The callable to be performed.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be the first argument to `fn`.
initializer: (optional) A tensor or (possibly nested) sequence of tensors,
as the initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run in
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors, resulting from applying
`fn` consecutively to the list of tensors unpacked from `elems`, from first
to last.
TypeError: if `fn` is not callable.
elems = tf.constant([1, 2, 3, 4, 5, 6])
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
if not callable(fn):
raise TypeError(
f"{fn.__name__} is not callable. Please provide a callable function.")
def create_ta(elem):
return tensor_array_ops.TensorArray(
dtype=elem.dtype, size=n, dynamic_size=False,
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldl", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode:
# Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
varscope = vs.get_variable_scope()
varscope_caching_device_was_none = False
if varscope.caching_device is None:
# TODO(ebrevdo): Change to using colocate_with here and in other
# methods.
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
# Convert elems to tensor array. n may be known statically.
elems_flat = [
ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
n = (
tensor_shape.dimension_value(elems_flat[0].shape[0]) or
elems_ta = nest.map_structure(create_ta, elems)
if initializer is None:
a = nest.map_structure(lambda elem:, elems_ta)
i = constant_op.constant(1)
a = initializer
i = constant_op.constant(0)
def compute(i, a):
elem_i = nest.map_structure(lambda elem:, elems_ta)
a = fn(a, elem_i)
return [i + 1, a]
_, r_a = control_flow_ops.while_loop(
lambda i, a: i < n,
compute, [i, a],
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
return r_a
@tf_export("foldl", v1=[])
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldl(fn, elems, back_prop=False)
results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""",
def foldl_v2(fn,
"""foldl on the list of tensors unpacked from `elems` on dimension 0.
This foldl operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn, and the second is the value at the current
position of `elems`. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
This method also allows multi-arity `elems` and output of `fn`. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The signature of `fn` may
match the structure of `elems`. That is, if `elems` is
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
fn: The callable to be performed.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be the first argument to `fn`.
initializer: (optional) A tensor or (possibly nested) sequence of tensors,
as the initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run in
back_prop: (optional) Deprecated. False disables support for back
propagation. Prefer using `tf.stop_gradient` instead.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors, resulting from applying
`fn` consecutively to the list of tensors unpacked from `elems`, from first
to last.
TypeError: if `fn` is not callable.
elems = tf.constant([1, 2, 3, 4, 5, 6])
sum = tf.foldl(lambda a, x: a + x, elems)
# sum == 21
return foldl(
def foldr(fn,
"""foldr on the list of tensors unpacked from `elems` on dimension 0.
This foldr operator repeatedly applies the callable `fn` to a sequence
of elements from last to first. The elements are made of the tensors
unpacked from `elems`. The callable fn takes two tensors as arguments.
The first argument is the accumulated value computed from the preceding
invocation of fn, and the second is the value at the current position of
`elems`. If `initializer` is None, `elems` must contain at least one element,
and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
This method also allows multi-arity `elems` and output of `fn`. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The signature of `fn` may
match the structure of `elems`. That is, if `elems` is
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
fn: The callable to be performed.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be the first argument to `fn`.
initializer: (optional) A tensor or (possibly nested) sequence of tensors,
as the initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run in
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors, resulting from applying
`fn` consecutively to the list of tensors unpacked from `elems`, from last
to first.
TypeError: if `fn` is not callable.
elems = [1, 2, 3, 4, 5, 6]
sum = foldr(lambda a, x: a + x, elems)
# sum == 21
if not callable(fn):
raise TypeError(
f"{fn.__name__} is not callable. Please provide a callable function.")
def create_ta(elem):
return tensor_array_ops.TensorArray(
dtype=elem.dtype, size=n, dynamic_size=False,
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldr", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode:
# Any get_variable calls in fn will cache the first call locally and not
# issue repeated network I/O requests for each iteration.
varscope = vs.get_variable_scope()
varscope_caching_device_was_none = False
if varscope.caching_device is None:
# TODO(ebrevdo): Change to using colocate_with here and in other
# methods.
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
# Convert elems to tensor array. n may be known statically.
elems_flat = [
ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
n = (
tensor_shape.dimension_value(elems_flat[0].shape[0]) or
elems_ta = nest.map_structure(create_ta, elems)
if initializer is None:
i = n - 1
a = nest.map_structure(lambda elem:, elems_ta)
i = n
a = initializer
def compute(i, a):
i -= 1
elem = nest.map_structure(lambda elem:, elems_ta)
a_out = fn(a, elem)
return [i, a_out]
_, r_a = control_flow_ops.while_loop(
lambda i, a: i > 0,
compute, [i, a],
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
return r_a
@tf_export("foldr", v1=[])
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""",
def foldr_v2(fn,
"""foldr on the list of tensors unpacked from `elems` on dimension 0.
This foldr operator repeatedly applies the callable `fn` to a sequence
of elements from last to first. The elements are made of the tensors
unpacked from `elems`. The callable fn takes two tensors as arguments.
The first argument is the accumulated value computed from the preceding
invocation of fn, and the second is the value at the current position of
`elems`. If `initializer` is None, `elems` must contain at least one element,
and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
This method also allows multi-arity `elems` and output of `fn`. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The signature of `fn` may
match the structure of `elems`. That is, if `elems` is
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
fn: The callable to be performed.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be the first argument to `fn`.
initializer: (optional) A tensor or (possibly nested) sequence of tensors,
as the initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run in
back_prop: (optional) Deprecated. False disables support for back
propagation. Prefer using `tf.stop_gradient` instead.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors, resulting from applying
`fn` consecutively to the list of tensors unpacked from `elems`, from last
to first.
TypeError: if `fn` is not callable.
elems = [1, 2, 3, 4, 5, 6]
sum = tf.foldr(lambda a, x: a + x, elems)
# sum == 21
return foldr(
def scan(fn,
"""scan on the list of tensors unpacked from `elems` on dimension 0.
See also `tf.map_fn`.
The simplest version of `scan` repeatedly applies the callable `fn` to a
sequence of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn, and the second is the value at the current
position of `elems`. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
If reverse=True, it's fn(initializer, values[-1]).shape.
This method also allows multi-arity `elems` and accumulator. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The second argument of
`fn` must match the structure of `elems`.
If no `initializer` is provided, the output structure and dtypes of `fn`
are assumed to be the same as its input; and in this case, the first
argument of `fn` must match the structure of `elems`.
If an `initializer` is provided, then the output of `fn` must have the same
structure as `initializer`; and the first argument of `fn` must match
this structure.
For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
`[i1, i2]` then an appropriate signature for `fn` in `python2` is:
`fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
`[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
one that works in `python3`, is:
`fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
fn: The callable to be performed. It accepts two arguments. The first will
have the same structure as `initializer` if one is provided, otherwise it
will have the same structure as `elems`. The second will have the same
(possibly nested) structure as `elems`. Its output must have the same
structure as `initializer` if one is provided, otherwise it must have the
same structure as `elems`.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be the first argument to `fn`.
initializer: (optional) A tensor or (possibly nested) sequence of tensors,
initial value for the accumulator, and the expected output type of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run in
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
reverse: (optional) True scans the tensor last to first (instead of first to
name: (optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
results of applying `fn` to tensors unpacked from `elems` along the first
dimension, and the previous accumulator value(s), from first to last (or
last to first, if `reverse=True`).
TypeError: if `fn` is not callable or the structure of the output of
`fn` and `initializer` do not match.
ValueError: if the lengths of the output of `fn` and `initializer`
do not match.
elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
sum = scan(lambda a, x: a + x, elems, reverse=True)
# sum == [21, 20, 18, 15, 11, 6]
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(
lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]
elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
if not callable(fn):
raise TypeError(
f"{fn.__name__} is not callable. Please provide a callable function.")
input_is_sequence = nest.is_nested(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
if initializer is None:
output_is_sequence = input_is_sequence
output_flatten = input_flatten
output_pack = input_pack
output_is_sequence = nest.is_nested(initializer)
output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
def output_pack(x):
return (nest.pack_sequence_as(initializer, x)
if output_is_sequence else x[0])
elems_flat = input_flatten(elems)
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "scan", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode:
# Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
varscope = vs.get_variable_scope()
varscope_caching_device_was_none = False
if varscope.caching_device is None:
# TODO(ebrevdo): Change to using colocate_with here and in other
# methods.
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
# Convert elems to tensor array.
elems_flat = [
ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
# Convert elems to tensor array. n may be known statically.
n = tensor_shape.dimension_value(elems_flat[0].shape[0])
if n is None:
n = array_ops.shape(elems_flat[0])[0]
# TensorArrays are always flat
elems_ta = [
infer_shape=True) for elem in elems_flat
# Unpack elements
elems_ta = [
elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)
if initializer is None:
a_flat = [ - 1 if reverse else 0) for elem in elems_ta]
i = 1
initializer_flat = output_flatten(initializer)
a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
i = 0
# Create a tensor array to store the intermediate values.
accs_ta = [
element_shape=init.shape if infer_shape else None,
infer_shape=infer_shape) for init in a_flat
if initializer is None:
accs_ta = [
acc_ta.write(n - 1 if reverse else 0, a)
for (acc_ta, a) in zip(accs_ta, a_flat)
def compute(i, a_flat, tas):
"""The loop body of scan.
i: the loop counter.
a_flat: the accumulator value(s), flattened.
tas: the output accumulator TensorArray(s), flattened.
[i + 1, a_flat, tas]: the updated counter + new accumulator values +
updated TensorArrays
TypeError: if initializer and fn() output structure do not match
ValueType: if initializer and fn() output lengths do not match
packed_elems = input_pack([ for elem_ta in elems_ta])
packed_a = output_pack(a_flat)
a_out = fn(packed_a, packed_elems)
nest.assert_same_structure(elems if initializer is None else initializer,
flat_a_out = output_flatten(a_out)
tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
if reverse:
next_i = i - 1
next_i = i + 1
return (next_i, flat_a_out, tas)
if reverse:
initial_i = n - 1 - i
condition = lambda i, _1, _2: i >= 0
initial_i = i
condition = lambda i, _1, _2: i < n
_, _, r_a = control_flow_ops.while_loop(
compute, (initial_i, a_flat, accs_ta),
results_flat = [r.stack() for r in r_a]
n_static = tensor_shape.Dimension(
for elem in elems_flat[1:]:
for r in results_flat:
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
return output_pack(results_flat)
@tf_export("scan", v1=[])
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.scan(fn, elems, back_prop=False)
results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""",
def scan_v2(fn,
"""scan on the list of tensors unpacked from `elems` on dimension 0.
The simplest version of `scan` repeatedly applies the callable `fn` to a
sequence of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn, and the second is the value at the current
position of `elems`. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
If reverse=True, it's fn(initializer, values[-1]).shape.
This method also allows multi-arity `elems` and accumulator. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The second argument of
`fn` must match the structure of `elems`.
If no `initializer` is provided, the output structure and dtypes of `fn`
are assumed to be the same as its input; and in this case, the first
argument of `fn` must match the structure of `elems`.
If an `initializer` is provided, then the output of `fn` must have the same
structure as `initializer`; and the first argument of `fn` must match
this structure.
For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
`[i1, i2]` then an appropriate signature for `fn` in `python2` is:
`fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
`[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
one that works in `python3`, is:
`fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
fn: The callable to be performed. It accepts two arguments. The first will
have the same structure as `initializer` if one is provided, otherwise it
will have the same structure as `elems`. The second will have the same
(possibly nested) structure as `elems`. Its output must have the same
structure as `initializer` if one is provided, otherwise it must have the
same structure as `elems`.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be the first argument to `fn`.
initializer: (optional) A tensor or (possibly nested) sequence of tensors,
initial value for the accumulator, and the expected output type of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run in
back_prop: (optional) Deprecated. False disables support for back
propagation. Prefer using `tf.stop_gradient` instead.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
reverse: (optional) True scans the tensor last to first (instead of first to
name: (optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
results of applying `fn` to tensors unpacked from `elems` along the first
dimension, and the previous accumulator value(s), from first to last (or
last to first, if `reverse=True`).
TypeError: if `fn` is not callable or the structure of the output of
`fn` and `initializer` do not match.
ValueError: if the lengths of the output of `fn` and `initializer`
do not match.
elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
sum = scan(lambda a, x: a + x, elems, reverse=True)
# sum == [21, 20, 18, 15, 11, 6]
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(
lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]
elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
return scan(
# pylint: disable=invalid-name
def If(cond, inputs, then_branch, else_branch, name=None):
r"""output = Cond(inputs) ?
then_branch(inputs) : else_branch(inputs).
cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
converted to a boolean according to the following rule: if the scalar is a
numerical value, non-zero means True and zero means False; if the scalar
is a string, non-empty means True and empty means False.
inputs: A list of input tensors.
then_branch: A function takes 'inputs' and returns a list of tensors, whose
types are the same as what else_branch returns.
else_branch: A function takes 'inputs' and returns a list of tensors. whose
types are the same as what then_branch returns.
name: A name for the operation (optional).
A list of tensors returned by either then_branch(inputs)
or else_branch(inputs).
# pylint: disable=protected-access
# Handle the Defun case until users have transitioned to tf.function. Note
# that composites may need to be re-packed by the caller.
if isinstance(then_branch, function._DefinedFunction):
tlist = [_.type for _ in then_branch.definition.signature.output_arg]
return gen_functional_ops._if(
cond, inputs, tlist, then_branch, else_branch, name=name)
# We assume that `then_branch` is a ConcreteFunction here.
then_out = then_branch.structured_outputs
else_out = else_branch.structured_outputs
# Ensure then/else are the same type of composites to avoid an invalid call
# to pack_sequence_as later on.
nest.assert_same_structure(then_out, else_out, expand_composites=True)
tlist = nest.flatten(then_branch.output_dtypes)
ret = gen_functional_ops._if(
cond, inputs, tlist, then_branch, else_branch, name=name)
# Re-pack the outputs to restore any CompositeTensors
return nest.pack_sequence_as(then_out, ret, expand_composites=True)
def Gradient(inputs, f, name=None):
r"""Computes the gradient function for function f via backpropagation.
inputs: A list of tensors of size N + M.
f: The function we want to compute the gradient for. The function 'f' must
be a numerical function which takes N inputs and produces M outputs. Its
gradient function 'g', which is a function taking N + M inputs and
produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ...,
xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1,
dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ...,
xN) (e.g., the loss function). dL/dxi is the partial derivative of L with
respect to xi.
name: A name for the operation (optional).
A list of tensors of size N.
# TODO(zhifengc): Pretty-print the above spec in latex.
# TODO(zhfiengc): Needs some math expert to say the comment above better.
tlist = [_.type for _ in f.definition.signature.input_arg]
return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
def _GetInputDtypes(func):
"""Returns the input dtypes of func, excluding dtypes for captured inputs."""
if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access
return func.declared_input_types
# We assume that `func` is a ConcreteFunction here, but we are not able to
# verify since importing eager function library will cause cyclic dependence.
# ConcreteFunction.inputs includes captured inputs.
num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs)
inputs_without_captured = func.inputs[:num_non_captured_inputs]
return [t.dtype for t in inputs_without_captured]
def _LoopBodyCaptureWrapper(func):
"""Returns a wrapper for `func` that handles loop-carried captured inputs."""
@function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" %
def Wrapper(*args):
"""A wrapper that handles loop-carried captured inputs."""
result = func(*args)
extra_args = tuple(function.get_extra_args())
# Nullary functions return an Operation. Normal functions can't do this
# because their return values are converted to Tensors.
if isinstance(result, ops.Operation):
return extra_args
# Unary functions return a single Tensor value.
elif not isinstance(result, (list, tuple)):
return (result,) + extra_args
# N-ary functions return a tuple of Tensors.
return result + type(result)(extra_args)
return Wrapper
# pylint: disable=invalid-name,protected-access
def While(input_, cond, body, name=None, hostmem=None):
r"""output = input; While (Cond(output)) { output = Body(output) }.
input_: A list of `Tensor` objects. A list of input tensors whose types are
cond: . A function takes 'input' and returns a tensor. If the tensor is a
scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical value,
non-zero means True and zero means False; if the scalar is a string,
non-empty means True and empty means False. If the tensor is not a
scalar, non-emptiness means True and False otherwise.
body: . A function takes a list of tensors and returns another list tensors.
Both lists have the same types as specified by T.
name: A name for the operation (optional).
hostmem: A list of integer. If i is in the list, input[i] is a host memory
ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
have different signatures.
A list of `Tensor` objects. Has the same type as `input`.
A list of output tensors whose types are T.
if cond.captured_inputs:
raise ValueError(
"The 'cond' argument can not have implicitly captured inputs. Received "
f"captured_inputs: {cond.captured_inputs}")
cond_input_types = _GetInputDtypes(cond)
body_input_types = _GetInputDtypes(body)
if cond_input_types != body_input_types:
raise ValueError(
"The 'cond' and 'body' signatures do not match. Received: "
f"cond_input_types={cond_input_types}, body_input_types="
if body.captured_inputs:
cond_dtypes = list(body_input_types) + [
t.dtype for t in body.captured_inputs
@function.Defun(*cond_dtypes, func_name="%s_Wrapper" %
def CondWrapper(*args):
"""A wrapper that handles loop-carried captured inputs."""
return cond(*args[:len(body_input_types)])
ret = gen_functional_ops._while(
input_ + body.captured_inputs,
# Slice off the loop-carried captured inputs.
ret = ret[:-len(body.captured_inputs)]
ret = gen_functional_ops._while(input_, cond, body, name=name)
if hostmem:
input_attr = attr_value_pb2.AttrValue()
ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access
output_attr = attr_value_pb2.AttrValue()
ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
return ret
# b/36459430
# Ideally, we do not need this rewrite For loop into a While loop.
# However, today, if a While runs on GPU and the condition returns a
# boolean, the While kernel crashes. Even if we fix the crash, the
# bool needs to be copied between GPU and CPU. So, a for loop is much
# preferred when running on GPU.
# On the other hand, For op has no directly XLA kernel. So, when we run
# a for loop, we need to rewrite it using a While op.
# It should be possible and probably better to write a XLA C++ kernel
# implementing the logic in _ForUsingWhile.
def _ForUsingWhile(start,
"""Helper to implement a For loop using a While."""
# To support negative delta (e.g., range(100, 0, -3)), we iterate
# over the range(n) and use iter * delta + start as the real
# iteration index. (e.g., for i in range(34): iter = i * (-3) +
# 100).
d = math_ops.abs(delta)
# XLA on TPUs doesn't support integer division
n = math_ops.cast(
math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
math_ops.cast(d, dtypes.float32), dtypes.int32)
# Carried loop variables ("extra_args") are implicitly added to the input list
# of the WhileBody function. WhileCond does not call forbody, and so does not
# depend on any of forbody's extra_args. Since WhileCond and WhileBody
# must have identical inputs, we have to augment the cond signature to take
# the same types as the carried loop variables.
body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
cond_name = "%s_Cond" %
@function.Defun(*body_sig, func_name=cond_name)
def WhileCond(i, n, *args):
del args
return i < n
body_name = "%s_Body" %
@function.Defun(*body_sig, func_name=body_name)
def WhileBody(i, n, start, delta, *args):
"""A While wrapper for forbody that handles loop-carried captured inputs."""
for_result = forbody(start + i * delta, *args)
# Nullary functions return an Operation. Normal functions can't do this
# because their return values are converted to Tensors.
if isinstance(for_result, ops.Operation):
for_result = ()
# Unary functions return a single Tensor value.
elif isinstance(for_result, ops.Tensor):
for_result = (for_result,)
return (i + 1, n, start, delta) + tuple(for_result)
if hostmem is not None:
hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
hostmem = [0, 1, 2, 3]
results = While(
input_=[0, n, start, delta] + inputs,
# Slice off the loop-carried captured inputs.
return list(results[4:len(results)])
def For(start,
r"""out = input; for i in range(start, limit, delta) out = body(i, out).
start: A `Tensor` of type `int32`.
limit: A `Tensor` of type `int32`.
delta: A `Tensor` of type `int32`.
inputs: A list of `Tensor` objects. A list of input tensors whose types are
body: A function takes a list of tensors and returns another list of
tensors. Both lists have the same types as (int32, T...).
name: A name for the operation (optional).
hostmem: A list of integer. If i is in the list, inputs[i] is a host memory
tensor. In other words, (i+1)-th argument of the body function is
expecting a host memory.
rewrite_with_while: If True, using While op to implement the For.
A list of `Tensor` objects. Has the same type as `input`.
A list of output tensors whose types are T.
if rewrite_with_while:
return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
if body.captured_inputs:
ret = gen_functional_ops._for(
inputs + body.captured_inputs,
# Slice off the loop-carried captured inputs.
ret = ret[:-len(body.captured_inputs)]
ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
if hostmem:
num_for_params = 3 # start/limit/delta
input_attr = attr_value_pb2.AttrValue()
input_attr.list.i.extend([num_for_params + i for i in hostmem])
ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access
output_attr = attr_value_pb2.AttrValue()
ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
return ret
# pylint: enable=invalid-name,protected-access
def partitioned_call(args,
"""Executes a function while respecting device annotations.
Currently, only those functions that execute within the same address space
can be executed.
args: The arguments of the function, including captured inputs.
f: The function to execute; an instance of `_DefinedFunction` or
tout: a list containing the output dtypes enums; if `None`, inferred from
the signature of `f`.
executing_eagerly: (Optional) A boolean indicating whether the context is
executing eagerly. If `None`, fetched from the global context.
config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`,
all optimizations are disabled. Currently only handled for eager defined
executor_type: (Optional) A string for the name of the executor to be used
in the function call. If not set, or set to an empty string, the default
tensorflow executor will be used.
The list of `Tensor`s returned by invoking `f(args)`. If the function does
not return anything, then returns `None` if eager execution is enabled, or
the `Operation` if not.
if tout is None:
tout = tuple(x.type for x in f.definition.signature.output_arg)
if executing_eagerly is None:
executing_eagerly = context.executing_eagerly()
if config is None:
config = function_utils.get_disabled_rewriter_config()
if executor_type is None:
executor_type = ""
if executing_eagerly:
if f.stateful_ops:
outputs = gen_functional_ops.stateful_partitioned_call(
outputs = gen_functional_ops.partitioned_call(
return outputs if outputs else None
# The generated binding returns an empty list for functions that don't
# return any Tensors, hence the need to use `create_op` directly.
args = [ops.convert_to_tensor(x) for x in args]
tin_attr = attr_value_pb2.AttrValue(
type=[x.dtype.as_datatype_enum for x in args]))
tout_attr = attr_value_pb2.AttrValue(
func_attr = attr_value_pb2.AttrValue(
executor_type_attr = attr_value_pb2.AttrValue(
# When running in graph mode, the graph and function graphs are optimized
# (i.e. run through grappler) per the session options, so we can disable any
# eager-specific rewriting.
config_proto = attr_value_pb2.AttrValue(s=config)
graph = ops.get_default_graph()
op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
# Propagate the attribute indicating the need to compile from function to the
# call itself.
xla_compile_attr = "_XlaMustCompile"
op_attrs = {
"Tin": tin_attr,
"Tout": tout_attr,
"f": func_attr,
"config_proto": config_proto,
"executor_type": executor_type_attr,
if xla_compile_attr in f.definition.attr:
op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
outputs = op.outputs
if hasattr(f, "graph"):
_set_read_only_resource_inputs_attr(op, f.graph)
if hasattr(f.graph, "collective_manager_ids_used"):
ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS,
return outputs if outputs else op
def _set_read_only_resource_inputs_attr(op, func_graph):
"""Sets the list of resource inputs which are read-only.
This is used by AutomaticControlDependencies.
op: PartitionedCall Operation.
func_graph: FuncGraph.
read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph)
ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,