# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """Operations for ExtensionTypes (aka Composite Tensors).""" from tensorflow.core.protobuf import composite_tensor_variant_pb2 from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_composite_tensor_ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import nest def composite_tensor_to_variants(value, type_spec=None, name=None): """Encodes `value` as a scalar variant tensor. Args: value: The `ExtensionType` value to encode. type_spec: Information about the value's type that should be included in the encoding. name: Optional name for the operation. Returns: A Tensor with shape=`()` and dtype=`tf.variant`. Raises: ValueError: If `type_spec` is not compatible with `value`. """ if not isinstance(value, composite_tensor.CompositeTensor): raise TypeError("Expected `value` to be a CompositeTensor. " f"Received {type(value)}.") if type_spec is None: type_spec = value._type_spec # pylint: disable=protected-access if not type_spec.is_compatible_with(value): raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` " f"{value!r}.") metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() metadata.type_spec_proto.CopyFrom( nested_structure_coder.encode_structure(type_spec).type_spec_value) return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( components=nest.flatten(value, expand_composites=True), metadata=metadata.SerializeToString(), name=name) def composite_tensor_from_variant(encoded, type_spec, name=None): """Returns the `ExtensionType` value encoded by a variant scalar tensor. Args: encoded: A Tensor returned by `composite_tensor_to_variants`. type_spec: The `TypeSpec` of the original value. This is used to determine the number and types of the component tensors that comprise the decoded value. Must be compatible with the `TypeSpec` serilized in `encoded`. name: Optional name for the operation. Returns: An `ExtensionType` value that is compatible with `TypeSpec`. Raises: TypeError: If `encoded` is not a Tensor with dtype=variant. InvalidArgumentError: If `encoded` is not compatible with `type_spec`. """ if not isinstance(encoded, ops.Tensor): raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.") if encoded.dtype != dtypes.variant: raise TypeError("Expected `encoded` to have dtype=variant, got " f"{encoded!r}.") encoded.shape.assert_is_compatible_with(()) metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() metadata.type_spec_proto.CopyFrom( nested_structure_coder.encode_structure(type_spec).type_spec_value) component_dtypes = [ t.dtype for t in nest.flatten(type_spec, expand_composites=True) ] components = gen_composite_tensor_ops.CompositeTensorVariantToComponents( encoded=encoded, metadata=metadata.SerializeToString(), Tcomponents=component_dtypes, name=name) return nest.pack_sequence_as(type_spec, components, expand_composites=True) @ops.RegisterGradient("CompositeTensorVariantFromComponents") def _composite_tensor_to_variants_grad(op, grad): return gen_composite_tensor_ops.CompositeTensorVariantToComponents( encoded=grad, metadata=op.get_attr("metadata"), Tcomponents=op.get_attr("Tcomponents")) @ops.RegisterGradient("CompositeTensorVariantToComponents") def _composite_tensor_from_variant_grad(op, *grad): assert len(grad) == len(op.outputs) # `components` is `op.outputs`, but with any tensors for which we're # taking the gradient replaced by the corresponding value from `grad`. components = [ op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad)) ] return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( components=components, metadata=op.get_attr("metadata"))