1307 lines
55 KiB
Python
1307 lines
55 KiB
Python
"""Python wrappers around TensorFlow ops.
|
|
|
|
This file is MACHINE GENERATED! Do not edit.
|
|
"""
|
|
|
|
import collections
|
|
|
|
from tensorflow.python import pywrap_tfe as pywrap_tfe
|
|
from tensorflow.python.eager import context as _context
|
|
from tensorflow.python.eager import core as _core
|
|
from tensorflow.python.eager import execute as _execute
|
|
from tensorflow.python.framework import dtypes as _dtypes
|
|
from tensorflow.security.fuzzing.py import annotation_types as _atypes
|
|
|
|
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
|
from tensorflow.python.framework import ops as _ops
|
|
from tensorflow.python.framework import op_def_library as _op_def_library
|
|
from tensorflow.python.util.deprecation import deprecated_endpoints
|
|
from tensorflow.python.util import dispatch as _dispatch
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
from typing import TypeVar, List, Any
|
|
from typing_extensions import Annotated
|
|
|
|
def case(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes=[], name=None):
|
|
r"""An n-way switch statement which calls a single branch function.
|
|
|
|
An n-way switch statement, implementing the following:
|
|
```
|
|
switch (branch_index) {
|
|
case 0:
|
|
output = branches[0](input);
|
|
break;
|
|
case 1:
|
|
output = branches[1](input);
|
|
break;
|
|
...
|
|
case [[nbranches-1]]:
|
|
default:
|
|
output = branches[nbranches-1](input);
|
|
break;
|
|
}
|
|
```
|
|
|
|
Args:
|
|
branch_index: A `Tensor` of type `int32`.
|
|
The branch selector, an int32 Tensor.
|
|
input: A list of `Tensor` objects.
|
|
A list of input tensors passed to the branch function.
|
|
Tout: A list of `tf.DTypes`. A list of output types.
|
|
branches: A list of functions decorated with @Defun that has length `>= 1`.
|
|
A list of functions each of which takes 'inputs' and returns a list of
|
|
tensors, whose types are the same as what every other branch returns.
|
|
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "Case", name, branch_index, input, "Tout", Tout, "branches",
|
|
branches, "output_shapes", output_shapes)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return case_eager_fallback(
|
|
branch_index, input, Tout=Tout, branches=branches,
|
|
output_shapes=output_shapes, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'case' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if not isinstance(branches, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'branches' argument to "
|
|
"'case' Op, not %r." % branches)
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'case' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"Case", branch_index=branch_index, input=input, Tout=Tout,
|
|
branches=branches, output_shapes=output_shapes, name=name)
|
|
_result = _outputs[:]
|
|
if not _result:
|
|
return _op
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
|
|
"branches", _op.get_attr("branches"), "output_shapes",
|
|
_op.get_attr("output_shapes"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"Case", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
Case = tf_export("raw_ops.Case")(_ops.to_raw_op(case))
|
|
|
|
|
|
def case_eager_fallback(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'case' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if not isinstance(branches, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'branches' argument to "
|
|
"'case' Op, not %r." % branches)
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'case' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
branch_index = _ops.convert_to_tensor(branch_index, _dtypes.int32)
|
|
_inputs_flat = [branch_index] + list(input)
|
|
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "branches", branches,
|
|
"output_shapes", output_shapes)
|
|
_result = _execute.execute(b"Case", len(Tout), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"Case", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def device_index(device_names, name=None) -> Annotated[Any, _atypes.Int32]:
|
|
r"""Return the index of device the op runs.
|
|
|
|
Given a list of device names, this operation returns the index of the device
|
|
this op runs. The length of the list is returned in two cases:
|
|
(1) Device does not exist in the given device list.
|
|
(2) It is in XLA compilation.
|
|
|
|
Args:
|
|
device_names: A list of `strings`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of type `int32`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "DeviceIndex", name, "device_names", device_names)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return device_index_eager_fallback(
|
|
device_names=device_names, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(device_names, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'device_names' argument to "
|
|
"'device_index' Op, not %r." % device_names)
|
|
device_names = [_execute.make_str(_s, "device_names") for _s in device_names]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"DeviceIndex", device_names=device_names, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("device_names", _op.get_attr("device_names"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"DeviceIndex", _inputs_flat, _attrs, _result)
|
|
_result, = _result
|
|
return _result
|
|
|
|
DeviceIndex = tf_export("raw_ops.DeviceIndex")(_ops.to_raw_op(device_index))
|
|
|
|
|
|
def device_index_eager_fallback(device_names, name, ctx) -> Annotated[Any, _atypes.Int32]:
|
|
if not isinstance(device_names, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'device_names' argument to "
|
|
"'device_index' Op, not %r." % device_names)
|
|
device_names = [_execute.make_str(_s, "device_names") for _s in device_names]
|
|
_inputs_flat = []
|
|
_attrs = ("device_names", device_names)
|
|
_result = _execute.execute(b"DeviceIndex", 1, inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"DeviceIndex", _inputs_flat, _attrs, _result)
|
|
_result, = _result
|
|
return _result
|
|
|
|
|
|
TV_FakeParam_dtype = TypeVar("TV_FakeParam_dtype", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
|
|
|
|
def fake_param(dtype: TV_FakeParam_dtype, shape, name=None) -> Annotated[Any, TV_FakeParam_dtype]:
|
|
r""" This op is used as a placeholder in If branch functions. It doesn't provide a
|
|
valid output when run, so must either be removed (e.g. replaced with a
|
|
function input) or guaranteed not to be used (e.g. if mirroring an
|
|
intermediate output needed for the gradient computation of the other branch).
|
|
|
|
Args:
|
|
dtype: A `tf.DType`. The type of the output.
|
|
shape: A `tf.TensorShape` or list of `ints`.
|
|
The purported shape of the output. This is only used for shape inference;
|
|
the output will not necessarily have this shape. Can be a partial shape.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of type `dtype`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "FakeParam", name, "dtype", dtype, "shape", shape)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return fake_param_eager_fallback(
|
|
dtype=dtype, shape=shape, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
dtype = _execute.make_type(dtype, "dtype")
|
|
shape = _execute.make_shape(shape, "shape")
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"FakeParam", dtype=dtype, shape=shape, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("dtype", _op._get_attr_type("dtype"), "shape",
|
|
_op.get_attr("shape"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"FakeParam", _inputs_flat, _attrs, _result)
|
|
_result, = _result
|
|
return _result
|
|
|
|
FakeParam = tf_export("raw_ops.FakeParam")(_ops.to_raw_op(fake_param))
|
|
|
|
|
|
def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, ctx) -> Annotated[Any, TV_FakeParam_dtype]:
|
|
dtype = _execute.make_type(dtype, "dtype")
|
|
shape = _execute.make_shape(shape, "shape")
|
|
_inputs_flat = []
|
|
_attrs = ("dtype", dtype, "shape", shape)
|
|
_result = _execute.execute(b"FakeParam", 1, inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"FakeParam", _inputs_flat, _attrs, _result)
|
|
_result, = _result
|
|
return _result
|
|
|
|
|
|
def _for(start: Annotated[Any, _atypes.Int32], limit: Annotated[Any, _atypes.Int32], delta: Annotated[Any, _atypes.Int32], input, body, name=None):
|
|
r"""Applies a for loop.
|
|
|
|
```python
|
|
output = input;
|
|
for i in range(start, limit, delta)
|
|
output = body(i, output);
|
|
```
|
|
|
|
Args:
|
|
start: A `Tensor` of type `int32`. The lower bound. An int32
|
|
limit: A `Tensor` of type `int32`. The upper bound. An int32
|
|
delta: A `Tensor` of type `int32`. The increment. An int32
|
|
input: A list of `Tensor` objects.
|
|
A list of input tensors whose types are T.
|
|
body: A function decorated with @Defun.
|
|
A function that takes a list of tensors (int32, T) and returns another
|
|
list of tensors (T).
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects. Has the same type as `input`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "For", name, start, limit, delta, input, "body", body)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return _for_eager_fallback(
|
|
start, limit, delta, input, body=body, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"For", start=start, limit=limit, delta=delta, input=input, body=body,
|
|
name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("T", _op.get_attr("T"), "body", _op.get_attr("body"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"For", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
For = tf_export("raw_ops.For")(_ops.to_raw_op(_for))
|
|
|
|
|
|
def _for_eager_fallback(start: Annotated[Any, _atypes.Int32], limit: Annotated[Any, _atypes.Int32], delta: Annotated[Any, _atypes.Int32], input, body, name, ctx):
|
|
_attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
start = _ops.convert_to_tensor(start, _dtypes.int32)
|
|
limit = _ops.convert_to_tensor(limit, _dtypes.int32)
|
|
delta = _ops.convert_to_tensor(delta, _dtypes.int32)
|
|
_inputs_flat = [start, limit, delta] + list(input)
|
|
_attrs = ("T", _attr_T, "body", body)
|
|
_result = _execute.execute(b"For", len(input), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"For", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
TV_If_Tcond = TypeVar("TV_If_Tcond", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
|
|
|
|
def _if(cond: Annotated[Any, TV_If_Tcond], input, Tout, then_branch, else_branch, output_shapes=[], name=None):
|
|
r"""output = cond ? then_branch(input) : else_branch(input)
|
|
|
|
Args:
|
|
cond: A `Tensor`.
|
|
A Tensor. If the tensor is a scalar of non-boolean type, the
|
|
scalar is converted to a boolean according to the
|
|
following rule: if the scalar is a numerical value, non-zero means
|
|
`True` and zero means False; if the scalar is a string, non-empty
|
|
means `True` and empty means `False`. If the tensor is not a scalar,
|
|
being empty means False and being non-empty means True.
|
|
input: A list of `Tensor` objects. A list of input tensors.
|
|
Tout: A list of `tf.DTypes`. A list of output types.
|
|
then_branch: A function decorated with @Defun.
|
|
A function that takes 'inputs' and returns a list of tensors, whose
|
|
types are the same as what else_branch returns.
|
|
else_branch: A function decorated with @Defun.
|
|
A function that takes 'inputs' and returns a list of tensors, whose
|
|
types are the same as what then_branch returns.
|
|
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "If", name, cond, input, "Tout", Tout, "then_branch",
|
|
then_branch, "else_branch", else_branch, "output_shapes",
|
|
output_shapes)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return _if_eager_fallback(
|
|
cond, input, Tout=Tout, then_branch=then_branch,
|
|
else_branch=else_branch, output_shapes=output_shapes, name=name,
|
|
ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'if' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'if' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"If", cond=cond, input=input, Tout=Tout, then_branch=then_branch,
|
|
else_branch=else_branch, output_shapes=output_shapes, name=name)
|
|
_result = _outputs[:]
|
|
if not _result:
|
|
return _op
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tcond", _op._get_attr_type("Tcond"), "Tin",
|
|
_op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
|
|
"then_branch", _op.get_attr("then_branch"), "else_branch",
|
|
_op.get_attr("else_branch"), "output_shapes",
|
|
_op.get_attr("output_shapes"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"If", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
If = tf_export("raw_ops.If")(_ops.to_raw_op(_if))
|
|
|
|
|
|
def _if_eager_fallback(cond: Annotated[Any, TV_If_Tcond], input, Tout, then_branch, else_branch, output_shapes, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'if' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'if' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, [])
|
|
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
_inputs_flat = [cond] + list(input)
|
|
_attrs = ("Tcond", _attr_Tcond, "Tin", _attr_Tin, "Tout", Tout,
|
|
"then_branch", then_branch, "else_branch", else_branch, "output_shapes",
|
|
output_shapes)
|
|
_result = _execute.execute(b"If", len(Tout), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"If", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def partitioned_call(args, Tout, f, config:str="", config_proto:str="", executor_type:str="", name=None):
|
|
r"""returns `f(inputs)`, where `f`'s body is placed and partitioned.
|
|
|
|
Asynchronously executes a function, potentially across multiple devices but
|
|
within a single process. The kernel places and partitions a given function's
|
|
underlying graph, and executes each of the partitioned subgraphs as a function.
|
|
|
|
Args:
|
|
args: A list of `Tensor` objects. A list of input tensors.
|
|
Tout: A list of `tf.DTypes`. A list of output types.
|
|
f: A function decorated with @Defun.
|
|
A function that takes 'args', a list of tensors, and returns 'output',
|
|
another list of tensors. Input and output types are specified by 'Tin'
|
|
and 'Tout'. The function body of f will be placed and partitioned across
|
|
devices, setting this op apart from the regular Call op.
|
|
config: An optional `string`. Defaults to `""`.
|
|
config_proto: An optional `string`. Defaults to `""`.
|
|
executor_type: An optional `string`. Defaults to `""`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "PartitionedCall", name, args, "Tout", Tout, "f", f, "config",
|
|
config, "config_proto", config_proto, "executor_type", executor_type)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return partitioned_call_eager_fallback(
|
|
args, Tout=Tout, f=f, config=config, config_proto=config_proto,
|
|
executor_type=executor_type, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'partitioned_call' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if config is None:
|
|
config = ""
|
|
config = _execute.make_str(config, "config")
|
|
if config_proto is None:
|
|
config_proto = ""
|
|
config_proto = _execute.make_str(config_proto, "config_proto")
|
|
if executor_type is None:
|
|
executor_type = ""
|
|
executor_type = _execute.make_str(executor_type, "executor_type")
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"PartitionedCall", args=args, Tout=Tout, f=f, config=config,
|
|
config_proto=config_proto,
|
|
executor_type=executor_type, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
|
|
_op.get_attr("f"), "config", _op.get_attr("config"),
|
|
"config_proto", _op.get_attr("config_proto"), "executor_type",
|
|
_op.get_attr("executor_type"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"PartitionedCall", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
PartitionedCall = tf_export("raw_ops.PartitionedCall")(_ops.to_raw_op(partitioned_call))
|
|
|
|
|
|
def partitioned_call_eager_fallback(args, Tout, f, config: str, config_proto: str, executor_type: str, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'partitioned_call' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if config is None:
|
|
config = ""
|
|
config = _execute.make_str(config, "config")
|
|
if config_proto is None:
|
|
config_proto = ""
|
|
config_proto = _execute.make_str(config_proto, "config_proto")
|
|
if executor_type is None:
|
|
executor_type = ""
|
|
executor_type = _execute.make_str(executor_type, "executor_type")
|
|
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
|
|
_inputs_flat = list(args)
|
|
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f, "config", config,
|
|
"config_proto", config_proto, "executor_type", executor_type)
|
|
_result = _execute.execute(b"PartitionedCall", len(Tout),
|
|
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
|
|
name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"PartitionedCall", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def remote_call(target: Annotated[Any, _atypes.String], args, Tout, f, name=None):
|
|
r"""Runs function `f` on a remote device indicated by `target`.
|
|
|
|
Args:
|
|
target: A `Tensor` of type `string`.
|
|
A fully specified device name where we want to run the function.
|
|
args: A list of `Tensor` objects. A list of arguments for the function.
|
|
Tout: A list of `tf.DTypes` that has length `>= 1`.
|
|
The type list for the return values.
|
|
f: A function decorated with @Defun. The function to run remotely.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "RemoteCall", name, target, args, "Tout", Tout, "f", f)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return remote_call_eager_fallback(
|
|
target, args, Tout=Tout, f=f, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'remote_call' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"RemoteCall", target=target, args=args, Tout=Tout, f=f, name=name)
|
|
_result = _outputs[:]
|
|
if not _result:
|
|
return _op
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
|
|
_op.get_attr("f"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"RemoteCall", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
RemoteCall = tf_export("raw_ops.RemoteCall")(_ops.to_raw_op(remote_call))
|
|
|
|
|
|
def remote_call_eager_fallback(target: Annotated[Any, _atypes.String], args, Tout, f, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'remote_call' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
|
|
target = _ops.convert_to_tensor(target, _dtypes.string)
|
|
_inputs_flat = [target] + list(args)
|
|
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f)
|
|
_result = _execute.execute(b"RemoteCall", len(Tout), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"RemoteCall", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def stateful_partitioned_call(args, Tout, f, config:str="", config_proto:str="", executor_type:str="", name=None):
|
|
r"""returns `f(inputs)`, where `f`'s body is placed and partitioned.
|
|
|
|
Args:
|
|
args: A list of `Tensor` objects. A list of input tensors.
|
|
Tout: A list of `tf.DTypes`. A list of output types.
|
|
f: A function decorated with @Defun.
|
|
A function that takes 'args', a list of tensors, and returns 'output',
|
|
another list of tensors. Input and output types are specified by 'Tin'
|
|
and 'Tout'. The function body of f will be placed and partitioned across
|
|
devices, setting this op apart from the regular Call op. This op is
|
|
stateful.
|
|
config: An optional `string`. Defaults to `""`.
|
|
config_proto: An optional `string`. Defaults to `""`.
|
|
executor_type: An optional `string`. Defaults to `""`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "StatefulPartitionedCall", name, args, "Tout", Tout, "f", f,
|
|
"config", config, "config_proto", config_proto, "executor_type",
|
|
executor_type)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return stateful_partitioned_call_eager_fallback(
|
|
args, Tout=Tout, f=f, config=config, config_proto=config_proto,
|
|
executor_type=executor_type, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'stateful_partitioned_call' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if config is None:
|
|
config = ""
|
|
config = _execute.make_str(config, "config")
|
|
if config_proto is None:
|
|
config_proto = ""
|
|
config_proto = _execute.make_str(config_proto, "config_proto")
|
|
if executor_type is None:
|
|
executor_type = ""
|
|
executor_type = _execute.make_str(executor_type, "executor_type")
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"StatefulPartitionedCall", args=args, Tout=Tout, f=f, config=config,
|
|
config_proto=config_proto,
|
|
executor_type=executor_type, name=name)
|
|
_result = _outputs[:]
|
|
if not _result:
|
|
return _op
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
|
|
_op.get_attr("f"), "config", _op.get_attr("config"),
|
|
"config_proto", _op.get_attr("config_proto"), "executor_type",
|
|
_op.get_attr("executor_type"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"StatefulPartitionedCall", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
StatefulPartitionedCall = tf_export("raw_ops.StatefulPartitionedCall")(_ops.to_raw_op(stateful_partitioned_call))
|
|
|
|
|
|
def stateful_partitioned_call_eager_fallback(args, Tout, f, config: str, config_proto: str, executor_type: str, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'stateful_partitioned_call' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if config is None:
|
|
config = ""
|
|
config = _execute.make_str(config, "config")
|
|
if config_proto is None:
|
|
config_proto = ""
|
|
config_proto = _execute.make_str(config_proto, "config_proto")
|
|
if executor_type is None:
|
|
executor_type = ""
|
|
executor_type = _execute.make_str(executor_type, "executor_type")
|
|
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
|
|
_inputs_flat = list(args)
|
|
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f, "config", config,
|
|
"config_proto", config_proto, "executor_type", executor_type)
|
|
_result = _execute.execute(b"StatefulPartitionedCall", len(Tout),
|
|
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
|
|
name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"StatefulPartitionedCall", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def stateless_case(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes=[], name=None):
|
|
r"""An n-way switch statement which calls a single branch function.
|
|
|
|
An n-way switch statement, implementing the following:
|
|
```
|
|
switch (branch_index) {
|
|
case 0:
|
|
output = branches[0](input);
|
|
break;
|
|
case 1:
|
|
output = branches[1](input);
|
|
break;
|
|
...
|
|
case [[nbranches-1]]:
|
|
default:
|
|
output = branches[nbranches-1](input);
|
|
break;
|
|
}
|
|
```
|
|
|
|
This should only be used when the none of branches has stateful ops.
|
|
|
|
Args:
|
|
branch_index: A `Tensor` of type `int32`.
|
|
The branch selector, an int32 Tensor.
|
|
input: A list of `Tensor` objects.
|
|
A list of input tensors passed to the branch function.
|
|
Tout: A list of `tf.DTypes`. A list of output types.
|
|
branches: A list of functions decorated with @Defun that has length `>= 1`.
|
|
A list of functions each of which takes 'inputs' and returns a list of
|
|
tensors, whose types are the same as what every other branch returns.
|
|
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "StatelessCase", name, branch_index, input, "Tout", Tout,
|
|
"branches", branches, "output_shapes", output_shapes)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return stateless_case_eager_fallback(
|
|
branch_index, input, Tout=Tout, branches=branches,
|
|
output_shapes=output_shapes, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'stateless_case' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if not isinstance(branches, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'branches' argument to "
|
|
"'stateless_case' Op, not %r." % branches)
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'stateless_case' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"StatelessCase", branch_index=branch_index, input=input, Tout=Tout,
|
|
branches=branches, output_shapes=output_shapes,
|
|
name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
|
|
"branches", _op.get_attr("branches"), "output_shapes",
|
|
_op.get_attr("output_shapes"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"StatelessCase", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
StatelessCase = tf_export("raw_ops.StatelessCase")(_ops.to_raw_op(stateless_case))
|
|
|
|
|
|
def stateless_case_eager_fallback(branch_index: Annotated[Any, _atypes.Int32], input, Tout, branches, output_shapes, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'stateless_case' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if not isinstance(branches, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'branches' argument to "
|
|
"'stateless_case' Op, not %r." % branches)
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'stateless_case' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
branch_index = _ops.convert_to_tensor(branch_index, _dtypes.int32)
|
|
_inputs_flat = [branch_index] + list(input)
|
|
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "branches", branches,
|
|
"output_shapes", output_shapes)
|
|
_result = _execute.execute(b"StatelessCase", len(Tout), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"StatelessCase", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
TV_StatelessIf_Tcond = TypeVar("TV_StatelessIf_Tcond", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
|
|
|
|
def stateless_if(cond: Annotated[Any, TV_StatelessIf_Tcond], input, Tout, then_branch, else_branch, output_shapes=[], name=None):
|
|
r"""output = cond ? then_branch(input) : else_branch(input)
|
|
|
|
Args:
|
|
cond: A `Tensor`.
|
|
A Tensor. If the tensor is a scalar of non-boolean type, the
|
|
scalar is converted to a boolean according to the
|
|
following rule: if the scalar is a numerical value, non-zero means
|
|
`True` and zero means False; if the scalar is a string, non-empty
|
|
means `True` and empty means `False`. If the tensor is not a scalar,
|
|
being empty means False and being non-empty means True.
|
|
|
|
This should only be used when the if then/else body functions do not
|
|
have stateful ops.
|
|
input: A list of `Tensor` objects. A list of input tensors.
|
|
Tout: A list of `tf.DTypes`. A list of output types.
|
|
then_branch: A function decorated with @Defun.
|
|
A function that takes 'inputs' and returns a list of tensors, whose
|
|
types are the same as what else_branch returns.
|
|
else_branch: A function decorated with @Defun.
|
|
A function that takes 'inputs' and returns a list of tensors, whose
|
|
types are the same as what then_branch returns.
|
|
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "StatelessIf", name, cond, input, "Tout", Tout, "then_branch",
|
|
then_branch, "else_branch", else_branch, "output_shapes",
|
|
output_shapes)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return stateless_if_eager_fallback(
|
|
cond, input, Tout=Tout, then_branch=then_branch,
|
|
else_branch=else_branch, output_shapes=output_shapes, name=name,
|
|
ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'stateless_if' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'stateless_if' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"StatelessIf", cond=cond, input=input, Tout=Tout,
|
|
then_branch=then_branch, else_branch=else_branch,
|
|
output_shapes=output_shapes, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tcond", _op._get_attr_type("Tcond"), "Tin",
|
|
_op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
|
|
"then_branch", _op.get_attr("then_branch"), "else_branch",
|
|
_op.get_attr("else_branch"), "output_shapes",
|
|
_op.get_attr("output_shapes"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"StatelessIf", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
StatelessIf = tf_export("raw_ops.StatelessIf")(_ops.to_raw_op(stateless_if))
|
|
|
|
|
|
def stateless_if_eager_fallback(cond: Annotated[Any, TV_StatelessIf_Tcond], input, Tout, then_branch, else_branch, output_shapes, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'stateless_if' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'stateless_if' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
_attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, [])
|
|
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
_inputs_flat = [cond] + list(input)
|
|
_attrs = ("Tcond", _attr_Tcond, "Tin", _attr_Tin, "Tout", Tout,
|
|
"then_branch", then_branch, "else_branch", else_branch, "output_shapes",
|
|
output_shapes)
|
|
_result = _execute.execute(b"StatelessIf", len(Tout), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"StatelessIf", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def stateless_while(input, cond, body, output_shapes=[], parallel_iterations:int=10, name=None):
|
|
r"""output = input; While (Cond(output)) { output = Body(output) }
|
|
|
|
Args:
|
|
input: A list of `Tensor` objects.
|
|
A list of input tensors whose types are T.
|
|
cond: A function decorated with @Defun.
|
|
A function takes 'input' and returns a tensor. If the tensor is
|
|
a scalar of non-boolean, the scalar is converted to a boolean
|
|
according to the following rule: if the scalar is a numerical
|
|
value, non-zero means True and zero means False; if the scalar is
|
|
a string, non-empty means True and empty means False. If the
|
|
tensor is not a scalar, non-emptiness means True and False
|
|
otherwise.
|
|
|
|
This should only be used when the while condition and body functions
|
|
do not have stateful ops.
|
|
body: A function decorated with @Defun.
|
|
A function that takes a list of tensors and returns another
|
|
list of tensors. Both lists have the same types as specified
|
|
by T.
|
|
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
|
|
parallel_iterations: An optional `int`. Defaults to `10`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects. Has the same type as `input`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "StatelessWhile", name, input, "cond", cond, "body", body,
|
|
"output_shapes", output_shapes, "parallel_iterations",
|
|
parallel_iterations)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return stateless_while_eager_fallback(
|
|
input, cond=cond, body=body, output_shapes=output_shapes,
|
|
parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'stateless_while' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
if parallel_iterations is None:
|
|
parallel_iterations = 10
|
|
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"StatelessWhile", input=input, cond=cond, body=body,
|
|
output_shapes=output_shapes,
|
|
parallel_iterations=parallel_iterations, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body",
|
|
_op.get_attr("body"), "output_shapes",
|
|
_op.get_attr("output_shapes"), "parallel_iterations",
|
|
_op._get_attr_int("parallel_iterations"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"StatelessWhile", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
StatelessWhile = tf_export("raw_ops.StatelessWhile")(_ops.to_raw_op(stateless_while))
|
|
|
|
|
|
def stateless_while_eager_fallback(input, cond, body, output_shapes, parallel_iterations: int, name, ctx):
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'stateless_while' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
if parallel_iterations is None:
|
|
parallel_iterations = 10
|
|
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
|
|
_attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
_inputs_flat = list(input)
|
|
_attrs = ("T", _attr_T, "cond", cond, "body", body, "output_shapes",
|
|
output_shapes, "parallel_iterations", parallel_iterations)
|
|
_result = _execute.execute(b"StatelessWhile", len(input),
|
|
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
|
|
name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"StatelessWhile", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
def symbolic_gradient(input, Tout, f, name=None):
|
|
r"""Computes the gradient function for function f via backpropagation.
|
|
|
|
Args:
|
|
input: A list of `Tensor` objects. a list of input tensors of size N + M;
|
|
Tout: A list of `tf.DTypes` that has length `>= 1`.
|
|
the type list for the input list.
|
|
f: A function decorated with @Defun.
|
|
The function we want to compute the gradient for.
|
|
|
|
The function 'f' must be a numerical function which takes N inputs and
|
|
produces M outputs. Its gradient function 'g', which is computed by
|
|
this SymbolicGradient op is a function taking N + M inputs and
|
|
produces N outputs.
|
|
|
|
I.e. if we have
|
|
(y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
|
|
then, g is
|
|
(dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
|
|
dL/dy1, dL/dy2, ..., dL/dy_M),
|
|
|
|
where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
|
|
loss function). dL/dx_i is the partial derivative of L with respect
|
|
to x_i.
|
|
|
|
(Needs some math expert to say the comment above better.)
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects of type `Tout`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "SymbolicGradient", name, input, "Tout", Tout, "f", f)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return symbolic_gradient_eager_fallback(
|
|
input, Tout=Tout, f=f, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'symbolic_gradient' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"SymbolicGradient", input=input, Tout=Tout, f=f, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
|
|
_op.get_attr("f"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"SymbolicGradient", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
SymbolicGradient = tf_export("raw_ops.SymbolicGradient")(_ops.to_raw_op(symbolic_gradient))
|
|
|
|
|
|
def symbolic_gradient_eager_fallback(input, Tout, f, name, ctx):
|
|
if not isinstance(Tout, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'Tout' argument to "
|
|
"'symbolic_gradient' Op, not %r." % Tout)
|
|
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
|
_attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
_inputs_flat = list(input)
|
|
_attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f)
|
|
_result = _execute.execute(b"SymbolicGradient", len(Tout),
|
|
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
|
|
name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"SymbolicGradient", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
|
|
TV_ToBool_T = TypeVar("TV_ToBool_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
|
|
|
|
def to_bool(input: Annotated[Any, TV_ToBool_T], name=None) -> Annotated[Any, _atypes.Bool]:
|
|
r"""Converts a tensor to a scalar predicate.
|
|
|
|
Converts a tensor to a scalar predicate with the following rules:
|
|
|
|
- For 0D tensors, truthiness is determined by comparing against a "zero"
|
|
value. For numerical types it is the obvious zero. For strings it is the
|
|
empty string.
|
|
|
|
- For >0D tensors, truthiness is determined by looking at the number of
|
|
elements. If has zero elements, then the result is false. Otherwise the
|
|
result is true.
|
|
|
|
This matches the behavior of If and While for determining if a tensor counts
|
|
as true/false for a branch condition.
|
|
|
|
Args:
|
|
input: A `Tensor`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of type `bool`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "ToBool", name, input)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return to_bool_eager_fallback(
|
|
input, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"ToBool", input=input, name=name)
|
|
_result = _outputs[:]
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("T", _op._get_attr_type("T"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"ToBool", _inputs_flat, _attrs, _result)
|
|
_result, = _result
|
|
return _result
|
|
|
|
ToBool = tf_export("raw_ops.ToBool")(_ops.to_raw_op(to_bool))
|
|
|
|
|
|
def to_bool_eager_fallback(input: Annotated[Any, TV_ToBool_T], name, ctx) -> Annotated[Any, _atypes.Bool]:
|
|
_attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [])
|
|
_inputs_flat = [input]
|
|
_attrs = ("T", _attr_T)
|
|
_result = _execute.execute(b"ToBool", 1, inputs=_inputs_flat, attrs=_attrs,
|
|
ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"ToBool", _inputs_flat, _attrs, _result)
|
|
_result, = _result
|
|
return _result
|
|
|
|
|
|
def _while(input, cond, body, output_shapes=[], parallel_iterations:int=10, name=None):
|
|
r"""output = input; While (Cond(output)) { output = Body(output) }
|
|
|
|
Args:
|
|
input: A list of `Tensor` objects.
|
|
A list of input tensors whose types are T.
|
|
cond: A function decorated with @Defun.
|
|
A function takes 'input' and returns a tensor. If the tensor is
|
|
a scalar of non-boolean, the scalar is converted to a boolean
|
|
according to the following rule: if the scalar is a numerical
|
|
value, non-zero means True and zero means False; if the scalar is
|
|
a string, non-empty means True and empty means False. If the
|
|
tensor is not a scalar, non-emptiness means True and False
|
|
otherwise.
|
|
body: A function decorated with @Defun.
|
|
A function that takes a list of tensors and returns another
|
|
list of tensors. Both lists have the same types as specified
|
|
by T.
|
|
output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
|
|
parallel_iterations: An optional `int`. Defaults to `10`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A list of `Tensor` objects. Has the same type as `input`.
|
|
"""
|
|
_ctx = _context._context or _context.context()
|
|
tld = _ctx._thread_local_data
|
|
if tld.is_eager:
|
|
try:
|
|
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
|
_ctx, "While", name, input, "cond", cond, "body", body,
|
|
"output_shapes", output_shapes, "parallel_iterations",
|
|
parallel_iterations)
|
|
return _result
|
|
except _core._NotOkStatusException as e:
|
|
_ops.raise_from_not_ok_status(e, name)
|
|
except _core._FallbackException:
|
|
pass
|
|
try:
|
|
return _while_eager_fallback(
|
|
input, cond=cond, body=body, output_shapes=output_shapes,
|
|
parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
|
|
except _core._SymbolicException:
|
|
pass # Add nodes to the TensorFlow graph.
|
|
# Add nodes to the TensorFlow graph.
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'while' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
if parallel_iterations is None:
|
|
parallel_iterations = 10
|
|
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
|
|
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
|
"While", input=input, cond=cond, body=body,
|
|
output_shapes=output_shapes,
|
|
parallel_iterations=parallel_iterations, name=name)
|
|
_result = _outputs[:]
|
|
if not _result:
|
|
return _op
|
|
if _execute.must_record_gradient():
|
|
_attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body",
|
|
_op.get_attr("body"), "output_shapes",
|
|
_op.get_attr("output_shapes"), "parallel_iterations",
|
|
_op._get_attr_int("parallel_iterations"))
|
|
_inputs_flat = _op.inputs
|
|
_execute.record_gradient(
|
|
"While", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|
|
While = tf_export("raw_ops.While")(_ops.to_raw_op(_while))
|
|
|
|
|
|
def _while_eager_fallback(input, cond, body, output_shapes, parallel_iterations: int, name, ctx):
|
|
if output_shapes is None:
|
|
output_shapes = []
|
|
if not isinstance(output_shapes, (list, tuple)):
|
|
raise TypeError(
|
|
"Expected list for 'output_shapes' argument to "
|
|
"'while' Op, not %r." % output_shapes)
|
|
output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
|
|
if parallel_iterations is None:
|
|
parallel_iterations = 10
|
|
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
|
|
_attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
|
|
_inputs_flat = list(input)
|
|
_attrs = ("T", _attr_T, "cond", cond, "body", body, "output_shapes",
|
|
output_shapes, "parallel_iterations", parallel_iterations)
|
|
_result = _execute.execute(b"While", len(input), inputs=_inputs_flat,
|
|
attrs=_attrs, ctx=ctx, name=name)
|
|
if _execute.must_record_gradient():
|
|
_execute.record_gradient(
|
|
"While", _inputs_flat, _attrs, _result)
|
|
return _result
|
|
|