# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Type-based dispatch for TensorFlow's Python APIs. "Python APIs" refers to Python functions that have been exported with `tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also referred to as "ops". There are currently two dispatch systems for TensorFlow: * The "fallback dispatch" system calls an API's standard implementation first, and only tries to perform dispatch if that standard implementation raises a TypeError (or ValueError) exception. * The "type-based dispatch" system checks the types of the parameters passed to an API, and performs dispatch if those types match any signatures that have been registered for dispatch. The fallback dispatch system was the original dispatch system, but it was somewhat brittle and had limitations, such as an inability to support dispatch for some operations (like convert_to_tensor). We plan to remove the fallback dispatch system in favor of the type-based dispatch system, once all users have been switched over to use it. ### Fallback Dispatch The fallback dispatch system is based on "operation dispatchers", which can be used to override the behavior for TensorFlow ops when they are called with otherwise unsupported argument types. In particular, when an operation is called with arguments that would cause it to raise a TypeError, it falls back on its registered operation dispatchers. If any registered dispatchers can handle the arguments, then its result is returned. Otherwise, the original TypeError is raised. ### Type-based Dispatch The main interface for the type-based dispatch system is the `dispatch_for_api` decorator, which overrides the default implementation for a TensorFlow API. The decorated function (known as the "dispatch target") will override the default implementation for the API when the API is called with parameters that match a specified type signature. ### Dispatch Support By default, dispatch support is added to the generated op wrappers for any visible ops by default. APIs/ops that are implemented in Python can opt in to dispatch support using the `add_dispatch_support` decorator. """ import collections import itertools import typing # pylint: disable=unused-import (used in doctests) from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher from tensorflow.python.framework import ops from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_export as tf_export_lib from tensorflow.python.util import tf_inspect from tensorflow.python.util import traceback_utils from tensorflow.python.util import type_annotations from tensorflow.python.util.tf_export import tf_export # Private function attributes used to store dispatchers on TensorFlow APIs. FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers" TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher" # OpDispatchers which should be used for all operations. _GLOBAL_DISPATCHERS = [] ################################################################################ # Fallback Dispatch ################################################################################ @tf_export("__internal__.dispatch.OpDispatcher", v1=[]) class OpDispatcher(object): """Abstract base class for TensorFlow operator dispatchers. Each operation dispatcher acts as an override handler for a single TensorFlow operation, and its results are used when the handler indicates that it can handle the operation's arguments (by returning any value other than `OpDispatcher.NOT_SUPPORTED`). """ # Sentinel value that can be returned to indicate that an operation # dispatcher does not support a given set of arguments. NOT_SUPPORTED = object() def handle(self, args, kwargs): # pylint: disable=unused-argument """Handle this dispatcher's operation with the specified arguments. If this operation dispatcher can handle the given arguments, then return an appropriate value (or raise an appropriate exception). Args: args: The arguments to the operation. kwargs: They keyword arguments to the operation. Returns: The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this dispatcher can not handle the given arguments. """ return self.NOT_SUPPORTED def register(self, op): """Register this dispatcher as a handler for `op`. Args: op: Python function: the TensorFlow operation that should be handled. Must have a dispatch list (which is added automatically for generated ops, and can be added to Python ops using the `add_dispatch_support` decorator). """ if not hasattr(op, FALLBACK_DISPATCH_ATTR): raise AssertionError("Dispatching not enabled for %s" % op) getattr(op, FALLBACK_DISPATCH_ATTR).append(self) @tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[]) class GlobalOpDispatcher(object): """Abstract base class for TensorFlow global operator dispatchers.""" NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED def handle(self, op, args, kwargs): """Handle the specified operation with the specified arguments.""" def register(self): """Register this dispatcher as a handler for all ops.""" _GLOBAL_DISPATCHERS.append(self) def dispatch(op, args, kwargs): """Returns the result from the first successful dispatcher for a given op. Calls the `handle` method of each `OpDispatcher` that has been registered to handle `op`, and returns the value from the first successful handler. Args: op: Python function: the operation to dispatch for. args: The arguments to the operation. kwargs: They keyword arguments to the operation. Returns: The result of the operation, or `NOT_SUPPORTED` if no registered dispatcher can handle the given arguments. """ for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR): result = dispatcher.handle(args, kwargs) if result is not OpDispatcher.NOT_SUPPORTED: return result for dispatcher in _GLOBAL_DISPATCHERS: result = dispatcher.handle(op, args, kwargs) if result is not OpDispatcher.NOT_SUPPORTED: return result return OpDispatcher.NOT_SUPPORTED class _TypeBasedDispatcher(OpDispatcher): """Dispatcher that handles op if any arguments have a specified type. Checks the types of the arguments and keyword arguments (including elements of lists or tuples), and if any argument values have the indicated type(s), then delegates to an override function. """ def __init__(self, override_func, types): self._types = types self._override_func = override_func def _handles(self, args, kwargs): for arg in itertools.chain(args, kwargs.values()): if (isinstance(arg, self._types) or (isinstance(arg, (list, tuple)) and any(isinstance(elt, self._types) for elt in arg))): return True return False def handle(self, args, kwargs): if self._handles(args, kwargs): return self._override_func(*args, **kwargs) else: return self.NOT_SUPPORTED # pylint: disable=g-doc-return-or-yield def dispatch_for_types(op, *types): """Decorator to declare that a Python function overrides an op for a type. The decorated function is used to override `op` if any of the arguments or keyword arguments (including elements of lists or tuples) have one of the specified types. Example: ```python @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue) def ragged_add(x, y, name=None): ... ``` Args: op: Python function: the operation that should be overridden. *types: The argument types for which this function should be used. """ def decorator(func): if tf_inspect.getargspec(func) != tf_inspect.getargspec(op): raise AssertionError("The decorated function's signature must exactly " "match the signature of the overridden op.") _TypeBasedDispatcher(func, types).register(op) return func return decorator # pylint: enable=g-doc-return-or-yield def add_fallback_dispatch_list(target): """Decorator that adds a dispatch_list attribute to an op.""" if hasattr(target, FALLBACK_DISPATCH_ATTR): raise AssertionError("%s already has a dispatch list" % target) setattr(target, FALLBACK_DISPATCH_ATTR, []) return target # Alias for backwards-compatibility. add_dispatch_list = add_fallback_dispatch_list ################################################################################ # Type-based Dispatch ################################################################################ @tf_export("experimental.dispatch_for_api") def dispatch_for_api(api, *signatures): """Decorator that overrides the default implementation for a TensorFlow API. The decorated function (known as the "dispatch target") will override the default implementation for the API when the API is called with parameters that match a specified type signature. Signatures are specified using dictionaries that map parameter names to type annotations. E.g., in the following example, `masked_add` will be called for `tf.add` if both `x` and `y` are `MaskedTensor`s: >>> class MaskedTensor(tf.experimental.ExtensionType): ... values: tf.Tensor ... mask: tf.Tensor >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor}) ... def masked_add(x, y, name=None): ... return MaskedTensor(x.values + y.values, x.mask & y.mask) >>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True)) >>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}") values=[11 12], mask=[ True False] If multiple type signatures are specified, then the dispatch target will be called if any of the signatures match. For example, the following code registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is a `MaskedTensor`. >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor}) ... def masked_add(x, y): ... x_values = x.values if isinstance(x, MaskedTensor) else x ... x_mask = x.mask if isinstance(x, MaskedTensor) else True ... y_values = y.values if isinstance(y, MaskedTensor) else y ... y_mask = y.mask if isinstance(y, MaskedTensor) else True ... return MaskedTensor(x_values + y_values, x_mask & y_mask) The type annotations in type signatures may be type objects (e.g., `MaskedTensor`), `typing.List` values, or `typing.Union` values. For example, the following will register `masked_concat` to be called if `values` is a list of `MaskedTensor` values: >>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]}) ... def masked_concat(values, axis): ... return MaskedTensor(tf.concat([v.values for v in values], axis), ... tf.concat([v.mask for v in values], axis)) Each type signature must contain at least one subclass of `tf.CompositeTensor` (which includes subclasses of `tf.ExtensionType`), and dispatch will only be triggered if at least one type-annotated parameter contains a `CompositeTensor` value. This rule avoids invoking dispatch in degenerate cases, such as the following examples: * `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not dispatch to the decorated dispatch target when the user calls `tf.concat([])`. * `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y': Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch target when the user calls `tf.add(tf.constant(1), tf.constant(2))`. The dispatch target's signature must match the signature of the API that is being overridden. In particular, parameters must have the same names, and must occur in the same order. The dispatch target may optionally elide the "name" parameter, in which case it will be wrapped with a call to `tf.name_scope` when appropraite. Args: api: The TensorFlow API to override. *signatures: Dictionaries mapping parameter names or indices to type annotations, specifying when the dispatch target should be called. In particular, the dispatch target will be called if any signature matches; and a signature matches if all of the specified parameters have types that match with the indicated type annotations. If no signatures are specified, then a signature will be read from the dispatch target function's type annotations. Returns: A decorator that overrides the default implementation for `api`. #### Registered APIs The TensorFlow APIs that may be overridden by `@dispatch_for_api` are: <> """ dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None) if dispatcher is None: raise ValueError(f"{api} does not support dispatch.") api_signature = tf_inspect.signature(api) signature_checkers = [ _make_signature_checker(api_signature, signature) for signature in signatures ] def decorator(dispatch_target): """Decorator that registers the given dispatch target.""" if not callable(dispatch_target): raise TypeError("Expected dispatch_target to be callable; " f"got {dispatch_target!r}") dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature) _check_signature(api_signature, dispatch_target) for signature_checker in signature_checkers: dispatcher.Register(signature_checker, dispatch_target) _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures) if not signature_checkers: signature = _signature_from_annotations(dispatch_target) checker = _make_signature_checker(api_signature, signature) dispatcher.Register(checker, dispatch_target) _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature) return dispatch_target return decorator # Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`, # which can be used for documentation generation and for improved error messages # when APIs are called with unsupported types. _TYPE_BASED_DISPATCH_SIGNATURES = {} def apis_with_type_based_dispatch(): """Returns a list of TensorFlow APIs that support type-based dispatch.""" return sorted( _TYPE_BASED_DISPATCH_SIGNATURES, key=lambda api: f"{api.__module__}.{api.__name__}") def type_based_dispatch_signatures_for(cls): """Returns dispatch signatures that have been registered for a given class. This function is intended for documentation-generation purposes. Args: cls: The class to search for. Type signatures are searched recursively, so e.g., if `cls=RaggedTensor`, then information will be returned for all dispatch targets that have `RaggedTensor` anywhere in their type annotations (including nested in `typing.Union` or `typing.List`.) Returns: A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API function; and `signatures` is a list of dispatch signatures for `api` that include `cls`. (Each signature is a dict mapping argument names to type annotations; see `dispatch_for_api` for more info.) """ def contains_cls(x): """Returns true if `x` contains `cls`.""" if isinstance(x, dict): return any(contains_cls(v) for v in x.values()) elif x is cls: return True elif (type_annotations.is_generic_list(x) or type_annotations.is_generic_union(x)): type_args = type_annotations.get_generic_type_args(x) return any(contains_cls(arg) for arg in type_args) else: return False result = {} for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items(): for _, signatures in api_signatures.items(): filtered = list(filter(contains_cls, signatures)) if filtered: result.setdefault(api, []).extend(filtered) return result # TODO(edloper): Consider using a mechanism like this to automatically add # the `name` argument to all TensorFlow APIs that are implemented in Python # (so each Python function doesn't need to do it manually). def _add_name_scope_wrapper(func, api_signature): """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`. If `func` already expects a "name" arg, or if `api_signature` does not expect a "name" arg, then returns `func` as-is. Args: func: The function to wrap. Signature must match `api_signature` (except the "name" parameter may be missing. api_signature: The signature of the original API (used to find the index for the "name" parameter). Returns: The wrapped function (or the original function if no wrapping is needed). """ if "name" not in api_signature.parameters: return func # no wrapping needed (API has no name parameter). func_signature = tf_inspect.signature(func) func_argspec = tf_inspect.getargspec(func) if "name" in func_signature.parameters or func_argspec.keywords is not None: return func # No wrapping needed (already has name parameter). name_index = list(api_signature.parameters).index("name") def wrapped_func(*args, **kwargs): if name_index < len(args): name = args[name_index] args = args[:name_index] + args[name_index + 1:] else: name = kwargs.pop("name", None) if name is None: return func(*args, **kwargs) else: with ops.name_scope(name): return func(*args, **kwargs) wrapped_func = tf_decorator.make_decorator(func, wrapped_func) wrapped_func.__signature__ = func_signature.replace( parameters=(list(func_signature.parameters.values()) + [api_signature.parameters["name"]])) del wrapped_func._tf_decorator return wrapped_func @tf_export("experimental.unregister_dispatch_for") def unregister_dispatch_for(dispatch_target): """Unregisters a function that was registered with `@dispatch_for_*`. This is primarily intended for testing purposes. Example: >>> # Define a type and register a dispatcher to override `tf.abs`: >>> class MyTensor(tf.experimental.ExtensionType): ... value: tf.Tensor >>> @tf.experimental.dispatch_for_api(tf.abs) ... def my_abs(x: MyTensor): ... return MyTensor(tf.abs(x.value)) >>> tf.abs(MyTensor(5)) MyTensor(value=) >>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`. >>> unregister_dispatch_for(my_abs) >>> tf.abs(MyTensor(5)) Traceback (most recent call last): ... ValueError: Attempt to convert a value ... to a Tensor. Args: dispatch_target: The function to unregister. Raises: ValueError: If `dispatch_target` was not registered using `@dispatch_for`, `@dispatch_for_unary_elementwise_apis`, or `@dispatch_for_binary_elementwise_apis`. """ found = False # Check if dispatch_target registered by `@dispatch_for_api` for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items(): if dispatch_target in signatures: dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR) dispatcher.Unregister(dispatch_target) del signatures[dispatch_target] found = True # Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis` elementwise_keys_to_delete = [ key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items() if handler is dispatch_target ] for key in set(elementwise_keys_to_delete): for _, target in _ELEMENTWISE_API_TARGETS[key]: unregister_dispatch_for(target) del _ELEMENTWISE_API_HANDLERS[key] del _ELEMENTWISE_API_TARGETS[key] found = True if not found: raise ValueError(f"Function {dispatch_target} was not registered using " "a `@dispatch_for_*` decorator.") def register_dispatchable_type(cls): """Class decorator that registers a type for use with type-based dispatch. Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType` (which are automatically registered). Note: this function is intended to support internal legacy use cases (such as RaggedTensorValue), and will probably not be exposed as a public API. Args: cls: The class to register. Returns: `cls`. """ _api_dispatcher.register_dispatchable_type(cls) return cls def add_type_based_api_dispatcher(target): """Adds a PythonAPIDispatcher to the given TensorFlow API function.""" if hasattr(target, TYPE_BASED_DISPATCH_ATTR): raise ValueError(f"{target} already has a type-based API dispatcher.") _, unwrapped = tf_decorator.unwrap(target) target_argspec = tf_inspect.getargspec(unwrapped) if target_argspec.varargs or target_argspec.keywords: # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs # and keywords. Examples of APIs that take varargs and kwargs: meshgrid, # einsum, map_values, map_flat_values. return target setattr( target, TYPE_BASED_DISPATCH_ATTR, _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__, target_argspec.args, target_argspec.defaults)) _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list) return target def _check_signature(api_signature, func): """Checks that a dispatch target's signature is compatible with an API. Args: api_signature: The signature of the TensorFlow API. func: The dispatch target. Raises: ValueError: if the signatures are incompatible. Two signatures are considered compatible if they have the same number of parameters, and all corresponding parameters have the same `name` and `kind`. (Parameters are not required to have the same default value or the same annotation.) """ # Special case: if func_signature is (*args, **kwargs), then assume it's ok. func_argspec = tf_inspect.getargspec(func) if (func_argspec.varargs is not None and func_argspec.keywords is not None and not func_argspec.args): return func_signature = tf_inspect.signature(func) ok = len(api_signature.parameters) == len(func_signature.parameters) if ok: for param_1, param_2 in zip(api_signature.parameters.values(), func_signature.parameters.values()): if (param_1.name != param_2.name) or (param_1.kind != param_2.kind): ok = False if not ok: raise ValueError(f"Dispatch function's signature {func_signature} does " f"not match API's signature {api_signature}.") def _make_signature_checker(api_signature, signature): """Builds a PySignatureChecker for the given type signature. Args: api_signature: The `inspect.Signature` of the API whose signature is being checked. signature: Dictionary mapping parameter names to type annotations. Returns: A `PySignatureChecker`. """ if not (isinstance(signature, dict) and all(isinstance(k, (str, int)) for k in signature)): raise TypeError("signatures must be dictionaries mapping parameter names " "to type annotations.") checkers = [] param_names = list(api_signature.parameters) for param_name, param_type in signature.items(): # Convert positional parameters to named parameters. if (isinstance(param_name, int) and param_name < len(api_signature.parameters)): param_name = list(api_signature.parameters.values())[param_name].name # Check that the parameter exists, and has an appropriate kind. param = api_signature.parameters.get(param_name, None) if param is None: raise ValueError("signature includes annotation for unknown " f"parameter {param_name!r}.") if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY, tf_inspect.Parameter.POSITIONAL_OR_KEYWORD): raise ValueError("Dispatch currently only supports type annotations " "for positional parameters; can't handle annotation " f"for {param.kind!r} parameter {param_name}.") checker = make_type_checker(param_type) index = param_names.index(param_name) checkers.append((index, checker)) return _api_dispatcher.PySignatureChecker(checkers) # Cache for InstanceTypeChecker objects (we only want to create one # InstanceTypeChecker for each type, since each one uses an internal cache # to avoid repeated calls back into Python's isinstance). _is_instance_checker_cache = {} def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) # If the union contains two or more simple types, then use a single # InstanceChecker to check them. simple_types = [t for t in type_args if isinstance(t, type)] simple_types = tuple(sorted(simple_types, key=id)) if len(simple_types) > 1: if simple_types not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(*simple_types) _is_instance_checker_cache[simple_types] = checker options = ([_is_instance_checker_cache[simple_types]] + [make_type_checker(t) for t in type_args if not isinstance(t, type)]) return _api_dispatcher.MakeUnionChecker(options) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError("Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError(f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]") def _signature_from_annotations(func): """Builds a dict mapping from parameter names to type annotations.""" func_signature = tf_inspect.signature(func) signature = dict([(name, param.annotation) for (name, param) in func_signature.parameters.items() if param.annotation != tf_inspect.Parameter.empty]) if not signature: raise ValueError("The dispatch_for_api decorator must be called with at " "least one signature, or applied to a function that " "has type annotations on its parameters.") return signature # Registries for elementwise APIs and API handlers. # # _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered # as elementwise operations using the `register_*_elementwise_api` # decorators. # # _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API # handlers that have been registered with the `dispatch_for_*_elementwise_apis` # decorators. # # _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of # `(api, dispatch_target)` pairs. Used to impelement # `unregister_elementwise_api_handler`. _UNARY_ELEMENTWISE_APIS = [] _BINARY_ELEMENTWISE_APIS = [] _BINARY_ELEMENTWISE_ASSERT_APIS = [] _ELEMENTWISE_API_HANDLERS = {} _ELEMENTWISE_API_TARGETS = {} _ASSERT_API_TAG = "ASSERT_API_TAG" @tf_export("experimental.dispatch_for_unary_elementwise_apis") def dispatch_for_unary_elementwise_apis(x_type): """Decorator to override default implementation for unary elementwise APIs. The decorated function (known as the "elementwise api handler") overrides the default implementation for any unary elementwise API whenever the value for the first argument (typically named `x`) matches the type annotation `x_type`. The elementwise api handler is called with two arguments: `elementwise_api_handler(api_func, x)` Where `api_func` is a function that takes a single parameter and performs the elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the elementwise api. The following example shows how this decorator can be used to update all unary elementwise operations to handle a `MaskedTensor` type: >>> class MaskedTensor(tf.experimental.ExtensionType): ... values: tf.Tensor ... mask: tf.Tensor >>> @dispatch_for_unary_elementwise_apis(MaskedTensor) ... def unary_elementwise_api_handler(api_func, x): ... return MaskedTensor(api_func(x.values), x.mask) >>> mt = MaskedTensor([1, -2, -3], [True, False, True]) >>> abs_mt = tf.abs(mt) >>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}") values=[1 2 3], mask=[ True False True] For unary elementwise operations that take extra arguments beyond `x`, those arguments are *not* passed to the elementwise api handler, but are automatically added when `api_func` is called. E.g., in the following example, the `dtype` parameter is not passed to `unary_elementwise_api_handler`, but is added by `api_func`. >>> ones_mt = tf.ones_like(mt, dtype=tf.float32) >>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}") values=[1.0 1.0 1.0], mask=[ True False True] Args: x_type: A type annotation indicating when the api handler should be called. See `dispatch_for_api` for a list of supported annotation types. Returns: A decorator. #### Registered APIs The unary elementwise APIs are: <> """ def decorator(handler): if (x_type,) in _ELEMENTWISE_API_HANDLERS: raise ValueError("A unary elementwise dispatch handler " f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) " f"has already been registered for {x_type}.") _ELEMENTWISE_API_HANDLERS[(x_type,)] = handler for api in _UNARY_ELEMENTWISE_APIS: _add_dispatch_for_unary_elementwise_api(api, x_type, handler) return handler return decorator @tf_export("experimental.dispatch_for_binary_elementwise_apis") def dispatch_for_binary_elementwise_apis(x_type, y_type): """Decorator to override default implementation for binary elementwise APIs. The decorated function (known as the "elementwise api handler") overrides the default implementation for any binary elementwise API whenever the value for the first two arguments (typically named `x` and `y`) match the specified type annotations. The elementwise api handler is called with two arguments: `elementwise_api_handler(api_func, x, y)` Where `x` and `y` are the first two arguments to the elementwise api, and `api_func` is a TensorFlow function that takes two parameters and performs the elementwise operation (e.g., `tf.add`). The following example shows how this decorator can be used to update all binary elementwise operations to handle a `MaskedTensor` type: >>> class MaskedTensor(tf.experimental.ExtensionType): ... values: tf.Tensor ... mask: tf.Tensor >>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor) ... def binary_elementwise_api_handler(api_func, x, y): ... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask) >>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False]) >>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True]) >>> c = tf.add(a, b) >>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}") values=[ 3 6 9 12 5], mask=[ True True True False False] Args: x_type: A type annotation indicating when the api handler should be called. y_type: A type annotation indicating when the api handler should be called. Returns: A decorator. #### Registered APIs The binary elementwise APIs are: <> """ def decorator(handler): if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS: raise ValueError("A binary elementwise dispatch handler " f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) " f"has already been registered for ({x_type}, {y_type}).") _ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler for api in _BINARY_ELEMENTWISE_APIS: _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler) return handler return decorator @tf_export("experimental.dispatch_for_binary_elementwise_assert_apis") def dispatch_for_binary_elementwise_assert_apis(x_type, y_type): """Decorator to override default implementation for binary elementwise assert APIs. The decorated function (known as the "elementwise assert handler") overrides the default implementation for any binary elementwise assert API whenever the value for the first two arguments (typically named `x` and `y`) match the specified type annotations. The handler is called with two arguments: `elementwise_assert_handler(assert_func, x, y)` Where `x` and `y` are the first two arguments to the binary elementwise assert operation, and `assert_func` is a TensorFlow function that takes two parameters and performs the elementwise assert operation (e.g., `tf.debugging.assert_equal`). The following example shows how this decorator can be used to update all binary elementwise assert operations to handle a `MaskedTensor` type: >>> class MaskedTensor(tf.experimental.ExtensionType): ... values: tf.Tensor ... mask: tf.Tensor >>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor) ... def binary_elementwise_assert_api_handler(assert_func, x, y): ... merged_mask = tf.logical_and(x.mask, y.mask) ... selected_x_values = tf.boolean_mask(x.values, merged_mask) ... selected_y_values = tf.boolean_mask(y.values, merged_mask) ... assert_func(selected_x_values, selected_y_values) >>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True]) >>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False]) >>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown >>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True]) >>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True]) >>> tf.debugging.assert_greater(a, b) Traceback (most recent call last): ... InvalidArgumentError: Condition x > y did not hold. Args: x_type: A type annotation indicating when the api handler should be called. y_type: A type annotation indicating when the api handler should be called. Returns: A decorator. #### Registered APIs The binary elementwise assert APIs are: <> """ def decorator(handler): api_handler_key = (x_type, y_type, _ASSERT_API_TAG) if api_handler_key in _ELEMENTWISE_API_HANDLERS: raise ValueError("A binary elementwise assert dispatch handler " f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) " f"has already been registered for ({x_type}, {y_type}).") _ELEMENTWISE_API_HANDLERS[api_handler_key] = handler for api in _BINARY_ELEMENTWISE_ASSERT_APIS: _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler) return handler return decorator def register_unary_elementwise_api(func): """Decorator that registers a TensorFlow op as a unary elementwise API.""" _UNARY_ELEMENTWISE_APIS.append(func) for args, handler in _ELEMENTWISE_API_HANDLERS.items(): if len(args) == 1: _add_dispatch_for_unary_elementwise_api(func, args[0], handler) return func def register_binary_elementwise_api(func): """Decorator that registers a TensorFlow op as a binary elementwise API.""" _BINARY_ELEMENTWISE_APIS.append(func) for args, handler in _ELEMENTWISE_API_HANDLERS.items(): if len(args) == 2: _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler) return func def register_binary_elementwise_assert_api(func): """Decorator that registers a TensorFlow op as a binary elementwise assert API. Different from `dispatch_for_binary_elementwise_apis`, this decorator is used for assert apis, such as assert_equal, assert_none_equal, etc, which return None in eager mode and an op in graph mode. Args: func: The function that implements the binary elementwise assert API. Returns: `func` """ _BINARY_ELEMENTWISE_ASSERT_APIS.append(func) for args, handler in _ELEMENTWISE_API_HANDLERS.items(): if len(args) == 3 and args[2] is _ASSERT_API_TAG: _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler) return func def unary_elementwise_apis(): """Returns a list of APIs that have been registered as unary elementwise.""" return tuple(_UNARY_ELEMENTWISE_APIS) def binary_elementwise_apis(): """Returns a list of APIs that have been registered as binary elementwise.""" return tuple(_BINARY_ELEMENTWISE_APIS) def _add_dispatch_for_unary_elementwise_api(api, x_type, elementwise_api_handler): """Registers a unary elementwise handler as a dispatcher for a given API.""" api_signature = tf_inspect.signature(api) x_name = list(api_signature.parameters)[0] name_index = _find_name_index(api_signature) need_to_bind_api_args = ( len(api_signature.parameters) > 2 or "name" not in api_signature.parameters) @dispatch_for_api(api, {x_name: x_type}) def dispatch_target(*args, **kwargs): args, kwargs, name = _extract_name_arg(args, kwargs, name_index) if args: x, args = args[0], args[1:] else: x = kwargs.pop(x_name) if need_to_bind_api_args: tensor_api = lambda v: api(v, *args, **kwargs) else: tensor_api = api if name is None: return elementwise_api_handler(tensor_api, x) else: with ops.name_scope(name, None, [x]): return elementwise_api_handler(tensor_api, x) dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__ dispatch_target.__qualname__ = dispatch_target.__name__ # Keep track of what targets we've registered (so we can unregister them). target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), []) target_list.append((api, dispatch_target)) def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, elementwise_api_handler): """Registers a binary elementwise handler as a dispatcher for a given API.""" api_signature = tf_inspect.signature(api) x_name, y_name = list(api_signature.parameters)[:2] name_index = _find_name_index(api_signature) need_to_bind_api_args = (len(api_signature.parameters) > 3 or "name" not in api_signature.parameters) @dispatch_for_api(api, {x_name: x_type, y_name: y_type}) def dispatch_target(*args, **kwargs): args, kwargs, name = _extract_name_arg(args, kwargs, name_index) if len(args) > 1: x, y, args = args[0], args[1], args[2:] elif args: x, args = args[0], args[1:] y = kwargs.pop(y_name, None) else: x = kwargs.pop(x_name, None) y = kwargs.pop(y_name, None) if need_to_bind_api_args: tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs) else: tensor_api = api if name is None: return elementwise_api_handler(tensor_api, x, y) else: with ops.name_scope(name, None, [x, y]): return elementwise_api_handler(tensor_api, x, y) dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__ dispatch_target.__qualname__ = dispatch_target.__name__ # Keep track of what targets we've registered (so we can unregister them). target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), []) target_list.append((api, dispatch_target)) def _find_name_index(signature): """Returns the index of the `name` parameter, or -1 if it's not present.""" try: return list(signature.parameters).index("name") except ValueError: return -1 def _extract_name_arg(args, kwargs, name_index): """Extracts the parameter `name` and returns `(args, kwargs, name_value)`.""" if name_index < 0: name_value = None elif name_index < len(args): name_value = args[name_index] args = args[:name_index] + args[name_index + 1:] else: name_value = kwargs.pop("name", None) return args, kwargs, name_value def update_docstrings_with_api_lists(): """Updates the docstrings of dispatch decorators with API lists. Updates docstrings for `dispatch_for_api`, `dispatch_for_unary_elementwise_apis`, and `dispatch_for_binary_elementwise_apis`, by replacing the string '<>' with a list of APIs that have been registered for that decorator. """ _update_docstring_with_api_list(dispatch_for_unary_elementwise_apis, _UNARY_ELEMENTWISE_APIS) _update_docstring_with_api_list(dispatch_for_binary_elementwise_apis, _BINARY_ELEMENTWISE_APIS) _update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis, _BINARY_ELEMENTWISE_ASSERT_APIS) _update_docstring_with_api_list(dispatch_for_api, _TYPE_BASED_DISPATCH_SIGNATURES) def _update_docstring_with_api_list(target, api_list): """Replaces `<>` in target.__doc__ with the given list of APIs.""" lines = [] for func in api_list: name = tf_export_lib.get_canonical_name_for_symbol( func, add_prefix_to_v1_names=True) if name is not None: params = tf_inspect.signature(func).parameters.keys() lines.append(f" * `tf.{name}({', '.join(params)})`") lines.sort() target.__doc__ = target.__doc__.replace(" <>", "\n".join(lines)) ################################################################################ # Dispatch Support ################################################################################ @tf_export("__internal__.dispatch.add_dispatch_support", v1=[]) def add_dispatch_support(target=None, iterable_parameters=None): """Decorator that adds a dispatch handling wrapper to a TensorFlow Python API. This wrapper adds the decorated function as an API that can be overridden using the `@dispatch_for_api` decorator. In the following example, we first define a new API (`double`) that supports dispatch, then define a custom type (`MaskedTensor`) and finally use `dispatch_for_api` to override the default implementation of `double` when called with `MaskedTensor` values: >>> @add_dispatch_support ... def double(x): ... return x * 2 >>> class MaskedTensor(tf.experimental.ExtensionType): ... values: tf.Tensor ... mask: tf.Tensor >>> @dispatch_for_api(double, {'x': MaskedTensor}) ... def masked_double(x): ... return MaskedTensor(x.values * 2, y.mask) The optional `iterable_parameter` argument can be used to mark parameters that can take arbitrary iterable values (such as generator expressions). These need to be handled specially during dispatch, since just iterating over an iterable uses up its values. In the following example, we define a new API whose second argument can be an iterable value; and then override the default implementatio of that API when the iterable contains MaskedTensors: >>> @add_dispatch_support(iterable_parameters=['ys']) ... def add_tensor_to_list_of_tensors(x, ys): ... return [x + y for y in ys] >>> @dispatch_for_api(add_tensor_to_list_of_tensors, ... {'ys': typing.List[MaskedTensor]}) ... def masked_add_tensor_to_list_of_tensors(x, ys): ... return [MaskedTensor(x+y.values, y.mask) for y in ys] (Note: the only TensorFlow API that currently supports iterables is `add_n`.) Args: target: The TensorFlow API that should support dispatch. iterable_parameters: Optional list of parameter names that may be called with iterables (such as the `inputs` parameter for `tf.add_n`). Returns: A decorator. """ if not (iterable_parameters is None or (isinstance(iterable_parameters, (list, tuple)) and all(isinstance(p, str) for p in iterable_parameters))): raise TypeError("iterable_parameters should be a list or tuple of string.") def decorator(dispatch_target): # Get the name & index for each iterable parameter. if iterable_parameters is None: iterable_params = None else: arg_names = tf_inspect.getargspec(dispatch_target).args iterable_params = [ (name, arg_names.index(name)) for name in iterable_parameters ] @traceback_utils.filter_traceback def op_dispatch_handler(*args, **kwargs): """Call `dispatch_target`, peforming dispatch when appropriate.""" # Type-based dispatch system (dispatch v2): if api_dispatcher is not None: if iterable_params is not None: args, kwargs = replace_iterable_params(args, kwargs, iterable_params) result = api_dispatcher.Dispatch(args, kwargs) if result is not NotImplemented: return result # Fallback dispatch system (dispatch v1): try: return dispatch_target(*args, **kwargs) except (TypeError, ValueError): # Note: convert_to_eager_tensor currently raises a ValueError, not a # TypeError, when given unexpected types. So we need to catch both. result = dispatch(op_dispatch_handler, args, kwargs) if result is not OpDispatcher.NOT_SUPPORTED: return result else: raise add_fallback_dispatch_list(op_dispatch_handler) op_dispatch_handler = tf_decorator.make_decorator(dispatch_target, op_dispatch_handler) add_type_based_api_dispatcher(op_dispatch_handler) api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR, None) return op_dispatch_handler if target is None: return decorator else: return decorator(target) def replace_iterable_params(args, kwargs, iterable_params): """Returns (args, kwargs) with any iterable parameters converted to lists. Args: args: Positional rguments to a function kwargs: Keyword arguments to a function. iterable_params: A list of (name, index) tuples for iterable parameters. Returns: A tuple (args, kwargs), where any positional or keyword parameters in `iterable_params` have their value converted to a `list`. """ args = list(args) for name, index in iterable_params: if index < len(args): args[index] = list(args[index]) elif name in kwargs: kwargs[name] = list(kwargs[name]) return tuple(args), kwargs