Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/ops/gen_control_flow_ops.py
2023-06-19 00:49:18 +02:00

861 lines
31 KiB
Python

"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar
def abort(error_msg="", exit_without_error=False, name=None):
r"""Raise a exception to abort the process when called.
If exit_without_error is true, the process will exit normally,
otherwise it will exit with a SIGABORT signal.
Returns nothing but an exception.
Args:
error_msg: An optional `string`. Defaults to `""`.
A string which is the message associated with the exception.
exit_without_error: An optional `bool`. Defaults to `False`.
name: A name for the operation (optional).
Returns:
The created Operation.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "Abort", name, "error_msg", error_msg, "exit_without_error",
exit_without_error)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return abort_eager_fallback(
error_msg=error_msg, exit_without_error=exit_without_error,
name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if error_msg is None:
error_msg = ""
error_msg = _execute.make_str(error_msg, "error_msg")
if exit_without_error is None:
exit_without_error = False
exit_without_error = _execute.make_bool(exit_without_error, "exit_without_error")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"Abort", error_msg=error_msg, exit_without_error=exit_without_error,
name=name)
return _op
Abort = tf_export("raw_ops.Abort")(_ops.to_raw_op(abort))
def abort_eager_fallback(error_msg, exit_without_error, name, ctx):
if error_msg is None:
error_msg = ""
error_msg = _execute.make_str(error_msg, "error_msg")
if exit_without_error is None:
exit_without_error = False
exit_without_error = _execute.make_bool(exit_without_error, "exit_without_error")
_inputs_flat = []
_attrs = ("error_msg", error_msg, "exit_without_error", exit_without_error)
_result = _execute.execute(b"Abort", 0, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
_result = None
return _result
def control_trigger(name=None):
r"""Does nothing. Serves as a control trigger for scheduling.
Only useful as a placeholder for control edges.
Args:
name: A name for the operation (optional).
Returns:
The created Operation.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "ControlTrigger", name)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return control_trigger_eager_fallback(
name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"ControlTrigger", name=name)
return _op
ControlTrigger = tf_export("raw_ops.ControlTrigger")(_ops.to_raw_op(control_trigger))
def control_trigger_eager_fallback(name, ctx):
_inputs_flat = []
_attrs = None
_result = _execute.execute(b"ControlTrigger", 0, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
_result = None
return _result
def enter(data, frame_name, is_constant=False, parallel_iterations=10, name=None):
r"""Creates or finds a child frame, and makes `data` available to the child frame.
This op is used together with `Exit` to create loops in the graph.
The unique `frame_name` is used by the `Executor` to identify frames. If
`is_constant` is true, `output` is a constant in the child frame; otherwise
it may be changed in the child frame. At most `parallel_iterations` iterations
are run in parallel in the child frame.
Args:
data: A `Tensor`. The tensor to be made available to the child frame.
frame_name: A `string`. The name of the child frame.
is_constant: An optional `bool`. Defaults to `False`.
If true, the output is constant within the child frame.
parallel_iterations: An optional `int`. Defaults to `10`.
The number of iterations allowed to run in parallel.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "Enter", name, data, "frame_name", frame_name, "is_constant",
is_constant, "parallel_iterations", parallel_iterations)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return enter_eager_fallback(
data, frame_name=frame_name, is_constant=is_constant,
parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
frame_name = _execute.make_str(frame_name, "frame_name")
if is_constant is None:
is_constant = False
is_constant = _execute.make_bool(is_constant, "is_constant")
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"Enter", data=data, frame_name=frame_name, is_constant=is_constant,
parallel_iterations=parallel_iterations, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"), "frame_name",
_op.get_attr("frame_name"), "is_constant",
_op._get_attr_bool("is_constant"), "parallel_iterations",
_op._get_attr_int("parallel_iterations"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"Enter", _inputs_flat, _attrs, _result)
_result, = _result
return _result
Enter = tf_export("raw_ops.Enter")(_ops.to_raw_op(enter))
def enter_eager_fallback(data, frame_name, is_constant, parallel_iterations, name, ctx):
frame_name = _execute.make_str(frame_name, "frame_name")
if is_constant is None:
is_constant = False
is_constant = _execute.make_bool(is_constant, "is_constant")
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_attr_T, (data,) = _execute.args_to_matching_eager([data], ctx, [])
_inputs_flat = [data]
_attrs = ("T", _attr_T, "frame_name", frame_name, "is_constant",
is_constant, "parallel_iterations", parallel_iterations)
_result = _execute.execute(b"Enter", 1, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"Enter", _inputs_flat, _attrs, _result)
_result, = _result
return _result
def _exit(data, name=None):
r"""Exits the current frame to its parent frame.
Exit makes its input `data` available to the parent frame.
Args:
data: A `Tensor`. The tensor to be made available to the parent frame.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "Exit", name, data)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return _exit_eager_fallback(
data, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"Exit", data=data, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"Exit", _inputs_flat, _attrs, _result)
_result, = _result
return _result
Exit = tf_export("raw_ops.Exit")(_ops.to_raw_op(_exit))
def _exit_eager_fallback(data, name, ctx):
_attr_T, (data,) = _execute.args_to_matching_eager([data], ctx, [])
_inputs_flat = [data]
_attrs = ("T", _attr_T)
_result = _execute.execute(b"Exit", 1, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"Exit", _inputs_flat, _attrs, _result)
_result, = _result
return _result
def loop_cond(input, name=None):
r"""Forwards the input to the output.
This operator represents the loop termination condition used by the
"pivot" switches of a loop.
Args:
input: A `Tensor` of type `bool`.
A boolean scalar, representing the branch predicate of the Switch op.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `bool`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "LoopCond", name, input)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return loop_cond_eager_fallback(
input, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"LoopCond", input=input, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ()
_inputs_flat = _op.inputs
_execute.record_gradient(
"LoopCond", _inputs_flat, _attrs, _result)
_result, = _result
return _result
LoopCond = tf_export("raw_ops.LoopCond")(_ops.to_raw_op(loop_cond))
def loop_cond_eager_fallback(input, name, ctx):
input = _ops.convert_to_tensor(input, _dtypes.bool)
_inputs_flat = [input]
_attrs = None
_result = _execute.execute(b"LoopCond", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"LoopCond", _inputs_flat, _attrs, _result)
_result, = _result
return _result
_MergeOutput = collections.namedtuple(
"Merge",
["output", "value_index"])
def merge(inputs, name=None):
r"""Forwards the value of an available tensor from `inputs` to `output`.
`Merge` waits for at least one of the tensors in `inputs` to become available.
It is usually combined with `Switch` to implement branching.
`Merge` forwards the first tensor to become available to `output`, and sets
`value_index` to its index in `inputs`.
Args:
inputs: A list of at least 1 `Tensor` objects with the same type.
The input tensors, exactly one of which will become available.
name: A name for the operation (optional).
Returns:
A tuple of `Tensor` objects (output, value_index).
output: A `Tensor`. Has the same type as `inputs`.
value_index: A `Tensor` of type `int32`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "Merge", name, inputs)
_result = _MergeOutput._make(_result)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return merge_eager_fallback(
inputs, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
if not isinstance(inputs, (list, tuple)):
raise TypeError(
"Expected list for 'inputs' argument to "
"'merge' Op, not %r." % inputs)
_attr_N = len(inputs)
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"Merge", inputs=inputs, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"), "N", _op._get_attr_int("N"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"Merge", _inputs_flat, _attrs, _result)
_result = _MergeOutput._make(_result)
return _result
Merge = tf_export("raw_ops.Merge")(_ops.to_raw_op(merge))
def merge_eager_fallback(inputs, name, ctx):
if not isinstance(inputs, (list, tuple)):
raise TypeError(
"Expected list for 'inputs' argument to "
"'merge' Op, not %r." % inputs)
_attr_N = len(inputs)
_attr_T, inputs = _execute.args_to_matching_eager(list(inputs), ctx, [])
_inputs_flat = list(inputs)
_attrs = ("T", _attr_T, "N", _attr_N)
_result = _execute.execute(b"Merge", 2, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"Merge", _inputs_flat, _attrs, _result)
_result = _MergeOutput._make(_result)
return _result
def next_iteration(data, name=None):
r"""Makes its input available to the next iteration.
Args:
data: A `Tensor`. The tensor to be made available to the next iteration.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "NextIteration", name, data)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return next_iteration_eager_fallback(
data, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"NextIteration", data=data, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"NextIteration", _inputs_flat, _attrs, _result)
_result, = _result
return _result
NextIteration = tf_export("raw_ops.NextIteration")(_ops.to_raw_op(next_iteration))
def next_iteration_eager_fallback(data, name, ctx):
_attr_T, (data,) = _execute.args_to_matching_eager([data], ctx, [])
_inputs_flat = [data]
_attrs = ("T", _attr_T)
_result = _execute.execute(b"NextIteration", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"NextIteration", _inputs_flat, _attrs, _result)
_result, = _result
return _result
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('no_op')
def no_op(name=None):
r"""Does nothing. Only useful as a placeholder for control edges.
Args:
name: A name for the operation (optional).
Returns:
The created Operation.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "NoOp", name)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_no_op(
(name,), None)
if _result is not NotImplemented:
return _result
return no_op_eager_fallback(
name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
no_op, (), dict(name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_no_op(
(name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"NoOp", name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
no_op, (), dict(name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
return _op
NoOp = tf_export("raw_ops.NoOp")(_ops.to_raw_op(no_op))
_dispatcher_for_no_op = no_op._tf_type_based_dispatcher.Dispatch
def no_op_eager_fallback(name, ctx):
_inputs_flat = []
_attrs = None
_result = _execute.execute(b"NoOp", 0, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
_result = None
return _result
def ref_enter(data, frame_name, is_constant=False, parallel_iterations=10, name=None):
r"""Creates or finds a child frame, and makes `data` available to the child frame.
The unique `frame_name` is used by the `Executor` to identify frames. If
`is_constant` is true, `output` is a constant in the child frame; otherwise
it may be changed in the child frame. At most `parallel_iterations` iterations
are run in parallel in the child frame.
Args:
data: A mutable `Tensor`.
The tensor to be made available to the child frame.
frame_name: A `string`. The name of the child frame.
is_constant: An optional `bool`. Defaults to `False`.
If true, the output is constant within the child frame.
parallel_iterations: An optional `int`. Defaults to `10`.
The number of iterations allowed to run in parallel.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
raise RuntimeError("ref_enter op does not support eager execution. Arg 'output' is a ref.")
# Add nodes to the TensorFlow graph.
frame_name = _execute.make_str(frame_name, "frame_name")
if is_constant is None:
is_constant = False
is_constant = _execute.make_bool(is_constant, "is_constant")
if parallel_iterations is None:
parallel_iterations = 10
parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RefEnter", data=data, frame_name=frame_name, is_constant=is_constant,
parallel_iterations=parallel_iterations, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"), "frame_name",
_op.get_attr("frame_name"), "is_constant",
_op._get_attr_bool("is_constant"), "parallel_iterations",
_op._get_attr_int("parallel_iterations"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RefEnter", _inputs_flat, _attrs, _result)
_result, = _result
return _result
RefEnter = tf_export("raw_ops.RefEnter")(_ops.to_raw_op(ref_enter))
def ref_enter_eager_fallback(data, frame_name, is_constant, parallel_iterations, name, ctx):
raise RuntimeError("ref_enter op does not support eager execution. Arg 'output' is a ref.")
def ref_exit(data, name=None):
r"""Exits the current frame to its parent frame.
Exit makes its input `data` available to the parent frame.
Args:
data: A mutable `Tensor`.
The tensor to be made available to the parent frame.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
raise RuntimeError("ref_exit op does not support eager execution. Arg 'output' is a ref.")
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RefExit", data=data, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RefExit", _inputs_flat, _attrs, _result)
_result, = _result
return _result
RefExit = tf_export("raw_ops.RefExit")(_ops.to_raw_op(ref_exit))
def ref_exit_eager_fallback(data, name, ctx):
raise RuntimeError("ref_exit op does not support eager execution. Arg 'output' is a ref.")
_RefMergeOutput = collections.namedtuple(
"RefMerge",
["output", "value_index"])
def ref_merge(inputs, name=None):
r"""Forwards the value of an available tensor from `inputs` to `output`.
`Merge` waits for at least one of the tensors in `inputs` to become available.
It is usually combined with `Switch` to implement branching.
`Merge` forwards the first tensor for become available to `output`, and sets
`value_index` to its index in `inputs`.
Args:
inputs: A list of at least 1 mutable `Tensor` objects with the same type.
The input tensors, exactly one of which will become available.
name: A name for the operation (optional).
Returns:
A tuple of `Tensor` objects (output, value_index).
output: A mutable `Tensor`. Has the same type as `inputs`.
value_index: A `Tensor` of type `int32`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
raise RuntimeError("ref_merge op does not support eager execution. Arg 'output' is a ref.")
# Add nodes to the TensorFlow graph.
if not isinstance(inputs, (list, tuple)):
raise TypeError(
"Expected list for 'inputs' argument to "
"'ref_merge' Op, not %r." % inputs)
_attr_N = len(inputs)
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RefMerge", inputs=inputs, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"), "N", _op._get_attr_int("N"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RefMerge", _inputs_flat, _attrs, _result)
_result = _RefMergeOutput._make(_result)
return _result
RefMerge = tf_export("raw_ops.RefMerge")(_ops.to_raw_op(ref_merge))
def ref_merge_eager_fallback(inputs, name, ctx):
raise RuntimeError("ref_merge op does not support eager execution. Arg 'output' is a ref.")
def ref_next_iteration(data, name=None):
r"""Makes its input available to the next iteration.
Args:
data: A mutable `Tensor`.
The tensor to be made available to the next iteration.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
raise RuntimeError("ref_next_iteration op does not support eager execution. Arg 'output' is a ref.")
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RefNextIteration", data=data, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RefNextIteration", _inputs_flat, _attrs, _result)
_result, = _result
return _result
RefNextIteration = tf_export("raw_ops.RefNextIteration")(_ops.to_raw_op(ref_next_iteration))
def ref_next_iteration_eager_fallback(data, name, ctx):
raise RuntimeError("ref_next_iteration op does not support eager execution. Arg 'output' is a ref.")
def ref_select(index, inputs, name=None):
r"""Forwards the `index`th element of `inputs` to `output`.
Args:
index: A `Tensor` of type `int32`.
A scalar that determines the input that gets selected.
inputs: A list of at least 1 mutable `Tensor` objects with the same type.
A list of ref tensors, one of which will be forwarded to `output`.
name: A name for the operation (optional).
Returns:
A mutable `Tensor`. Has the same type as `inputs`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
raise RuntimeError("ref_select op does not support eager execution. Arg 'output' is a ref.")
# Add nodes to the TensorFlow graph.
if not isinstance(inputs, (list, tuple)):
raise TypeError(
"Expected list for 'inputs' argument to "
"'ref_select' Op, not %r." % inputs)
_attr_N = len(inputs)
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RefSelect", index=index, inputs=inputs, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"), "N", _op._get_attr_int("N"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RefSelect", _inputs_flat, _attrs, _result)
_result, = _result
return _result
RefSelect = tf_export("raw_ops.RefSelect")(_ops.to_raw_op(ref_select))
def ref_select_eager_fallback(index, inputs, name, ctx):
raise RuntimeError("ref_select op does not support eager execution. Arg 'output' is a ref.")
_RefSwitchOutput = collections.namedtuple(
"RefSwitch",
["output_false", "output_true"])
def ref_switch(data, pred, name=None):
r"""Forwards the ref tensor `data` to the output port determined by `pred`.
If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise,
the data goes to `output_false`.
See also `Switch` and `Merge`.
Args:
data: A mutable `Tensor`.
The ref tensor to be forwarded to the appropriate output.
pred: A `Tensor` of type `bool`.
A scalar that specifies which output port will receive data.
name: A name for the operation (optional).
Returns:
A tuple of `Tensor` objects (output_false, output_true).
output_false: A mutable `Tensor`. Has the same type as `data`.
output_true: A mutable `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
raise RuntimeError("ref_switch op does not support eager execution. Arg 'output_true' is a ref.")
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"RefSwitch", data=data, pred=pred, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"RefSwitch", _inputs_flat, _attrs, _result)
_result = _RefSwitchOutput._make(_result)
return _result
RefSwitch = tf_export("raw_ops.RefSwitch")(_ops.to_raw_op(ref_switch))
def ref_switch_eager_fallback(data, pred, name, ctx):
raise RuntimeError("ref_switch op does not support eager execution. Arg 'output_true' is a ref.")
_SwitchOutput = collections.namedtuple(
"Switch",
["output_false", "output_true"])
def switch(data, pred, name=None):
r"""Forwards `data` to the output port determined by `pred`.
If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise,
the data goes to `output_false`.
See also `RefSwitch` and `Merge`.
Args:
data: A `Tensor`. The tensor to be forwarded to the appropriate output.
pred: A `Tensor` of type `bool`.
A scalar that specifies which output port will receive data.
name: A name for the operation (optional).
Returns:
A tuple of `Tensor` objects (output_false, output_true).
output_false: A `Tensor`. Has the same type as `data`.
output_true: A `Tensor`. Has the same type as `data`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "Switch", name, data, pred)
_result = _SwitchOutput._make(_result)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
return switch_eager_fallback(
data, pred, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
# Add nodes to the TensorFlow graph.
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"Switch", data=data, pred=pred, name=name)
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"Switch", _inputs_flat, _attrs, _result)
_result = _SwitchOutput._make(_result)
return _result
Switch = tf_export("raw_ops.Switch")(_ops.to_raw_op(switch))
def switch_eager_fallback(data, pred, name, ctx):
_attr_T, (data,) = _execute.args_to_matching_eager([data], ctx, [])
pred = _ops.convert_to_tensor(pred, _dtypes.bool)
_inputs_flat = [data, pred]
_attrs = ("T", _attr_T)
_result = _execute.execute(b"Switch", 2, inputs=_inputs_flat, attrs=_attrs,
ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"Switch", _inputs_flat, _attrs, _result)
_result = _SwitchOutput._make(_result)
return _result