118 lines
4.7 KiB
Python
118 lines
4.7 KiB
Python
![]() |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# =============================================================================
|
||
|
"""Operations for ExtensionTypes (aka Composite Tensors)."""
|
||
|
|
||
|
from tensorflow.core.protobuf import composite_tensor_variant_pb2
|
||
|
from tensorflow.python.framework import composite_tensor
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import gen_composite_tensor_ops
|
||
|
from tensorflow.python.saved_model import nested_structure_coder
|
||
|
from tensorflow.python.util import nest
|
||
|
|
||
|
|
||
|
def composite_tensor_to_variants(value, type_spec=None, name=None):
|
||
|
"""Encodes `value` as a scalar variant tensor.
|
||
|
|
||
|
Args:
|
||
|
value: The `ExtensionType` value to encode.
|
||
|
type_spec: Information about the value's type that should be included in the
|
||
|
encoding.
|
||
|
name: Optional name for the operation.
|
||
|
|
||
|
Returns:
|
||
|
A Tensor with shape=`()` and dtype=`tf.variant`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If `type_spec` is not compatible with `value`.
|
||
|
"""
|
||
|
if not isinstance(value, composite_tensor.CompositeTensor):
|
||
|
raise TypeError("Expected `value` to be a CompositeTensor. "
|
||
|
f"Received {type(value)}.")
|
||
|
|
||
|
if type_spec is None:
|
||
|
type_spec = value._type_spec # pylint: disable=protected-access
|
||
|
if not type_spec.is_compatible_with(value):
|
||
|
raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` "
|
||
|
f"{value!r}.")
|
||
|
metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
|
||
|
metadata.type_spec_proto.CopyFrom(
|
||
|
nested_structure_coder.encode_structure(type_spec).type_spec_value)
|
||
|
|
||
|
return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
|
||
|
components=nest.flatten(value, expand_composites=True),
|
||
|
metadata=metadata.SerializeToString(),
|
||
|
name=name)
|
||
|
|
||
|
|
||
|
def composite_tensor_from_variant(encoded, type_spec, name=None):
|
||
|
"""Returns the `ExtensionType` value encoded by a variant scalar tensor.
|
||
|
|
||
|
Args:
|
||
|
encoded: A Tensor returned by `composite_tensor_to_variants`.
|
||
|
type_spec: The `TypeSpec` of the original value. This is used to determine
|
||
|
the number and types of the component tensors that comprise the decoded
|
||
|
value. Must be compatible with the `TypeSpec` serilized in `encoded`.
|
||
|
name: Optional name for the operation.
|
||
|
|
||
|
Returns:
|
||
|
An `ExtensionType` value that is compatible with `TypeSpec`.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `encoded` is not a Tensor with dtype=variant.
|
||
|
InvalidArgumentError: If `encoded` is not compatible with `type_spec`.
|
||
|
"""
|
||
|
if not isinstance(encoded, ops.Tensor):
|
||
|
raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.")
|
||
|
if encoded.dtype != dtypes.variant:
|
||
|
raise TypeError("Expected `encoded` to have dtype=variant, got "
|
||
|
f"{encoded!r}.")
|
||
|
encoded.shape.assert_is_compatible_with(())
|
||
|
|
||
|
metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
|
||
|
metadata.type_spec_proto.CopyFrom(
|
||
|
nested_structure_coder.encode_structure(type_spec).type_spec_value)
|
||
|
|
||
|
component_dtypes = [
|
||
|
t.dtype for t in nest.flatten(type_spec, expand_composites=True)
|
||
|
]
|
||
|
|
||
|
components = gen_composite_tensor_ops.CompositeTensorVariantToComponents(
|
||
|
encoded=encoded,
|
||
|
metadata=metadata.SerializeToString(),
|
||
|
Tcomponents=component_dtypes,
|
||
|
name=name)
|
||
|
return nest.pack_sequence_as(type_spec, components, expand_composites=True)
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("CompositeTensorVariantFromComponents")
|
||
|
def _composite_tensor_to_variants_grad(op, grad):
|
||
|
return gen_composite_tensor_ops.CompositeTensorVariantToComponents(
|
||
|
encoded=grad,
|
||
|
metadata=op.get_attr("metadata"),
|
||
|
Tcomponents=op.get_attr("Tcomponents"))
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("CompositeTensorVariantToComponents")
|
||
|
def _composite_tensor_from_variant_grad(op, *grad):
|
||
|
assert len(grad) == len(op.outputs)
|
||
|
# `components` is `op.outputs`, but with any tensors for which we're
|
||
|
# taking the gradient replaced by the corresponding value from `grad`.
|
||
|
components = [
|
||
|
op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad))
|
||
|
]
|
||
|
return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
|
||
|
components=components, metadata=op.get_attr("metadata"))
|