195 lines
8.2 KiB
Python
195 lines
8.2 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utility functions for types information, incuding full type information."""
|
|
|
|
from typing import List
|
|
|
|
from tensorflow.core.framework import full_type_pb2
|
|
from tensorflow.core.framework import types_pb2
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
|
|
from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import nest
|
|
|
|
# TODO(b/226455884) A python binding for DT_TO_FT or map_dtype_to_tensor() from
|
|
# tensorflow/core/framework/types.cc to avoid duplication here
|
|
_DT_TO_FT = {
|
|
types_pb2.DT_FLOAT: full_type_pb2.TFT_FLOAT,
|
|
types_pb2.DT_DOUBLE: full_type_pb2.TFT_DOUBLE,
|
|
types_pb2.DT_INT32: full_type_pb2.TFT_INT32,
|
|
types_pb2.DT_UINT8: full_type_pb2.TFT_UINT8,
|
|
types_pb2.DT_INT16: full_type_pb2.TFT_INT16,
|
|
types_pb2.DT_INT8: full_type_pb2.TFT_INT8,
|
|
types_pb2.DT_STRING: full_type_pb2.TFT_STRING,
|
|
types_pb2.DT_COMPLEX64: full_type_pb2.TFT_COMPLEX64,
|
|
types_pb2.DT_INT64: full_type_pb2.TFT_INT64,
|
|
types_pb2.DT_BOOL: full_type_pb2.TFT_BOOL,
|
|
types_pb2.DT_UINT16: full_type_pb2.TFT_UINT16,
|
|
types_pb2.DT_COMPLEX128: full_type_pb2.TFT_COMPLEX128,
|
|
types_pb2.DT_HALF: full_type_pb2.TFT_HALF,
|
|
types_pb2.DT_UINT32: full_type_pb2.TFT_UINT32,
|
|
types_pb2.DT_UINT64: full_type_pb2.TFT_UINT64,
|
|
types_pb2.DT_VARIANT: full_type_pb2.TFT_LEGACY_VARIANT,
|
|
}
|
|
|
|
|
|
def _translate_to_fulltype_for_flat_tensors(
|
|
spec: type_spec.TypeSpec) -> List[full_type_pb2.FullTypeDef]:
|
|
"""Convert a TypeSec to a list of FullTypeDef.
|
|
|
|
The FullTypeDef created corresponds to the encoding used with datasets
|
|
(and map_fn) that uses variants (and not FullTypeDef corresponding to the
|
|
default "component" encoding).
|
|
|
|
Currently, the only use of this is for information about the contents of
|
|
ragged tensors, so only ragged tensors return useful full type information
|
|
and other types return TFT_UNSET. While this could be improved in the future,
|
|
this function is intended for temporary use and expected to be removed
|
|
when type inference support is sufficient.
|
|
|
|
Args:
|
|
spec: A TypeSpec for one element of a dataset or map_fn.
|
|
|
|
Returns:
|
|
A list of FullTypeDef corresponding to SPEC. The length of this list
|
|
is always the same as the length of spec._flat_tensor_specs.
|
|
"""
|
|
if isinstance(spec, RaggedTensorSpec):
|
|
dt = spec.dtype
|
|
elem_t = _DT_TO_FT.get(dt)
|
|
if elem_t is None:
|
|
logging.vlog(1, "dtype %s that has no conversion to fulltype.", dt)
|
|
elif elem_t == full_type_pb2.TFT_LEGACY_VARIANT:
|
|
logging.vlog(1, "Ragged tensors containing variants are not supported.",
|
|
dt)
|
|
else:
|
|
assert len(spec._flat_tensor_specs) == 1 # pylint: disable=protected-access
|
|
return [
|
|
full_type_pb2.FullTypeDef(
|
|
type_id=full_type_pb2.TFT_RAGGED,
|
|
args=[full_type_pb2.FullTypeDef(type_id=elem_t)])
|
|
]
|
|
return [
|
|
full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)
|
|
for t in spec._flat_tensor_specs # pylint: disable=protected-access
|
|
]
|
|
|
|
|
|
# LINT.IfChange(_specs_for_flat_tensors)
|
|
def _specs_for_flat_tensors(element_spec):
|
|
"""Return a flat list of type specs for element_spec.
|
|
|
|
Note that "flat" in this function and in `_flat_tensor_specs` is a nickname
|
|
for the "batchable tensor list" encoding used by datasets and map_fn
|
|
internally (in C++/graphs). The ability to batch, unbatch and change
|
|
batch size is one important characteristic of this encoding. A second
|
|
important characteristic is that it represets a ragged tensor or sparse
|
|
tensor as a single tensor of type variant (and this encoding uses special
|
|
ops to encode/decode to/from variants).
|
|
|
|
(In constrast, the more typical encoding, e.g. the C++/graph
|
|
representation when calling a tf.function, is "component encoding" which
|
|
represents sparse and ragged tensors as multiple dense tensors and does
|
|
not use variants or special ops for encoding/decoding.)
|
|
|
|
Args:
|
|
element_spec: A nest of TypeSpec describing the elements of a dataset (or
|
|
map_fn).
|
|
|
|
Returns:
|
|
A non-nested list of TypeSpec used by the encoding of tensors by
|
|
datasets and map_fn for ELEMENT_SPEC. The items
|
|
in this list correspond to the items in `_flat_tensor_specs`.
|
|
"""
|
|
if isinstance(element_spec, StructuredTensor.Spec):
|
|
specs = []
|
|
for _, field_spec in sorted(
|
|
element_spec._field_specs.items(), key=lambda t: t[0]): # pylint: disable=protected-access
|
|
specs.extend(_specs_for_flat_tensors(field_spec))
|
|
elif isinstance(element_spec, type_spec.BatchableTypeSpec) and (
|
|
element_spec.__class__._flat_tensor_specs is # pylint: disable=protected-access
|
|
type_spec.BatchableTypeSpec._flat_tensor_specs): # pylint: disable=protected-access
|
|
# Classes which use the default `_flat_tensor_specs` from
|
|
# `BatchableTypeSpec` case (i.e. a derived class does not override
|
|
# `_flat_tensor_specs`.) are encoded using `component_specs`.
|
|
specs = nest.flatten(
|
|
element_spec._component_specs, # pylint: disable=protected-access
|
|
expand_composites=False)
|
|
else:
|
|
# In addition flatting any nesting in Python,
|
|
# this default case covers things that are encoded by one tensor,
|
|
# such as dense tensors which are unchanged by encoding and
|
|
# ragged tensors and sparse tensors which are encoded by a variant tensor.
|
|
specs = nest.flatten(element_spec, expand_composites=False)
|
|
return specs
|
|
# LINT.ThenChange()
|
|
# Note that _specs_for_flat_tensors must correspond to _flat_tensor_specs
|
|
|
|
|
|
def fulltypes_for_flat_tensors(element_spec):
|
|
"""Convert the element_spec for a dataset to a list of FullType Def.
|
|
|
|
Note that "flat" in this function and in `_flat_tensor_specs` is a nickname
|
|
for the "batchable tensor list" encoding used by datasets and map_fn.
|
|
The FullTypeDef created corresponds to this encoding (e.g. that uses variants
|
|
and not the FullTypeDef corresponding to the default "component" encoding).
|
|
|
|
This is intended for temporary internal use and expected to be removed
|
|
when type inference support is sufficient. See limitations of
|
|
`_translate_to_fulltype_for_flat_tensors`.
|
|
|
|
Args:
|
|
element_spec: A nest of TypeSpec describing the elements of a dataset (or
|
|
map_fn).
|
|
|
|
Returns:
|
|
A list of FullTypeDef correspoinding to ELEMENT_SPEC. The items
|
|
in this list correspond to the items in `_flat_tensor_specs`.
|
|
"""
|
|
specs = _specs_for_flat_tensors(element_spec)
|
|
full_types_lists = [_translate_to_fulltype_for_flat_tensors(s) for s in specs]
|
|
rval = nest.flatten(full_types_lists) # flattens list-of-list to flat list.
|
|
return rval
|
|
|
|
|
|
def fulltype_list_to_product(fulltype_list):
|
|
"""Convert a list of FullType Def into a single FullType Def."""
|
|
return full_type_pb2.FullTypeDef(
|
|
type_id=full_type_pb2.TFT_PRODUCT, args=fulltype_list)
|
|
|
|
|
|
def iterator_full_type_from_spec(element_spec):
|
|
"""Returns a FullTypeDef for an iterator for the elements.
|
|
|
|
Args:
|
|
element_spec: A nested structure of `tf.TypeSpec` objects representing the
|
|
element type specification.
|
|
|
|
Returns:
|
|
A FullTypeDef for an iterator for the element tensor representation.
|
|
"""
|
|
args = fulltypes_for_flat_tensors(element_spec)
|
|
return full_type_pb2.FullTypeDef(
|
|
type_id=full_type_pb2.TFT_PRODUCT,
|
|
args=[
|
|
full_type_pb2.FullTypeDef(
|
|
type_id=full_type_pb2.TFT_ITERATOR,
|
|
args=[
|
|
full_type_pb2.FullTypeDef(
|
|
type_id=full_type_pb2.TFT_PRODUCT, args=args)
|
|
])
|
|
])
|