1303 lines
49 KiB
Python
1303 lines
49 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Type-based dispatch for TensorFlow's Python APIs.
|
|
|
|
"Python APIs" refers to Python functions that have been exported with
|
|
`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also
|
|
referred to as "ops".
|
|
|
|
There are currently two dispatch systems for TensorFlow:
|
|
|
|
* The "fallback dispatch" system calls an API's standard implementation first,
|
|
and only tries to perform dispatch if that standard implementation raises a
|
|
TypeError (or ValueError) exception.
|
|
|
|
* The "type-based dispatch" system checks the types of the parameters passed
|
|
to an API, and performs dispatch if those types match any signatures that
|
|
have been registered for dispatch.
|
|
|
|
The fallback dispatch system was the original dispatch system, but it was
|
|
somewhat brittle and had limitations, such as an inability to support dispatch
|
|
for some operations (like convert_to_tensor). We plan to remove the fallback
|
|
dispatch system in favor of the type-based dispatch system, once all users have
|
|
been switched over to use it.
|
|
|
|
### Fallback Dispatch
|
|
|
|
The fallback dispatch system is based on "operation dispatchers", which can be
|
|
used to override the behavior for TensorFlow ops when they are called with
|
|
otherwise unsupported argument types. In particular, when an operation is
|
|
called with arguments that would cause it to raise a TypeError, it falls back on
|
|
its registered operation dispatchers. If any registered dispatchers can handle
|
|
the arguments, then its result is returned. Otherwise, the original TypeError is
|
|
raised.
|
|
|
|
### Type-based Dispatch
|
|
|
|
The main interface for the type-based dispatch system is the `dispatch_for_api`
|
|
decorator, which overrides the default implementation for a TensorFlow API.
|
|
The decorated function (known as the "dispatch target") will override the
|
|
default implementation for the API when the API is called with parameters that
|
|
match a specified type signature.
|
|
|
|
### Dispatch Support
|
|
|
|
By default, dispatch support is added to the generated op wrappers for any
|
|
visible ops by default. APIs/ops that are implemented in Python can opt in to
|
|
dispatch support using the `add_dispatch_support` decorator.
|
|
"""
|
|
|
|
import collections
|
|
import itertools
|
|
import typing # pylint: disable=unused-import (used in doctests)
|
|
|
|
from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.util import tf_decorator
|
|
from tensorflow.python.util import tf_export as tf_export_lib
|
|
from tensorflow.python.util import tf_inspect
|
|
from tensorflow.python.util import traceback_utils
|
|
from tensorflow.python.util import type_annotations
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
# Private function attributes used to store dispatchers on TensorFlow APIs.
|
|
FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers"
|
|
TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher"
|
|
|
|
# OpDispatchers which should be used for all operations.
|
|
_GLOBAL_DISPATCHERS = []
|
|
|
|
|
|
################################################################################
|
|
# Fallback Dispatch
|
|
################################################################################
|
|
|
|
|
|
@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
|
|
class OpDispatcher(object):
|
|
"""Abstract base class for TensorFlow operator dispatchers.
|
|
|
|
Each operation dispatcher acts as an override handler for a single
|
|
TensorFlow operation, and its results are used when the handler indicates
|
|
that it can handle the operation's arguments (by returning any value other
|
|
than `OpDispatcher.NOT_SUPPORTED`).
|
|
"""
|
|
|
|
# Sentinel value that can be returned to indicate that an operation
|
|
# dispatcher does not support a given set of arguments.
|
|
NOT_SUPPORTED = object()
|
|
|
|
def handle(self, args, kwargs): # pylint: disable=unused-argument
|
|
"""Handle this dispatcher's operation with the specified arguments.
|
|
|
|
If this operation dispatcher can handle the given arguments, then
|
|
return an appropriate value (or raise an appropriate exception).
|
|
|
|
Args:
|
|
args: The arguments to the operation.
|
|
kwargs: They keyword arguments to the operation.
|
|
|
|
Returns:
|
|
The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
|
|
dispatcher can not handle the given arguments.
|
|
"""
|
|
return self.NOT_SUPPORTED
|
|
|
|
def register(self, op):
|
|
"""Register this dispatcher as a handler for `op`.
|
|
|
|
Args:
|
|
op: Python function: the TensorFlow operation that should be handled. Must
|
|
have a dispatch list (which is added automatically for generated ops,
|
|
and can be added to Python ops using the `add_dispatch_support`
|
|
decorator).
|
|
"""
|
|
if not hasattr(op, FALLBACK_DISPATCH_ATTR):
|
|
raise AssertionError("Dispatching not enabled for %s" % op)
|
|
getattr(op, FALLBACK_DISPATCH_ATTR).append(self)
|
|
|
|
|
|
@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
|
|
class GlobalOpDispatcher(object):
|
|
"""Abstract base class for TensorFlow global operator dispatchers."""
|
|
|
|
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
|
|
|
|
def handle(self, op, args, kwargs):
|
|
"""Handle the specified operation with the specified arguments."""
|
|
|
|
def register(self):
|
|
"""Register this dispatcher as a handler for all ops."""
|
|
_GLOBAL_DISPATCHERS.append(self)
|
|
|
|
|
|
def dispatch(op, args, kwargs):
|
|
"""Returns the result from the first successful dispatcher for a given op.
|
|
|
|
Calls the `handle` method of each `OpDispatcher` that has been registered
|
|
to handle `op`, and returns the value from the first successful handler.
|
|
|
|
Args:
|
|
op: Python function: the operation to dispatch for.
|
|
args: The arguments to the operation.
|
|
kwargs: They keyword arguments to the operation.
|
|
|
|
Returns:
|
|
The result of the operation, or `NOT_SUPPORTED` if no registered
|
|
dispatcher can handle the given arguments.
|
|
"""
|
|
for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR):
|
|
result = dispatcher.handle(args, kwargs)
|
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
|
return result
|
|
for dispatcher in _GLOBAL_DISPATCHERS:
|
|
result = dispatcher.handle(op, args, kwargs)
|
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
|
return result
|
|
return OpDispatcher.NOT_SUPPORTED
|
|
|
|
|
|
class _TypeBasedDispatcher(OpDispatcher):
|
|
"""Dispatcher that handles op if any arguments have a specified type.
|
|
|
|
Checks the types of the arguments and keyword arguments (including elements
|
|
of lists or tuples), and if any argument values have the indicated type(s),
|
|
then delegates to an override function.
|
|
"""
|
|
|
|
def __init__(self, override_func, types):
|
|
self._types = types
|
|
self._override_func = override_func
|
|
|
|
def _handles(self, args, kwargs):
|
|
for arg in itertools.chain(args, kwargs.values()):
|
|
if (isinstance(arg, self._types) or
|
|
(isinstance(arg, (list, tuple)) and
|
|
any(isinstance(elt, self._types) for elt in arg))):
|
|
return True
|
|
return False
|
|
|
|
def handle(self, args, kwargs):
|
|
if self._handles(args, kwargs):
|
|
return self._override_func(*args, **kwargs)
|
|
else:
|
|
return self.NOT_SUPPORTED
|
|
|
|
|
|
def _remove_annotation(sig):
|
|
"""Removes annotation from a python Signature."""
|
|
parameters = [p.replace(annotation=p.empty) for p in sig.parameters.values()]
|
|
return sig.replace(parameters=parameters, return_annotation=sig.empty)
|
|
|
|
|
|
def _get_required_param_names(sig):
|
|
"""Returns a list of required parameter names from a python Signature."""
|
|
params = []
|
|
for p in sig.parameters.values():
|
|
if p.kind == p.VAR_POSITIONAL:
|
|
continue
|
|
if p.kind == p.VAR_KEYWORD:
|
|
continue
|
|
if p.default is not p.empty:
|
|
continue
|
|
params.append(p.name)
|
|
return params
|
|
|
|
|
|
def get_compatible_func(op, func):
|
|
"""Returns a compatible function.
|
|
|
|
Args:
|
|
op: a callable with whose signature the returned function is compatible.
|
|
func: a callable which is called by the returned function.
|
|
|
|
Returns:
|
|
a compatible function, which conducts the actions of `func` but can
|
|
be called like `op`, given that:
|
|
- the list of required arguments in `func` and `op` are the same.
|
|
- there is no override of the default arguments of `op` that are not
|
|
supported by `func`.
|
|
"""
|
|
op_signature = _remove_annotation(tf_inspect.signature(op))
|
|
func_signature = _remove_annotation(tf_inspect.signature(func))
|
|
|
|
# Identitical signatures, no need to apply compatibility fixes.
|
|
if op_signature == func_signature:
|
|
return func
|
|
|
|
# When calling func:
|
|
# - Positional args without default must be in the same order.
|
|
# - Ignore missing optional arguments from op
|
|
|
|
op_pos_names = _get_required_param_names(op_signature)
|
|
func_pos_names = _get_required_param_names(func_signature)
|
|
|
|
if op_pos_names != func_pos_names:
|
|
raise AssertionError(
|
|
"The decorated function's non-default arguments must be identical"
|
|
" to that of the overridden op."
|
|
f" func has {func_pos_names}. op has {op_pos_names}."
|
|
)
|
|
|
|
func_missing_params = {}
|
|
|
|
for name in set(op_signature.parameters.keys()) - set(
|
|
func_signature.parameters.keys()
|
|
):
|
|
p = op_signature.parameters[name]
|
|
if p.default is p.empty:
|
|
raise AssertionError(
|
|
"The decorated function's signature must implement all of the"
|
|
f" non-default arguments of the overridden op. Argument `{name}` is"
|
|
" unimplemented."
|
|
)
|
|
func_missing_params[name] = p
|
|
|
|
def compatible_func(*args, **kwargs):
|
|
bound = op_signature.bind(*args, **kwargs)
|
|
for name, param in func_missing_params.items():
|
|
if name not in bound.arguments:
|
|
continue
|
|
value = bound.arguments.pop(name)
|
|
if value is not param.default:
|
|
raise AssertionError(
|
|
f"Dispatched op is called with argument `{name}` set to a"
|
|
" non-default value, which is not supported by the decorated"
|
|
" function"
|
|
)
|
|
return func(*bound.args, **bound.kwargs)
|
|
|
|
return compatible_func
|
|
|
|
|
|
# pylint: disable=g-doc-return-or-yield
|
|
def dispatch_for_types(op, *types):
|
|
"""Decorator to declare that a Python function overrides an op for a type.
|
|
|
|
The decorated function is used to override `op` if any of the arguments or
|
|
keyword arguments (including elements of lists or tuples) have one of the
|
|
specified types.
|
|
|
|
Example:
|
|
|
|
```python
|
|
@dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
|
|
def ragged_add(x, y, name=None): ...
|
|
```
|
|
|
|
Args:
|
|
op: Python function: the operation that should be overridden.
|
|
*types: The argument types for which this function should be used.
|
|
"""
|
|
|
|
def decorator(func):
|
|
|
|
_TypeBasedDispatcher(get_compatible_func(op, func), types).register(op)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
# pylint: enable=g-doc-return-or-yield
|
|
|
|
|
|
def add_fallback_dispatch_list(target):
|
|
"""Decorator that adds a dispatch_list attribute to an op."""
|
|
if hasattr(target, FALLBACK_DISPATCH_ATTR):
|
|
raise AssertionError("%s already has a dispatch list" % target)
|
|
setattr(target, FALLBACK_DISPATCH_ATTR, [])
|
|
return target
|
|
|
|
|
|
# Alias for backwards-compatibility.
|
|
add_dispatch_list = add_fallback_dispatch_list
|
|
|
|
|
|
################################################################################
|
|
# Type-based Dispatch
|
|
################################################################################
|
|
|
|
|
|
@tf_export("experimental.dispatch_for_api")
|
|
def dispatch_for_api(api, *signatures):
|
|
"""Decorator that overrides the default implementation for a TensorFlow API.
|
|
|
|
The decorated function (known as the "dispatch target") will override the
|
|
default implementation for the API when the API is called with parameters that
|
|
match a specified type signature. Signatures are specified using dictionaries
|
|
that map parameter names to type annotations. E.g., in the following example,
|
|
`masked_add` will be called for `tf.add` if both `x` and `y` are
|
|
`MaskedTensor`s:
|
|
|
|
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
|
... values: tf.Tensor
|
|
... mask: tf.Tensor
|
|
|
|
>>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
|
|
... def masked_add(x, y, name=None):
|
|
... return MaskedTensor(x.values + y.values, x.mask & y.mask)
|
|
|
|
>>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
|
|
>>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
|
|
values=[11 12], mask=[ True False]
|
|
|
|
If multiple type signatures are specified, then the dispatch target will be
|
|
called if any of the signatures match. For example, the following code
|
|
registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is
|
|
a `MaskedTensor`.
|
|
|
|
>>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
|
|
... def masked_add(x, y):
|
|
... x_values = x.values if isinstance(x, MaskedTensor) else x
|
|
... x_mask = x.mask if isinstance(x, MaskedTensor) else True
|
|
... y_values = y.values if isinstance(y, MaskedTensor) else y
|
|
... y_mask = y.mask if isinstance(y, MaskedTensor) else True
|
|
... return MaskedTensor(x_values + y_values, x_mask & y_mask)
|
|
|
|
The type annotations in type signatures may be type objects (e.g.,
|
|
`MaskedTensor`), `typing.List` values, or `typing.Union` values. For
|
|
example, the following will register `masked_concat` to be called if `values`
|
|
is a list of `MaskedTensor` values:
|
|
|
|
>>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
|
|
... def masked_concat(values, axis):
|
|
... return MaskedTensor(tf.concat([v.values for v in values], axis),
|
|
... tf.concat([v.mask for v in values], axis))
|
|
|
|
Each type signature must contain at least one subclass of `tf.CompositeTensor`
|
|
(which includes subclasses of `tf.ExtensionType`), and dispatch will only be
|
|
triggered if at least one type-annotated parameter contains a
|
|
`CompositeTensor` value. This rule avoids invoking dispatch in degenerate
|
|
cases, such as the following examples:
|
|
|
|
* `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not
|
|
dispatch to the decorated dispatch target when the user calls
|
|
`tf.concat([])`.
|
|
|
|
* `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y':
|
|
Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch
|
|
target when the user calls `tf.add(tf.constant(1), tf.constant(2))`.
|
|
|
|
The dispatch target's signature must match the signature of the API that is
|
|
being overridden. In particular, parameters must have the same names, and
|
|
must occur in the same order. The dispatch target may optionally elide the
|
|
"name" parameter, in which case it will be wrapped with a call to
|
|
`tf.name_scope` when appropraite.
|
|
|
|
Args:
|
|
api: The TensorFlow API to override.
|
|
*signatures: Dictionaries mapping parameter names or indices to type
|
|
annotations, specifying when the dispatch target should be called. In
|
|
particular, the dispatch target will be called if any signature matches;
|
|
and a signature matches if all of the specified parameters have types that
|
|
match with the indicated type annotations. If no signatures are
|
|
specified, then a signature will be read from the dispatch target
|
|
function's type annotations.
|
|
|
|
Returns:
|
|
A decorator that overrides the default implementation for `api`.
|
|
|
|
#### Registered APIs
|
|
|
|
The TensorFlow APIs that may be overridden by `@dispatch_for_api` are:
|
|
|
|
<<API_LIST>>
|
|
"""
|
|
dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None)
|
|
if dispatcher is None:
|
|
raise ValueError(f"{api} does not support dispatch.")
|
|
|
|
api_signature = tf_inspect.signature(api)
|
|
signature_checkers = [
|
|
_make_signature_checker(api_signature, signature)
|
|
for signature in signatures
|
|
]
|
|
|
|
def decorator(dispatch_target):
|
|
"""Decorator that registers the given dispatch target."""
|
|
if not callable(dispatch_target):
|
|
raise TypeError("Expected dispatch_target to be callable; "
|
|
f"got {dispatch_target!r}")
|
|
dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature)
|
|
_check_signature(api_signature, dispatch_target)
|
|
|
|
for signature_checker in signature_checkers:
|
|
dispatcher.Register(signature_checker, dispatch_target)
|
|
_TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures)
|
|
|
|
if not signature_checkers:
|
|
signature = _signature_from_annotations(dispatch_target)
|
|
checker = _make_signature_checker(api_signature, signature)
|
|
dispatcher.Register(checker, dispatch_target)
|
|
_TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature)
|
|
|
|
return dispatch_target
|
|
|
|
return decorator
|
|
|
|
|
|
# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`,
|
|
# which can be used for documentation generation and for improved error messages
|
|
# when APIs are called with unsupported types.
|
|
_TYPE_BASED_DISPATCH_SIGNATURES = {}
|
|
|
|
|
|
def apis_with_type_based_dispatch():
|
|
"""Returns a list of TensorFlow APIs that support type-based dispatch."""
|
|
return sorted(
|
|
_TYPE_BASED_DISPATCH_SIGNATURES,
|
|
key=lambda api: f"{api.__module__}.{api.__name__}")
|
|
|
|
|
|
def type_based_dispatch_signatures_for(cls):
|
|
"""Returns dispatch signatures that have been registered for a given class.
|
|
|
|
This function is intended for documentation-generation purposes.
|
|
|
|
Args:
|
|
cls: The class to search for. Type signatures are searched recursively, so
|
|
e.g., if `cls=RaggedTensor`, then information will be returned for all
|
|
dispatch targets that have `RaggedTensor` anywhere in their type
|
|
annotations (including nested in `typing.Union` or `typing.List`.)
|
|
|
|
Returns:
|
|
A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API
|
|
function; and `signatures` is a list of dispatch signatures for `api`
|
|
that include `cls`. (Each signature is a dict mapping argument names to
|
|
type annotations; see `dispatch_for_api` for more info.)
|
|
"""
|
|
|
|
def contains_cls(x):
|
|
"""Returns true if `x` contains `cls`."""
|
|
if isinstance(x, dict):
|
|
return any(contains_cls(v) for v in x.values())
|
|
elif x is cls:
|
|
return True
|
|
elif (type_annotations.is_generic_list(x) or
|
|
type_annotations.is_generic_union(x)):
|
|
type_args = type_annotations.get_generic_type_args(x)
|
|
return any(contains_cls(arg) for arg in type_args)
|
|
else:
|
|
return False
|
|
|
|
result = {}
|
|
for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
|
|
for _, signatures in api_signatures.items():
|
|
filtered = list(filter(contains_cls, signatures))
|
|
if filtered:
|
|
result.setdefault(api, []).extend(filtered)
|
|
return result
|
|
|
|
|
|
# TODO(edloper): Consider using a mechanism like this to automatically add
|
|
# the `name` argument to all TensorFlow APIs that are implemented in Python
|
|
# (so each Python function doesn't need to do it manually).
|
|
def _add_name_scope_wrapper(func, api_signature):
|
|
"""Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.
|
|
|
|
If `func` already expects a "name" arg, or if `api_signature` does not
|
|
expect a "name" arg, then returns `func` as-is.
|
|
|
|
Args:
|
|
func: The function to wrap. Signature must match `api_signature` (except
|
|
the "name" parameter may be missing.
|
|
api_signature: The signature of the original API (used to find the index for
|
|
the "name" parameter).
|
|
|
|
Returns:
|
|
The wrapped function (or the original function if no wrapping is needed).
|
|
"""
|
|
if "name" not in api_signature.parameters:
|
|
return func # no wrapping needed (API has no name parameter).
|
|
|
|
func_signature = tf_inspect.signature(func)
|
|
func_argspec = tf_inspect.getargspec(func)
|
|
if "name" in func_signature.parameters or func_argspec.keywords is not None:
|
|
return func # No wrapping needed (already has name parameter).
|
|
|
|
name_index = list(api_signature.parameters).index("name")
|
|
|
|
def wrapped_func(*args, **kwargs):
|
|
if name_index < len(args):
|
|
name = args[name_index]
|
|
args = args[:name_index] + args[name_index + 1:]
|
|
else:
|
|
name = kwargs.pop("name", None)
|
|
if name is None:
|
|
return func(*args, **kwargs)
|
|
else:
|
|
with ops.name_scope(name):
|
|
return func(*args, **kwargs)
|
|
|
|
wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
|
|
wrapped_func.__signature__ = func_signature.replace(
|
|
parameters=(list(func_signature.parameters.values()) +
|
|
[api_signature.parameters["name"]]))
|
|
del wrapped_func._tf_decorator
|
|
return wrapped_func
|
|
|
|
|
|
@tf_export("experimental.unregister_dispatch_for")
|
|
def unregister_dispatch_for(dispatch_target):
|
|
"""Unregisters a function that was registered with `@dispatch_for_*`.
|
|
|
|
This is primarily intended for testing purposes.
|
|
|
|
Example:
|
|
|
|
>>> # Define a type and register a dispatcher to override `tf.abs`:
|
|
>>> class MyTensor(tf.experimental.ExtensionType):
|
|
... value: tf.Tensor
|
|
>>> @tf.experimental.dispatch_for_api(tf.abs)
|
|
... def my_abs(x: MyTensor):
|
|
... return MyTensor(tf.abs(x.value))
|
|
>>> tf.abs(MyTensor(5))
|
|
MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
|
|
|
|
>>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`.
|
|
>>> unregister_dispatch_for(my_abs)
|
|
>>> tf.abs(MyTensor(5))
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Attempt to convert a value ... to a Tensor.
|
|
|
|
Args:
|
|
dispatch_target: The function to unregister.
|
|
|
|
Raises:
|
|
ValueError: If `dispatch_target` was not registered using `@dispatch_for`,
|
|
`@dispatch_for_unary_elementwise_apis`, or
|
|
`@dispatch_for_binary_elementwise_apis`.
|
|
"""
|
|
found = False
|
|
|
|
# Check if dispatch_target registered by `@dispatch_for_api`
|
|
for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
|
|
if dispatch_target in signatures:
|
|
dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR)
|
|
dispatcher.Unregister(dispatch_target)
|
|
del signatures[dispatch_target]
|
|
found = True
|
|
|
|
# Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis`
|
|
elementwise_keys_to_delete = [
|
|
key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items()
|
|
if handler is dispatch_target
|
|
]
|
|
for key in set(elementwise_keys_to_delete):
|
|
for _, target in _ELEMENTWISE_API_TARGETS[key]:
|
|
unregister_dispatch_for(target)
|
|
del _ELEMENTWISE_API_HANDLERS[key]
|
|
del _ELEMENTWISE_API_TARGETS[key]
|
|
found = True
|
|
|
|
if not found:
|
|
raise ValueError(f"Function {dispatch_target} was not registered using "
|
|
"a `@dispatch_for_*` decorator.")
|
|
|
|
|
|
def register_dispatchable_type(cls):
|
|
"""Class decorator that registers a type for use with type-based dispatch.
|
|
|
|
Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType`
|
|
(which are automatically registered).
|
|
|
|
Note: this function is intended to support internal legacy use cases (such
|
|
as RaggedTensorValue), and will probably not be exposed as a public API.
|
|
|
|
Args:
|
|
cls: The class to register.
|
|
|
|
Returns:
|
|
`cls`.
|
|
"""
|
|
_api_dispatcher.register_dispatchable_type(cls)
|
|
return cls
|
|
|
|
|
|
def add_type_based_api_dispatcher(target):
|
|
"""Adds a PythonAPIDispatcher to the given TensorFlow API function."""
|
|
if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
|
|
raise ValueError(f"{target} already has a type-based API dispatcher.")
|
|
|
|
_, unwrapped = tf_decorator.unwrap(target)
|
|
target_argspec = tf_inspect.getargspec(unwrapped)
|
|
if target_argspec.varargs or target_argspec.keywords:
|
|
# @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
|
|
# and keywords. Examples of APIs that take varargs and kwargs: meshgrid,
|
|
# einsum, map_values, map_flat_values.
|
|
return target
|
|
|
|
setattr(
|
|
target, TYPE_BASED_DISPATCH_ATTR,
|
|
_api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
|
|
target_argspec.args,
|
|
target_argspec.defaults))
|
|
_TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
|
|
return target
|
|
|
|
|
|
def _check_signature(api_signature, func):
|
|
"""Checks that a dispatch target's signature is compatible with an API.
|
|
|
|
Args:
|
|
api_signature: The signature of the TensorFlow API.
|
|
func: The dispatch target.
|
|
|
|
Raises:
|
|
ValueError: if the signatures are incompatible. Two signatures are
|
|
considered compatible if they have the same number of parameters, and all
|
|
corresponding parameters have the same `name` and `kind`. (Parameters
|
|
are not required to have the same default value or the same annotation.)
|
|
"""
|
|
# Special case: if func_signature is (*args, **kwargs), then assume it's ok.
|
|
func_argspec = tf_inspect.getargspec(func)
|
|
if (func_argspec.varargs is not None and func_argspec.keywords is not None
|
|
and not func_argspec.args):
|
|
return
|
|
|
|
func_signature = tf_inspect.signature(func)
|
|
ok = len(api_signature.parameters) == len(func_signature.parameters)
|
|
if ok:
|
|
for param_1, param_2 in zip(api_signature.parameters.values(),
|
|
func_signature.parameters.values()):
|
|
if (param_1.name != param_2.name) or (param_1.kind != param_2.kind):
|
|
ok = False
|
|
if not ok:
|
|
raise ValueError(f"Dispatch function's signature {func_signature} does "
|
|
f"not match API's signature {api_signature}.")
|
|
|
|
|
|
def _make_signature_checker(api_signature, signature):
|
|
"""Builds a PySignatureChecker for the given type signature.
|
|
|
|
Args:
|
|
api_signature: The `inspect.Signature` of the API whose signature is
|
|
being checked.
|
|
signature: Dictionary mapping parameter names to type annotations.
|
|
|
|
Returns:
|
|
A `PySignatureChecker`.
|
|
"""
|
|
if not (isinstance(signature, dict) and
|
|
all(isinstance(k, (str, int)) for k in signature)):
|
|
raise TypeError("signatures must be dictionaries mapping parameter names "
|
|
"to type annotations.")
|
|
checkers = []
|
|
|
|
param_names = list(api_signature.parameters)
|
|
for param_name, param_type in signature.items():
|
|
# Convert positional parameters to named parameters.
|
|
if (isinstance(param_name, int) and
|
|
param_name < len(api_signature.parameters)):
|
|
param_name = list(api_signature.parameters.values())[param_name].name
|
|
|
|
# Check that the parameter exists, and has an appropriate kind.
|
|
param = api_signature.parameters.get(param_name, None)
|
|
if param is None:
|
|
raise ValueError("signature includes annotation for unknown "
|
|
f"parameter {param_name!r}.")
|
|
if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY,
|
|
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD):
|
|
raise ValueError("Dispatch currently only supports type annotations "
|
|
"for positional parameters; can't handle annotation "
|
|
f"for {param.kind!r} parameter {param_name}.")
|
|
|
|
checker = make_type_checker(param_type)
|
|
index = param_names.index(param_name)
|
|
checkers.append((index, checker))
|
|
|
|
return _api_dispatcher.PySignatureChecker(checkers)
|
|
|
|
|
|
# Cache for InstanceTypeChecker objects (we only want to create one
|
|
# InstanceTypeChecker for each type, since each one uses an internal cache
|
|
# to avoid repeated calls back into Python's isinstance).
|
|
_is_instance_checker_cache = {}
|
|
|
|
|
|
def make_type_checker(annotation):
|
|
"""Builds a PyTypeChecker for the given type annotation."""
|
|
if type_annotations.is_generic_union(annotation):
|
|
type_args = type_annotations.get_generic_type_args(annotation)
|
|
|
|
# If the union contains two or more simple types, then use a single
|
|
# InstanceChecker to check them.
|
|
simple_types = [t for t in type_args if isinstance(t, type)]
|
|
simple_types = tuple(sorted(simple_types, key=id))
|
|
if len(simple_types) > 1:
|
|
if simple_types not in _is_instance_checker_cache:
|
|
checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
|
|
_is_instance_checker_cache[simple_types] = checker
|
|
options = ([_is_instance_checker_cache[simple_types]] +
|
|
[make_type_checker(t) for t in type_args
|
|
if not isinstance(t, type)])
|
|
return _api_dispatcher.MakeUnionChecker(options)
|
|
|
|
options = [make_type_checker(t) for t in type_args]
|
|
return _api_dispatcher.MakeUnionChecker(options)
|
|
|
|
elif type_annotations.is_generic_list(annotation):
|
|
type_args = type_annotations.get_generic_type_args(annotation)
|
|
if len(type_args) != 1:
|
|
raise AssertionError("Expected List[...] to have a single type parameter")
|
|
elt_type = make_type_checker(type_args[0])
|
|
return _api_dispatcher.MakeListChecker(elt_type)
|
|
|
|
elif isinstance(annotation, type):
|
|
if annotation not in _is_instance_checker_cache:
|
|
checker = _api_dispatcher.MakeInstanceChecker(annotation)
|
|
_is_instance_checker_cache[annotation] = checker
|
|
return _is_instance_checker_cache[annotation]
|
|
|
|
elif annotation is None:
|
|
return make_type_checker(type(None))
|
|
|
|
else:
|
|
raise ValueError(f"Type annotation {annotation} is not currently supported"
|
|
" by dispatch. Supported annotations: type objects, "
|
|
" List[...], and Union[...]")
|
|
|
|
|
|
def _signature_from_annotations(func):
|
|
"""Builds a dict mapping from parameter names to type annotations."""
|
|
func_signature = tf_inspect.signature(func)
|
|
|
|
signature = dict([(name, param.annotation)
|
|
for (name, param) in func_signature.parameters.items()
|
|
if param.annotation != tf_inspect.Parameter.empty])
|
|
if not signature:
|
|
raise ValueError("The dispatch_for_api decorator must be called with at "
|
|
"least one signature, or applied to a function that "
|
|
"has type annotations on its parameters.")
|
|
return signature
|
|
|
|
|
|
# Registries for elementwise APIs and API handlers.
|
|
#
|
|
# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered
|
|
# as elementwise operations using the `register_*_elementwise_api`
|
|
# decorators.
|
|
#
|
|
# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API
|
|
# handlers that have been registered with the `dispatch_for_*_elementwise_apis`
|
|
# decorators.
|
|
#
|
|
# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of
|
|
# `(api, dispatch_target)` pairs. Used to impelement
|
|
# `unregister_elementwise_api_handler`.
|
|
_UNARY_ELEMENTWISE_APIS = []
|
|
_BINARY_ELEMENTWISE_APIS = []
|
|
_BINARY_ELEMENTWISE_ASSERT_APIS = []
|
|
_ELEMENTWISE_API_HANDLERS = {}
|
|
_ELEMENTWISE_API_TARGETS = {}
|
|
|
|
_ASSERT_API_TAG = "ASSERT_API_TAG"
|
|
|
|
|
|
@tf_export("experimental.dispatch_for_unary_elementwise_apis")
|
|
def dispatch_for_unary_elementwise_apis(x_type):
|
|
"""Decorator to override default implementation for unary elementwise APIs.
|
|
|
|
The decorated function (known as the "elementwise api handler") overrides
|
|
the default implementation for any unary elementwise API whenever the value
|
|
for the first argument (typically named `x`) matches the type annotation
|
|
`x_type`. The elementwise api handler is called with two arguments:
|
|
|
|
`elementwise_api_handler(api_func, x)`
|
|
|
|
Where `api_func` is a function that takes a single parameter and performs the
|
|
elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the
|
|
elementwise api.
|
|
|
|
The following example shows how this decorator can be used to update all
|
|
unary elementwise operations to handle a `MaskedTensor` type:
|
|
|
|
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
|
... values: tf.Tensor
|
|
... mask: tf.Tensor
|
|
>>> @dispatch_for_unary_elementwise_apis(MaskedTensor)
|
|
... def unary_elementwise_api_handler(api_func, x):
|
|
... return MaskedTensor(api_func(x.values), x.mask)
|
|
>>> mt = MaskedTensor([1, -2, -3], [True, False, True])
|
|
>>> abs_mt = tf.abs(mt)
|
|
>>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
|
|
values=[1 2 3], mask=[ True False True]
|
|
|
|
For unary elementwise operations that take extra arguments beyond `x`, those
|
|
arguments are *not* passed to the elementwise api handler, but are
|
|
automatically added when `api_func` is called. E.g., in the following
|
|
example, the `dtype` parameter is not passed to
|
|
`unary_elementwise_api_handler`, but is added by `api_func`.
|
|
|
|
>>> ones_mt = tf.ones_like(mt, dtype=tf.float32)
|
|
>>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
|
|
values=[1.0 1.0 1.0], mask=[ True False True]
|
|
|
|
Args:
|
|
x_type: A type annotation indicating when the api handler should be called.
|
|
See `dispatch_for_api` for a list of supported annotation types.
|
|
|
|
Returns:
|
|
A decorator.
|
|
|
|
#### Registered APIs
|
|
|
|
The unary elementwise APIs are:
|
|
|
|
<<API_LIST>>
|
|
"""
|
|
|
|
def decorator(handler):
|
|
if (x_type,) in _ELEMENTWISE_API_HANDLERS:
|
|
raise ValueError("A unary elementwise dispatch handler "
|
|
f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) "
|
|
f"has already been registered for {x_type}.")
|
|
_ELEMENTWISE_API_HANDLERS[(x_type,)] = handler
|
|
for api in _UNARY_ELEMENTWISE_APIS:
|
|
_add_dispatch_for_unary_elementwise_api(api, x_type, handler)
|
|
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
|
|
@tf_export("experimental.dispatch_for_binary_elementwise_apis")
|
|
def dispatch_for_binary_elementwise_apis(x_type, y_type):
|
|
"""Decorator to override default implementation for binary elementwise APIs.
|
|
|
|
The decorated function (known as the "elementwise api handler") overrides
|
|
the default implementation for any binary elementwise API whenever the value
|
|
for the first two arguments (typically named `x` and `y`) match the specified
|
|
type annotations. The elementwise api handler is called with two arguments:
|
|
|
|
`elementwise_api_handler(api_func, x, y)`
|
|
|
|
Where `x` and `y` are the first two arguments to the elementwise api, and
|
|
`api_func` is a TensorFlow function that takes two parameters and performs the
|
|
elementwise operation (e.g., `tf.add`).
|
|
|
|
The following example shows how this decorator can be used to update all
|
|
binary elementwise operations to handle a `MaskedTensor` type:
|
|
|
|
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
|
... values: tf.Tensor
|
|
... mask: tf.Tensor
|
|
>>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
|
|
... def binary_elementwise_api_handler(api_func, x, y):
|
|
... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
|
|
>>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False])
|
|
>>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True])
|
|
>>> c = tf.add(a, b)
|
|
>>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}")
|
|
values=[ 3 6 9 12 5], mask=[ True True True False False]
|
|
|
|
Args:
|
|
x_type: A type annotation indicating when the api handler should be called.
|
|
y_type: A type annotation indicating when the api handler should be called.
|
|
|
|
Returns:
|
|
A decorator.
|
|
|
|
#### Registered APIs
|
|
|
|
The binary elementwise APIs are:
|
|
|
|
<<API_LIST>>
|
|
"""
|
|
|
|
def decorator(handler):
|
|
if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS:
|
|
raise ValueError("A binary elementwise dispatch handler "
|
|
f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) "
|
|
f"has already been registered for ({x_type}, {y_type}).")
|
|
_ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler
|
|
for api in _BINARY_ELEMENTWISE_APIS:
|
|
_add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
|
|
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
|
|
@tf_export("experimental.dispatch_for_binary_elementwise_assert_apis")
|
|
def dispatch_for_binary_elementwise_assert_apis(x_type, y_type):
|
|
"""Decorator to override default implementation for binary elementwise assert APIs.
|
|
|
|
The decorated function (known as the "elementwise assert handler")
|
|
overrides the default implementation for any binary elementwise assert API
|
|
whenever the value for the first two arguments (typically named `x` and `y`)
|
|
match the specified type annotations. The handler is called with two
|
|
arguments:
|
|
|
|
`elementwise_assert_handler(assert_func, x, y)`
|
|
|
|
Where `x` and `y` are the first two arguments to the binary elementwise assert
|
|
operation, and `assert_func` is a TensorFlow function that takes two
|
|
parameters and performs the elementwise assert operation (e.g.,
|
|
`tf.debugging.assert_equal`).
|
|
|
|
The following example shows how this decorator can be used to update all
|
|
binary elementwise assert operations to handle a `MaskedTensor` type:
|
|
|
|
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
|
... values: tf.Tensor
|
|
... mask: tf.Tensor
|
|
>>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor)
|
|
... def binary_elementwise_assert_api_handler(assert_func, x, y):
|
|
... merged_mask = tf.logical_and(x.mask, y.mask)
|
|
... selected_x_values = tf.boolean_mask(x.values, merged_mask)
|
|
... selected_y_values = tf.boolean_mask(y.values, merged_mask)
|
|
... assert_func(selected_x_values, selected_y_values)
|
|
>>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True])
|
|
>>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False])
|
|
>>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
|
|
|
|
>>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True])
|
|
>>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True])
|
|
>>> tf.debugging.assert_greater(a, b)
|
|
Traceback (most recent call last):
|
|
...
|
|
InvalidArgumentError: Condition x > y did not hold.
|
|
|
|
Args:
|
|
x_type: A type annotation indicating when the api handler should be called.
|
|
y_type: A type annotation indicating when the api handler should be called.
|
|
|
|
Returns:
|
|
A decorator.
|
|
|
|
#### Registered APIs
|
|
|
|
The binary elementwise assert APIs are:
|
|
|
|
<<API_LIST>>
|
|
"""
|
|
|
|
def decorator(handler):
|
|
api_handler_key = (x_type, y_type, _ASSERT_API_TAG)
|
|
if api_handler_key in _ELEMENTWISE_API_HANDLERS:
|
|
raise ValueError("A binary elementwise assert dispatch handler "
|
|
f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) "
|
|
f"has already been registered for ({x_type}, {y_type}).")
|
|
_ELEMENTWISE_API_HANDLERS[api_handler_key] = handler
|
|
for api in _BINARY_ELEMENTWISE_ASSERT_APIS:
|
|
_add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
|
|
|
|
return handler
|
|
|
|
return decorator
|
|
|
|
|
|
def register_unary_elementwise_api(func):
|
|
"""Decorator that registers a TensorFlow op as a unary elementwise API."""
|
|
_UNARY_ELEMENTWISE_APIS.append(func)
|
|
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
|
|
if len(args) == 1:
|
|
_add_dispatch_for_unary_elementwise_api(func, args[0], handler)
|
|
return func
|
|
|
|
|
|
def register_binary_elementwise_api(func):
|
|
"""Decorator that registers a TensorFlow op as a binary elementwise API."""
|
|
_BINARY_ELEMENTWISE_APIS.append(func)
|
|
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
|
|
if len(args) == 2:
|
|
_add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
|
|
return func
|
|
|
|
|
|
def register_binary_elementwise_assert_api(func):
|
|
"""Decorator that registers a TensorFlow op as a binary elementwise assert API.
|
|
|
|
Different from `dispatch_for_binary_elementwise_apis`, this decorator is used
|
|
for assert apis, such as assert_equal, assert_none_equal, etc, which return
|
|
None in eager mode and an op in graph mode.
|
|
|
|
Args:
|
|
func: The function that implements the binary elementwise assert API.
|
|
|
|
Returns:
|
|
`func`
|
|
"""
|
|
_BINARY_ELEMENTWISE_ASSERT_APIS.append(func)
|
|
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
|
|
if len(args) == 3 and args[2] is _ASSERT_API_TAG:
|
|
_add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
|
|
return func
|
|
|
|
|
|
def unary_elementwise_apis():
|
|
"""Returns a list of APIs that have been registered as unary elementwise."""
|
|
return tuple(_UNARY_ELEMENTWISE_APIS)
|
|
|
|
|
|
def binary_elementwise_apis():
|
|
"""Returns a list of APIs that have been registered as binary elementwise."""
|
|
return tuple(_BINARY_ELEMENTWISE_APIS)
|
|
|
|
|
|
def _add_dispatch_for_unary_elementwise_api(api, x_type,
|
|
elementwise_api_handler):
|
|
"""Registers a unary elementwise handler as a dispatcher for a given API."""
|
|
api_signature = tf_inspect.signature(api)
|
|
x_name = list(api_signature.parameters)[0]
|
|
name_index = _find_name_index(api_signature)
|
|
|
|
need_to_bind_api_args = (
|
|
len(api_signature.parameters) > 2 or
|
|
"name" not in api_signature.parameters)
|
|
|
|
@dispatch_for_api(api, {x_name: x_type})
|
|
def dispatch_target(*args, **kwargs):
|
|
args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
|
|
if args:
|
|
x, args = args[0], args[1:]
|
|
else:
|
|
x = kwargs.pop(x_name)
|
|
|
|
if need_to_bind_api_args:
|
|
tensor_api = lambda v: api(v, *args, **kwargs)
|
|
else:
|
|
tensor_api = api
|
|
|
|
if name is None:
|
|
return elementwise_api_handler(tensor_api, x)
|
|
else:
|
|
with ops.name_scope(name, None, [x]):
|
|
return elementwise_api_handler(tensor_api, x)
|
|
|
|
dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
|
|
dispatch_target.__qualname__ = dispatch_target.__name__
|
|
# Keep track of what targets we've registered (so we can unregister them).
|
|
target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), [])
|
|
target_list.append((api, dispatch_target))
|
|
|
|
|
|
def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type,
|
|
elementwise_api_handler):
|
|
"""Registers a binary elementwise handler as a dispatcher for a given API."""
|
|
api_signature = tf_inspect.signature(api)
|
|
x_name, y_name = list(api_signature.parameters)[:2]
|
|
name_index = _find_name_index(api_signature)
|
|
|
|
need_to_bind_api_args = (len(api_signature.parameters) > 3 or
|
|
"name" not in api_signature.parameters)
|
|
|
|
@dispatch_for_api(api, {x_name: x_type, y_name: y_type})
|
|
def dispatch_target(*args, **kwargs):
|
|
args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
|
|
if len(args) > 1:
|
|
x, y, args = args[0], args[1], args[2:]
|
|
elif args:
|
|
x, args = args[0], args[1:]
|
|
y = kwargs.pop(y_name, None)
|
|
else:
|
|
x = kwargs.pop(x_name, None)
|
|
y = kwargs.pop(y_name, None)
|
|
|
|
if need_to_bind_api_args:
|
|
tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs)
|
|
else:
|
|
tensor_api = api
|
|
|
|
if name is None:
|
|
return elementwise_api_handler(tensor_api, x, y)
|
|
else:
|
|
with ops.name_scope(name, None, [x, y]):
|
|
return elementwise_api_handler(tensor_api, x, y)
|
|
|
|
dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
|
|
dispatch_target.__qualname__ = dispatch_target.__name__
|
|
# Keep track of what targets we've registered (so we can unregister them).
|
|
target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), [])
|
|
target_list.append((api, dispatch_target))
|
|
|
|
|
|
def _find_name_index(signature):
|
|
"""Returns the index of the `name` parameter, or -1 if it's not present."""
|
|
try:
|
|
return list(signature.parameters).index("name")
|
|
except ValueError:
|
|
return -1
|
|
|
|
|
|
def _extract_name_arg(args, kwargs, name_index):
|
|
"""Extracts the parameter `name` and returns `(args, kwargs, name_value)`."""
|
|
if name_index < 0:
|
|
name_value = None
|
|
elif name_index < len(args):
|
|
name_value = args[name_index]
|
|
args = args[:name_index] + args[name_index + 1:]
|
|
else:
|
|
name_value = kwargs.pop("name", None)
|
|
return args, kwargs, name_value
|
|
|
|
|
|
def update_docstrings_with_api_lists():
|
|
"""Updates the docstrings of dispatch decorators with API lists.
|
|
|
|
Updates docstrings for `dispatch_for_api`,
|
|
`dispatch_for_unary_elementwise_apis`, and
|
|
`dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>'
|
|
with a list of APIs that have been registered for that decorator.
|
|
"""
|
|
_update_docstring_with_api_list(dispatch_for_unary_elementwise_apis,
|
|
_UNARY_ELEMENTWISE_APIS)
|
|
_update_docstring_with_api_list(dispatch_for_binary_elementwise_apis,
|
|
_BINARY_ELEMENTWISE_APIS)
|
|
_update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis,
|
|
_BINARY_ELEMENTWISE_ASSERT_APIS)
|
|
_update_docstring_with_api_list(dispatch_for_api,
|
|
_TYPE_BASED_DISPATCH_SIGNATURES)
|
|
|
|
|
|
def _update_docstring_with_api_list(target, api_list):
|
|
"""Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
|
|
lines = []
|
|
for func in api_list:
|
|
name = tf_export_lib.get_canonical_name_for_symbol(
|
|
func, add_prefix_to_v1_names=True)
|
|
if name is not None:
|
|
params = tf_inspect.signature(func).parameters.keys()
|
|
lines.append(f" * `tf.{name}({', '.join(params)})`")
|
|
lines.sort()
|
|
target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines))
|
|
|
|
|
|
################################################################################
|
|
# Dispatch Support
|
|
################################################################################
|
|
@tf_export("__internal__.dispatch.add_dispatch_support", v1=[])
|
|
def add_dispatch_support(target=None, iterable_parameters=None):
|
|
"""Decorator that adds a dispatch handling wrapper to a TensorFlow Python API.
|
|
|
|
This wrapper adds the decorated function as an API that can be overridden
|
|
using the `@dispatch_for_api` decorator. In the following example, we first
|
|
define a new API (`double`) that supports dispatch, then define a custom type
|
|
(`MaskedTensor`) and finally use `dispatch_for_api` to override the default
|
|
implementation of `double` when called with `MaskedTensor` values:
|
|
|
|
>>> @add_dispatch_support
|
|
... def double(x):
|
|
... return x * 2
|
|
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
|
... values: tf.Tensor
|
|
... mask: tf.Tensor
|
|
>>> @dispatch_for_api(double, {'x': MaskedTensor})
|
|
... def masked_double(x):
|
|
... return MaskedTensor(x.values * 2, y.mask)
|
|
|
|
The optional `iterable_parameter` argument can be used to mark parameters that
|
|
can take arbitrary iterable values (such as generator expressions). These
|
|
need to be handled specially during dispatch, since just iterating over an
|
|
iterable uses up its values. In the following example, we define a new API
|
|
whose second argument can be an iterable value; and then override the default
|
|
implementatio of that API when the iterable contains MaskedTensors:
|
|
|
|
>>> @add_dispatch_support(iterable_parameters=['ys'])
|
|
... def add_tensor_to_list_of_tensors(x, ys):
|
|
... return [x + y for y in ys]
|
|
>>> @dispatch_for_api(add_tensor_to_list_of_tensors,
|
|
... {'ys': typing.List[MaskedTensor]})
|
|
... def masked_add_tensor_to_list_of_tensors(x, ys):
|
|
... return [MaskedTensor(x+y.values, y.mask) for y in ys]
|
|
|
|
(Note: the only TensorFlow API that currently supports iterables is `add_n`.)
|
|
|
|
Args:
|
|
target: The TensorFlow API that should support dispatch.
|
|
iterable_parameters: Optional list of parameter names that may be called
|
|
with iterables (such as the `inputs` parameter for `tf.add_n`).
|
|
|
|
Returns:
|
|
A decorator.
|
|
"""
|
|
|
|
if not (iterable_parameters is None or
|
|
(isinstance(iterable_parameters, (list, tuple)) and
|
|
all(isinstance(p, str) for p in iterable_parameters))):
|
|
raise TypeError("iterable_parameters should be a list or tuple of string.")
|
|
|
|
def decorator(dispatch_target):
|
|
|
|
# Get the name & index for each iterable parameter.
|
|
if iterable_parameters is None:
|
|
iterable_params = None
|
|
else:
|
|
arg_names = tf_inspect.getargspec(dispatch_target).args
|
|
iterable_params = [
|
|
(name, arg_names.index(name)) for name in iterable_parameters
|
|
]
|
|
|
|
@traceback_utils.filter_traceback
|
|
def op_dispatch_handler(*args, **kwargs):
|
|
"""Call `dispatch_target`, peforming dispatch when appropriate."""
|
|
|
|
# Type-based dispatch system (dispatch v2):
|
|
if api_dispatcher is not None:
|
|
if iterable_params is not None:
|
|
args, kwargs = replace_iterable_params(args, kwargs, iterable_params)
|
|
result = api_dispatcher.Dispatch(args, kwargs)
|
|
if result is not NotImplemented:
|
|
return result
|
|
|
|
# Fallback dispatch system (dispatch v1):
|
|
try:
|
|
return dispatch_target(*args, **kwargs)
|
|
except (TypeError, ValueError):
|
|
# Note: convert_to_eager_tensor currently raises a ValueError, not a
|
|
# TypeError, when given unexpected types. So we need to catch both.
|
|
result = dispatch(op_dispatch_handler, args, kwargs)
|
|
if result is not OpDispatcher.NOT_SUPPORTED:
|
|
return result
|
|
else:
|
|
raise
|
|
|
|
add_fallback_dispatch_list(op_dispatch_handler)
|
|
op_dispatch_handler = tf_decorator.make_decorator(dispatch_target,
|
|
op_dispatch_handler)
|
|
add_type_based_api_dispatcher(op_dispatch_handler)
|
|
api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR,
|
|
None)
|
|
return op_dispatch_handler
|
|
|
|
if target is None:
|
|
return decorator
|
|
else:
|
|
return decorator(target)
|
|
|
|
|
|
def replace_iterable_params(args, kwargs, iterable_params):
|
|
"""Returns (args, kwargs) with any iterable parameters converted to lists.
|
|
|
|
Args:
|
|
args: Positional rguments to a function
|
|
kwargs: Keyword arguments to a function.
|
|
iterable_params: A list of (name, index) tuples for iterable parameters.
|
|
|
|
Returns:
|
|
A tuple (args, kwargs), where any positional or keyword parameters in
|
|
`iterable_params` have their value converted to a `list`.
|
|
"""
|
|
args = list(args)
|
|
for name, index in iterable_params:
|
|
if index < len(args):
|
|
args[index] = list(args[index])
|
|
elif name in kwargs:
|
|
kwargs[name] = list(kwargs[name])
|
|
return tuple(args), kwargs
|