3RNN/Lib/site-packages/tensorflow/python/ops/gen_functional_ops.py
2024-05-26 19:49:15 +02:00

1307 lines
55 KiB
Python

"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.security.fuzzing.py import annotation_types as _atypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar, List, Any
from typing_extensions import Annotated
def case(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes=[], name=None):
r"""An n-way switch statement which calls a single branch function.
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
Args:
branch_index: A `Tensor` of type `int32`.
The branch selector, an int32 Tensor.
input: A list of `Tensor` objects.
A list of input tensors passed to the branch function.
Tout: A list of `tf.DTypes`. A list of output types.
branches: A list of functions decorated with @Defun that has length `>= 1`.
A list of functions each of which takes 'inputs' and returns a list of
tensors, whose types are the same as what every other branch returns.
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "Case", name, branch_index, input, "Tout", Tout, "branches",
branches, "output_shapes", output_shapes)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return case_eager_fallback(
branch_index, input, Tout=Tout, branches=branches,
output_shapes=output_shapes, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'case' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if not isinstance(branches, (list, tuple)):
raise TypeError(
"Expected list for 'branches' argument to "
"'case' Op, not %r." % branches)
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'case' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"Case", branch_index=branch_index, input=input, Tout=Tout,
branches=branches, output_shapes=output_shapes, name=name)
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
"branches", _op.get_attr("branches"), "output_shapes",
_op.get_attr("output_shapes"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"Case", _inputs_flat, _attrs, _result)
return _result
Case = tf_export("raw_ops.Case")(_ops.to_raw_op(case))
def case_eager_fallback(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'case' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if not isinstance(branches, (list, tuple)):
raise TypeError(
"Expected list for 'branches' argument to "
"'case' Op, not %r." % branches)
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'case' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
branch_index = _ops.convert_to_tensor(branch_index, _dtypes.int32)
_inputs_flat = [branch_index] + list(input)
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "branches", branches,
"output_shapes", output_shapes)
_result = _execute.execute(b"Case", len(Tout), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"Case", _inputs_flat, _attrs, _result)
return _result
def device_index(device_names, name=None) -> Annotated[Any, _atypes.Int32]:
r"""Return the index of device the op runs.
Given a list of device names, this operation returns the index of the device
this op runs. The length of the list is returned in two cases:
(1) Device does not exist in the given device list.
(2) It is in XLA compilation.
Args:
device_names: A list of `strings`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `int32`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "DeviceIndex", name, "device_names", device_names)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return device_index_eager_fallback(
device_names=device_names, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(device_names, (list, tuple)):
raise TypeError(
"Expected list for 'device_names' argument to "
"'device_index' Op, not %r." % device_names)
device_names = [_execute.make_str(_s, "device_names") for _s in device_names]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"DeviceIndex", device_names=device_names, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("device_names", _op.get_attr("device_names"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"DeviceIndex", _inputs_flat, _attrs, _result)
_result, = _result
return _result
DeviceIndex = tf_export("raw_ops.DeviceIndex")(_ops.to_raw_op(device_index))
def device_index_eager_fallback(device_names, name, ctx) -> Annotated[Any, _atypes.Int32]:
if not isinstance(device_names, (list, tuple)):
raise TypeError(
"Expected list for 'device_names' argument to "
"'device_index' Op, not %r." % device_names)
device_names = [_execute.make_str(_s, "device_names") for _s in device_names]
_inputs_flat = []
_attrs = ("device_names", device_names)
_result = _execute.execute(b"DeviceIndex", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"DeviceIndex", _inputs_flat, _attrs, _result)
_result, = _result
return _result
TV_FakeParam_dtype = TypeVar("TV_FakeParam_dtype", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
def fake_param(dtype: TV_FakeParam_dtype, shape, name=None) -> Annotated[Any, TV_FakeParam_dtype]:
r""" This op is used as a placeholder in If branch functions. It doesn't provide a
valid output when run, so must either be removed (e.g. replaced with a
function input) or guaranteed not to be used (e.g. if mirroring an
intermediate output needed for the gradient computation of the other branch).
Args:
dtype: A `tf.DType`. The type of the output.
shape: A `tf.TensorShape` or list of `ints`.
The purported shape of the output. This is only used for shape inference;
the output will not necessarily have this shape. Can be a partial shape.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `dtype`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "FakeParam", name, "dtype", dtype, "shape", shape)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return fake_param_eager_fallback(
dtype=dtype, shape=shape, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
dtype = _execute.make_type(dtype, "dtype")
shape = _execute.make_shape(shape, "shape")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"FakeParam", dtype=dtype, shape=shape, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("dtype", _op._get_attr_type("dtype"), "shape",
_op.get_attr("shape"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"FakeParam", _inputs_flat, _attrs, _result)
_result, = _result
return _result
FakeParam = tf_export("raw_ops.FakeParam")(_ops.to_raw_op(fake_param))
def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, ctx) -> Annotated[Any, TV_FakeParam_dtype]:
dtype = _execute.make_type(dtype, "dtype")
shape = _execute.make_shape(shape, "shape")
_inputs_flat = []
_attrs = ("dtype", dtype, "shape", shape)
_result = _execute.execute(b"FakeParam", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"FakeParam", _inputs_flat, _attrs, _result)
_result, = _result
return _result
def _for(start: Annotated[Any, _atypes.Int32], limit: Annotated[Any, _atypes.Int32], delta: Annotated[Any, _atypes.Int32], input, body, name=None):
r"""Applies a for loop.
```python
output = input;
for i in range(start, limit, delta)
output = body(i, output);
```
Args:
start: A `Tensor` of type `int32`. The lower bound. An int32
limit: A `Tensor` of type `int32`. The upper bound. An int32
delta: A `Tensor` of type `int32`. The increment. An int32
input: A list of `Tensor` objects.
A list of input tensors whose types are T.
body: A function decorated with @Defun.
A function that takes a list of tensors (int32, T) and returns another
list of tensors (T).
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects. Has the same type as `input`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "For", name, start, limit, delta, input, "body", body)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return _for_eager_fallback(
start, limit, delta, input, body=body, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"For", start=start, limit=limit, delta=delta, input=input, body=body,
name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op.get_attr("T"), "body", _op.get_attr("body"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"For", _inputs_flat, _attrs, _result)
return _result
For = tf_export("raw_ops.For")(_ops.to_raw_op(_for))
def _for_eager_fallback(start: Annotated[Any, _atypes.Int32], limit: Annotated[Any, _atypes.Int32], delta: Annotated[Any, _atypes.Int32], input, body, name, ctx):
_attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
start = _ops.convert_to_tensor(start, _dtypes.int32)
limit = _ops.convert_to_tensor(limit, _dtypes.int32)
delta = _ops.convert_to_tensor(delta, _dtypes.int32)
_inputs_flat = [start, limit, delta] + list(input)
_attrs = ("T", _attr_T, "body", body)
_result = _execute.execute(b"For", len(input), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"For", _inputs_flat, _attrs, _result)
return _result
TV_If_Tcond = TypeVar("TV_If_Tcond", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
def _if(cond: Annotated[Any, TV_If_Tcond], input, Tout, then_branch, else_branch, output_shapes=[], name=None):
r"""output = cond ? then_branch(input) : else_branch(input)
Args:
cond: A `Tensor`.
A Tensor. If the tensor is a scalar of non-boolean type, the
scalar is converted to a boolean according to the
following rule: if the scalar is a numerical value, non-zero means
`True` and zero means False; if the scalar is a string, non-empty
means `True` and empty means `False`. If the tensor is not a scalar,
being empty means False and being non-empty means True.
input: A list of `Tensor` objects. A list of input tensors.
Tout: A list of `tf.DTypes`. A list of output types.
then_branch: A function decorated with @Defun.
A function that takes 'inputs' and returns a list of tensors, whose
types are the same as what else_branch returns.
else_branch: A function decorated with @Defun.
A function that takes 'inputs' and returns a list of tensors, whose
types are the same as what then_branch returns.
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "If", name, cond, input, "Tout", Tout, "then_branch",
then_branch, "else_branch", else_branch, "output_shapes",
output_shapes)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return _if_eager_fallback(
cond, input, Tout=Tout, then_branch=then_branch,
else_branch=else_branch, output_shapes=output_shapes, name=name,
ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'if' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'if' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"If", cond=cond, input=input, Tout=Tout, then_branch=then_branch,
else_branch=else_branch, output_shapes=output_shapes, name=name)
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("Tcond", _op._get_attr_type("Tcond"), "Tin",
_op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
"then_branch", _op.get_attr("then_branch"), "else_branch",
_op.get_attr("else_branch"), "output_shapes",
_op.get_attr("output_shapes"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"If", _inputs_flat, _attrs, _result)
return _result
If = tf_export("raw_ops.If")(_ops.to_raw_op(_if))
def _if_eager_fallback(cond: Annotated[Any, TV_If_Tcond], input, Tout, then_branch, else_branch, output_shapes, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'if' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'if' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, [])
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
_inputs_flat = [cond] + list(input)
_attrs = ("Tcond", _attr_Tcond, "Tin", _attr_Tin, "Tout", Tout,
"then_branch", then_branch, "else_branch", else_branch, "output_shapes",
output_shapes)
_result = _execute.execute(b"If", len(Tout), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"If", _inputs_flat, _attrs, _result)
return _result
def partitioned_call(args, Tout, f, config:str="", config_proto:str="", executor_type:str="", name=None):
r"""returns `f(inputs)`, where `f`'s body is placed and partitioned.
Asynchronously executes a function, potentially across multiple devices but
within a single process. The kernel places and partitions a given function's
underlying graph, and executes each of the partitioned subgraphs as a function.
Args:
args: A list of `Tensor` objects. A list of input tensors.
Tout: A list of `tf.DTypes`. A list of output types.
f: A function decorated with @Defun.
A function that takes 'args', a list of tensors, and returns 'output',
another list of tensors. Input and output types are specified by 'Tin'
and 'Tout'. The function body of f will be placed and partitioned across
devices, setting this op apart from the regular Call op.
config: An optional `string`. Defaults to `""`.
config_proto: An optional `string`. Defaults to `""`.
executor_type: An optional `string`. Defaults to `""`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "PartitionedCall", name, args, "Tout", Tout, "f", f, "config",
config, "config_proto", config_proto, "executor_type", executor_type)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return partitioned_call_eager_fallback(
args, Tout=Tout, f=f, config=config, config_proto=config_proto,
executor_type=executor_type, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'partitioned_call' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if config is None:
config = ""
config = _execute.make_str(config, "config")
if config_proto is None:
config_proto = ""
config_proto = _execute.make_str(config_proto, "config_proto")
if executor_type is None:
executor_type = ""
executor_type = _execute.make_str(executor_type, "executor_type")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"PartitionedCall", args=args, Tout=Tout, f=f, config=config,
config_proto=config_proto,
executor_type=executor_type, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
_op.get_attr("f"), "config", _op.get_attr("config"),
"config_proto", _op.get_attr("config_proto"), "executor_type",
_op.get_attr("executor_type"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"PartitionedCall", _inputs_flat, _attrs, _result)
return _result
PartitionedCall = tf_export("raw_ops.PartitionedCall")(_ops.to_raw_op(partitioned_call))
def partitioned_call_eager_fallback(args, Tout, f, config: str, config_proto: str, executor_type: str, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'partitioned_call' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if config is None:
config = ""
config = _execute.make_str(config, "config")
if config_proto is None:
config_proto = ""
config_proto = _execute.make_str(config_proto, "config_proto")
if executor_type is None:
executor_type = ""
executor_type = _execute.make_str(executor_type, "executor_type")
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
_inputs_flat = list(args)
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f, "config", config,
"config_proto", config_proto, "executor_type", executor_type)
_result = _execute.execute(b"PartitionedCall", len(Tout),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"PartitionedCall", _inputs_flat, _attrs, _result)
return _result
def remote_call(target: Annotated[Any, _atypes.String], args, Tout, f, name=None):
r"""Runs function `f` on a remote device indicated by `target`.
Args:
target: A `Tensor` of type `string`.
A fully specified device name where we want to run the function.
args: A list of `Tensor` objects. A list of arguments for the function.
Tout: A list of `tf.DTypes` that has length `>= 1`.
The type list for the return values.
f: A function decorated with @Defun. The function to run remotely.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "RemoteCall", name, target, args, "Tout", Tout, "f", f)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return remote_call_eager_fallback(
target, args, Tout=Tout, f=f, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'remote_call' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RemoteCall", target=target, args=args, Tout=Tout, f=f, name=name)
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
_op.get_attr("f"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RemoteCall", _inputs_flat, _attrs, _result)
return _result
RemoteCall = tf_export("raw_ops.RemoteCall")(_ops.to_raw_op(remote_call))
def remote_call_eager_fallback(target: Annotated[Any, _atypes.String], args, Tout, f, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'remote_call' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
target = _ops.convert_to_tensor(target, _dtypes.string)
_inputs_flat = [target] + list(args)
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f)
_result = _execute.execute(b"RemoteCall", len(Tout), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"RemoteCall", _inputs_flat, _attrs, _result)
return _result
def stateful_partitioned_call(args, Tout, f, config:str="", config_proto:str="", executor_type:str="", name=None):
r"""returns `f(inputs)`, where `f`'s body is placed and partitioned.
Args:
args: A list of `Tensor` objects. A list of input tensors.
Tout: A list of `tf.DTypes`. A list of output types.
f: A function decorated with @Defun.
A function that takes 'args', a list of tensors, and returns 'output',
another list of tensors. Input and output types are specified by 'Tin'
and 'Tout'. The function body of f will be placed and partitioned across
devices, setting this op apart from the regular Call op. This op is
stateful.
config: An optional `string`. Defaults to `""`.
config_proto: An optional `string`. Defaults to `""`.
executor_type: An optional `string`. Defaults to `""`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "StatefulPartitionedCall", name, args, "Tout", Tout, "f", f,
"config", config, "config_proto", config_proto, "executor_type",
executor_type)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return stateful_partitioned_call_eager_fallback(
args, Tout=Tout, f=f, config=config, config_proto=config_proto,
executor_type=executor_type, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'stateful_partitioned_call' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if config is None:
config = ""
config = _execute.make_str(config, "config")
if config_proto is None:
config_proto = ""
config_proto = _execute.make_str(config_proto, "config_proto")
if executor_type is None:
executor_type = ""
executor_type = _execute.make_str(executor_type, "executor_type")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"StatefulPartitionedCall", args=args, Tout=Tout, f=f, config=config,
config_proto=config_proto,
executor_type=executor_type, name=name)
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
_op.get_attr("f"), "config", _op.get_attr("config"),
"config_proto", _op.get_attr("config_proto"), "executor_type",
_op.get_attr("executor_type"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"StatefulPartitionedCall", _inputs_flat, _attrs, _result)
return _result
StatefulPartitionedCall = tf_export("raw_ops.StatefulPartitionedCall")(_ops.to_raw_op(stateful_partitioned_call))
def stateful_partitioned_call_eager_fallback(args, Tout, f, config: str, config_proto: str, executor_type: str, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'stateful_partitioned_call' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if config is None:
config = ""
config = _execute.make_str(config, "config")
if config_proto is None:
config_proto = ""
config_proto = _execute.make_str(config_proto, "config_proto")
if executor_type is None:
executor_type = ""
executor_type = _execute.make_str(executor_type, "executor_type")
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
_inputs_flat = list(args)
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f, "config", config,
"config_proto", config_proto, "executor_type", executor_type)
_result = _execute.execute(b"StatefulPartitionedCall", len(Tout),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"StatefulPartitionedCall", _inputs_flat, _attrs, _result)
return _result
def stateless_case(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes=[], name=None):
r"""An n-way switch statement which calls a single branch function.
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
This should only be used when the none of branches has stateful ops.
Args:
branch_index: A `Tensor` of type `int32`.
The branch selector, an int32 Tensor.
input: A list of `Tensor` objects.
A list of input tensors passed to the branch function.
Tout: A list of `tf.DTypes`. A list of output types.
branches: A list of functions decorated with @Defun that has length `>= 1`.
A list of functions each of which takes 'inputs' and returns a list of
tensors, whose types are the same as what every other branch returns.
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "StatelessCase", name, branch_index, input, "Tout", Tout,
"branches", branches, "output_shapes", output_shapes)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return stateless_case_eager_fallback(
branch_index, input, Tout=Tout, branches=branches,
output_shapes=output_shapes, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'stateless_case' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if not isinstance(branches, (list, tuple)):
raise TypeError(
"Expected list for 'branches' argument to "
"'stateless_case' Op, not %r." % branches)
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'stateless_case' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"StatelessCase", branch_index=branch_index, input=input, Tout=Tout,
branches=branches, output_shapes=output_shapes,
name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
"branches", _op.get_attr("branches"), "output_shapes",
_op.get_attr("output_shapes"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"StatelessCase", _inputs_flat, _attrs, _result)
return _result
StatelessCase = tf_export("raw_ops.StatelessCase")(_ops.to_raw_op(stateless_case))
def stateless_case_eager_fallback(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'stateless_case' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if not isinstance(branches, (list, tuple)):
raise TypeError(
"Expected list for 'branches' argument to "
"'stateless_case' Op, not %r." % branches)
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'stateless_case' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
branch_index = _ops.convert_to_tensor(branch_index, _dtypes.int32)
_inputs_flat = [branch_index] + list(input)
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "branches", branches,
"output_shapes", output_shapes)
_result = _execute.execute(b"StatelessCase", len(Tout), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"StatelessCase", _inputs_flat, _attrs, _result)
return _result
TV_StatelessIf_Tcond = TypeVar("TV_StatelessIf_Tcond", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
def stateless_if(cond: Annotated[Any, TV_StatelessIf_Tcond], input, Tout, then_branch, else_branch, output_shapes=[], name=None):
r"""output = cond ? then_branch(input) : else_branch(input)
Args:
cond: A `Tensor`.
A Tensor. If the tensor is a scalar of non-boolean type, the
scalar is converted to a boolean according to the
following rule: if the scalar is a numerical value, non-zero means
`True` and zero means False; if the scalar is a string, non-empty
means `True` and empty means `False`. If the tensor is not a scalar,
being empty means False and being non-empty means True.
This should only be used when the if then/else body functions do not
have stateful ops.
input: A list of `Tensor` objects. A list of input tensors.
Tout: A list of `tf.DTypes`. A list of output types.
then_branch: A function decorated with @Defun.
A function that takes 'inputs' and returns a list of tensors, whose
types are the same as what else_branch returns.
else_branch: A function decorated with @Defun.
A function that takes 'inputs' and returns a list of tensors, whose
types are the same as what then_branch returns.
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "StatelessIf", name, cond, input, "Tout", Tout, "then_branch",
then_branch, "else_branch", else_branch, "output_shapes",
output_shapes)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return stateless_if_eager_fallback(
cond, input, Tout=Tout, then_branch=then_branch,
else_branch=else_branch, output_shapes=output_shapes, name=name,
ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'stateless_if' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'stateless_if' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"StatelessIf", cond=cond, input=input, Tout=Tout,
then_branch=then_branch, else_branch=else_branch,
output_shapes=output_shapes, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("Tcond", _op._get_attr_type("Tcond"), "Tin",
_op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
"then_branch", _op.get_attr("then_branch"), "else_branch",
_op.get_attr("else_branch"), "output_shapes",
_op.get_attr("output_shapes"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"StatelessIf", _inputs_flat, _attrs, _result)
return _result
StatelessIf = tf_export("raw_ops.StatelessIf")(_ops.to_raw_op(stateless_if))
def stateless_if_eager_fallback(cond: Annotated[Any, TV_StatelessIf_Tcond], input, Tout, then_branch, else_branch, output_shapes, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'stateless_if' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'stateless_if' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
_attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, [])
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
_inputs_flat = [cond] + list(input)
_attrs = ("Tcond", _attr_Tcond, "Tin", _attr_Tin, "Tout", Tout,
"then_branch", then_branch, "else_branch", else_branch, "output_shapes",
output_shapes)
_result = _execute.execute(b"StatelessIf", len(Tout), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"StatelessIf", _inputs_flat, _attrs, _result)
return _result
def stateless_while(input, cond, body, output_shapes=[], parallel_iterations:int=10, name=None):
r"""output = input; While (Cond(output)) { output = Body(output) }
Args:
input: A list of `Tensor` objects.
A list of input tensors whose types are T.
cond: A function decorated with @Defun.
A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
This should only be used when the while condition and body functions
do not have stateful ops.
body: A function decorated with @Defun.
A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified
by T.
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
parallel_iterations: An optional `int`. Defaults to `10`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects. Has the same type as `input`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "StatelessWhile", name, input, "cond", cond, "body", body,
"output_shapes", output_shapes, "parallel_iterations",
parallel_iterations)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return stateless_while_eager_fallback(
input, cond=cond, body=body, output_shapes=output_shapes,
parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'stateless_while' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"StatelessWhile", input=input, cond=cond, body=body,
output_shapes=output_shapes,
parallel_iterations=parallel_iterations, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body",
_op.get_attr("body"), "output_shapes",
_op.get_attr("output_shapes"), "parallel_iterations",
_op._get_attr_int("parallel_iterations"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"StatelessWhile", _inputs_flat, _attrs, _result)
return _result
StatelessWhile = tf_export("raw_ops.StatelessWhile")(_ops.to_raw_op(stateless_while))
def stateless_while_eager_fallback(input, cond, body, output_shapes, parallel_iterations: int, name, ctx):
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'stateless_while' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
_inputs_flat = list(input)
_attrs = ("T", _attr_T, "cond", cond, "body", body, "output_shapes",
output_shapes, "parallel_iterations", parallel_iterations)
_result = _execute.execute(b"StatelessWhile", len(input),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"StatelessWhile", _inputs_flat, _attrs, _result)
return _result
def symbolic_gradient(input, Tout, f, name=None):
r"""Computes the gradient function for function f via backpropagation.
Args:
input: A list of `Tensor` objects. a list of input tensors of size N + M;
Tout: A list of `tf.DTypes` that has length `>= 1`.
the type list for the input list.
f: A function decorated with @Defun.
The function we want to compute the gradient for.
The function 'f' must be a numerical function which takes N inputs and
produces M outputs. Its gradient function 'g', which is computed by
this SymbolicGradient op is a function taking N + M inputs and
produces N outputs.
I.e. if we have
(y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
then, g is
(dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
dL/dy1, dL/dy2, ..., dL/dy_M),
where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
loss function). dL/dx_i is the partial derivative of L with respect
to x_i.
(Needs some math expert to say the comment above better.)
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tout`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "SymbolicGradient", name, input, "Tout", Tout, "f", f)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return symbolic_gradient_eager_fallback(
input, Tout=Tout, f=f, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'symbolic_gradient' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"SymbolicGradient", input=input, Tout=Tout, f=f, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
_op.get_attr("f"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"SymbolicGradient", _inputs_flat, _attrs, _result)
return _result
SymbolicGradient = tf_export("raw_ops.SymbolicGradient")(_ops.to_raw_op(symbolic_gradient))
def symbolic_gradient_eager_fallback(input, Tout, f, name, ctx):
if not isinstance(Tout, (list, tuple)):
raise TypeError(
"Expected list for 'Tout' argument to "
"'symbolic_gradient' Op, not %r." % Tout)
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
_inputs_flat = list(input)
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f)
_result = _execute.execute(b"SymbolicGradient", len(Tout),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"SymbolicGradient", _inputs_flat, _attrs, _result)
return _result
TV_ToBool_T = TypeVar("TV_ToBool_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
def to_bool(input: Annotated[Any, TV_ToBool_T], name=None) -> Annotated[Any, _atypes.Bool]:
r"""Converts a tensor to a scalar predicate.
Converts a tensor to a scalar predicate with the following rules:
- For 0D tensors, truthiness is determined by comparing against a "zero"
value. For numerical types it is the obvious zero. For strings it is the
empty string.
- For >0D tensors, truthiness is determined by looking at the number of
elements. If has zero elements, then the result is false. Otherwise the
result is true.
This matches the behavior of If and While for determining if a tensor counts
as true/false for a branch condition.
Args:
input: A `Tensor`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `bool`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "ToBool", name, input)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return to_bool_eager_fallback(
input, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"ToBool", input=input, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"ToBool", _inputs_flat, _attrs, _result)
_result, = _result
return _result
ToBool = tf_export("raw_ops.ToBool")(_ops.to_raw_op(to_bool))
def to_bool_eager_fallback(input: Annotated[Any, TV_ToBool_T], name, ctx) -> Annotated[Any, _atypes.Bool]:
_attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [])
_inputs_flat = [input]
_attrs = ("T", _attr_T)
_result = _execute.execute(b"ToBool", 1, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"ToBool", _inputs_flat, _attrs, _result)
_result, = _result
return _result
def _while(input, cond, body, output_shapes=[], parallel_iterations:int=10, name=None):
r"""output = input; While (Cond(output)) { output = Body(output) }
Args:
input: A list of `Tensor` objects.
A list of input tensors whose types are T.
cond: A function decorated with @Defun.
A function takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
body: A function decorated with @Defun.
A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified
by T.
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
parallel_iterations: An optional `int`. Defaults to `10`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects. Has the same type as `input`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "While", name, input, "cond", cond, "body", body,
"output_shapes", output_shapes, "parallel_iterations",
parallel_iterations)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return _while_eager_fallback(
input, cond=cond, body=body, output_shapes=output_shapes,
parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'while' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"While", input=input, cond=cond, body=body,
output_shapes=output_shapes,
parallel_iterations=parallel_iterations, name=name)
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body",
_op.get_attr("body"), "output_shapes",
_op.get_attr("output_shapes"), "parallel_iterations",
_op._get_attr_int("parallel_iterations"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"While", _inputs_flat, _attrs, _result)
return _result
While = tf_export("raw_ops.While")(_ops.to_raw_op(_while))
def _while_eager_fallback(input, cond, body, output_shapes, parallel_iterations: int, name, ctx):
if output_shapes is None:
output_shapes = []
if not isinstance(output_shapes, (list, tuple)):
raise TypeError(
"Expected list for 'output_shapes' argument to "
"'while' Op, not %r." % output_shapes)
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
_inputs_flat = list(input)
_attrs = ("T", _attr_T, "cond", cond, "body", body, "output_shapes",
output_shapes, "parallel_iterations", parallel_iterations)
_result = _execute.execute(b"While", len(input), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"While", _inputs_flat, _attrs, _result)
return _result