3RNN/Lib/site-packages/tensorflow/compiler/jit/ops/xla_ops.py
2024-05-26 19:49:15 +02:00

353 lines
14 KiB
Python

"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.security.fuzzing.py import annotation_types as _atypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar, List, Any
from typing_extensions import Annotated
TV_XlaClusterOutput_T = TypeVar("TV_XlaClusterOutput_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant)
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('xla_cluster_output')
def xla_cluster_output(input: Annotated[Any, TV_XlaClusterOutput_T], name=None) -> Annotated[Any, TV_XlaClusterOutput_T]:
r"""Operator that connects the output of an XLA computation to other consumer graph nodes.
Args:
input: A `Tensor`.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `input`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "XlaClusterOutput", name, input)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_xla_cluster_output(
(input, name,), None)
if _result is not NotImplemented:
return _result
return xla_cluster_output_eager_fallback(
input, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_cluster_output, (), dict(input=input, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_xla_cluster_output(
(input, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"XlaClusterOutput", input=input, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_cluster_output, (), dict(input=input, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"XlaClusterOutput", _inputs_flat, _attrs, _result)
_result, = _result
return _result
XlaClusterOutput = tf_export("raw_ops.XlaClusterOutput")(_ops.to_raw_op(xla_cluster_output))
_dispatcher_for_xla_cluster_output = xla_cluster_output._tf_type_based_dispatcher.Dispatch
def xla_cluster_output_eager_fallback(input: Annotated[Any, TV_XlaClusterOutput_T], name, ctx) -> Annotated[Any, TV_XlaClusterOutput_T]:
_attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [])
_inputs_flat = [input]
_attrs = ("T", _attr_T)
_result = _execute.execute(b"XlaClusterOutput", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"XlaClusterOutput", _inputs_flat, _attrs, _result)
_result, = _result
return _result
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('xla_launch')
def xla_launch(constants, args, resources: Annotated[List[Any], _atypes.Resource], Tresults, function, name=None):
r"""XLA Launch Op. For use by the XLA JIT only.
Args:
constants: A list of `Tensor` objects.
args: A list of `Tensor` objects.
resources: A list of `Tensor` objects with type `resource`.
Tresults: A list of `tf.DTypes`.
function: A function decorated with @Defun.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tresults`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "XlaLaunch", name, constants, args, resources, "Tresults",
Tresults, "function", function)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_xla_launch(
(constants, args, resources, Tresults, function, name,), None)
if _result is not NotImplemented:
return _result
return xla_launch_eager_fallback(
constants, args, resources, Tresults=Tresults, function=function,
name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch, (), dict(constants=constants, args=args,
resources=resources, Tresults=Tresults,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_xla_launch(
(constants, args, resources, Tresults, function, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch' Op, not %r." % resources)
_attr_Nresources = len(resources)
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"XlaLaunch", constants=constants, args=args, resources=resources,
Tresults=Tresults, function=function, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch, (), dict(constants=constants, args=args,
resources=resources, Tresults=Tresults,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("Tconstants", _op.get_attr("Tconstants"), "Targs",
_op.get_attr("Targs"), "Nresources",
_op._get_attr_int("Nresources"), "Tresults",
_op.get_attr("Tresults"), "function", _op.get_attr("function"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"XlaLaunch", _inputs_flat, _attrs, _result)
return _result
XlaLaunch = tf_export("raw_ops.XlaLaunch")(_ops.to_raw_op(xla_launch))
_dispatcher_for_xla_launch = xla_launch._tf_type_based_dispatcher.Dispatch
def xla_launch_eager_fallback(constants, args, resources: Annotated[List[Any], _atypes.Resource], Tresults, function, name, ctx):
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch' Op, not %r." % resources)
_attr_Nresources = len(resources)
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
_attr_Tconstants, constants = _execute.convert_to_mixed_eager_tensors(constants, ctx)
_attr_Targs, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
resources = _ops.convert_n_to_tensor(resources, _dtypes.resource)
_inputs_flat = list(constants) + list(args) + list(resources)
_attrs = ("Tconstants", _attr_Tconstants, "Targs", _attr_Targs,
"Nresources", _attr_Nresources, "Tresults", Tresults, "function", function)
_result = _execute.execute(b"XlaLaunch", len(Tresults), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"XlaLaunch", _inputs_flat, _attrs, _result)
return _result
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('xla_launch_v2')
def xla_launch_v2(args, Tresults, constants, resources, function, name=None):
r"""XLA Launch Op. For use by the XLA JIT only.
Args:
args: A list of `Tensor` objects.
Tresults: A list of `tf.DTypes`.
constants: A list of `ints`.
resources: A list of `ints`.
function: A function decorated with @Defun.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tresults`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "XlaLaunchV2", name, args, "Tresults", Tresults, "constants",
constants, "resources", resources, "function", function)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_xla_launch_v2(
(args, Tresults, constants, resources, function, name,), None)
if _result is not NotImplemented:
return _result
return xla_launch_v2_eager_fallback(
args, Tresults=Tresults, constants=constants, resources=resources,
function=function, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch_v2, (), dict(args=args, Tresults=Tresults,
constants=constants, resources=resources,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_xla_launch_v2(
(args, Tresults, constants, resources, function, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch_v2' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
if not isinstance(constants, (list, tuple)):
raise TypeError(
"Expected list for 'constants' argument to "
"'xla_launch_v2' Op, not %r." % constants)
constants = [_execute.make_int(_i, "constants") for _i in constants]
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch_v2' Op, not %r." % resources)
resources = [_execute.make_int(_i, "resources") for _i in resources]
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"XlaLaunchV2", args=args, Tresults=Tresults, constants=constants,
resources=resources, function=function, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch_v2, (), dict(args=args, Tresults=Tresults,
constants=constants, resources=resources,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("Targs", _op.get_attr("Targs"), "Tresults",
_op.get_attr("Tresults"), "constants",
_op.get_attr("constants"), "resources",
_op.get_attr("resources"), "function", _op.get_attr("function"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"XlaLaunchV2", _inputs_flat, _attrs, _result)
return _result
XlaLaunchV2 = tf_export("raw_ops.XlaLaunchV2")(_ops.to_raw_op(xla_launch_v2))
_dispatcher_for_xla_launch_v2 = xla_launch_v2._tf_type_based_dispatcher.Dispatch
def xla_launch_v2_eager_fallback(args, Tresults, constants, resources, function, name, ctx):
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch_v2' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
if not isinstance(constants, (list, tuple)):
raise TypeError(
"Expected list for 'constants' argument to "
"'xla_launch_v2' Op, not %r." % constants)
constants = [_execute.make_int(_i, "constants") for _i in constants]
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch_v2' Op, not %r." % resources)
resources = [_execute.make_int(_i, "resources") for _i in resources]
_attr_Targs, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
_inputs_flat = list(args)
_attrs = ("Targs", _attr_Targs, "Tresults", Tresults, "constants",
constants, "resources", resources, "function", function)
_result = _execute.execute(b"XlaLaunchV2", len(Tresults),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"XlaLaunchV2", _inputs_flat, _attrs, _result)
return _result