1189 lines
45 KiB
Python
1189 lines
45 KiB
Python
![]() |
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""User-defined ExtensionType classes."""
|
||
|
|
||
|
import abc
|
||
|
import typing
|
||
|
import warnings
|
||
|
import typing_extensions
|
||
|
|
||
|
from tensorflow.core.protobuf import struct_pb2
|
||
|
from tensorflow.python.framework import composite_tensor
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import extension_type_field
|
||
|
from tensorflow.python.framework import immutable_dict
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.framework import tensor_shape
|
||
|
from tensorflow.python.framework import tensor_spec
|
||
|
from tensorflow.python.framework import type_spec
|
||
|
from tensorflow.python.framework import type_spec_registry
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import composite_tensor_ops
|
||
|
from tensorflow.python.ops import gen_math_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.saved_model import nested_structure_coder
|
||
|
from tensorflow.python.util import nest
|
||
|
from tensorflow.python.util import tf_decorator
|
||
|
from tensorflow.python.util import tf_inspect
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
# Attribute used to keep track of when we're inside a user-defined constructor
|
||
|
# (in which case the fields of `self` may be modified).
|
||
|
_IN_CONSTRUCTOR = '_tf_extension_type_in_constructor'
|
||
|
|
||
|
_MUTABLE_KERAS_PROPERTIES = [
|
||
|
# Keras uses _keras_mask property to pass the mask around
|
||
|
'_keras_mask',
|
||
|
]
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# Utility functions
|
||
|
# ==============================================================================
|
||
|
def _create_object_from_type_and_dict(cls, obj_dict):
|
||
|
"""Creates an object, bypassing the constructor.
|
||
|
|
||
|
Creates an object of type `cls`, whose `__dict__` is updated to contain
|
||
|
`obj_dict`.
|
||
|
|
||
|
Args:
|
||
|
cls: The type of the new object.
|
||
|
obj_dict: A `Mapping` that should be used to initialize the new object's
|
||
|
`__dict__`.
|
||
|
|
||
|
Returns:
|
||
|
An object of type `cls`.
|
||
|
"""
|
||
|
value = object.__new__(cls)
|
||
|
value.__dict__.update(obj_dict)
|
||
|
return value
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# Metaclass for tf.ExtensionType
|
||
|
# ==============================================================================
|
||
|
class ExtensionTypeMetaclass(abc.ABCMeta):
|
||
|
"""Metaclass for tf.ExtensionType types."""
|
||
|
|
||
|
def __init__(cls, name, bases, namespace):
|
||
|
# Don't transform base classes that are part of the framework -- only
|
||
|
# transform user classes. We identify classes that are part of the
|
||
|
# framework by setting '_tf_extension_type_do_not_transform_this_class=True'
|
||
|
# in the class definition. (Note: we check for this in the class namespace,
|
||
|
# so it is *not* ineherited.)
|
||
|
if not namespace.get('_tf_extension_type_do_not_transform_this_class',
|
||
|
False):
|
||
|
_check_field_annotations(cls)
|
||
|
_add_extension_type_constructor(cls)
|
||
|
_add_type_spec(cls)
|
||
|
super(ExtensionTypeMetaclass, cls).__init__(name, bases, namespace)
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# Base class for user-defined types
|
||
|
# ==============================================================================
|
||
|
@tf_export('experimental.ExtensionType')
|
||
|
class ExtensionType(
|
||
|
composite_tensor.CompositeTensor, metaclass=ExtensionTypeMetaclass):
|
||
|
"""Base class for TensorFlow `ExtensionType` classes.
|
||
|
|
||
|
Tensorflow `ExtensionType` classes are specialized Python classes that can be
|
||
|
used transparently with TensorFlow -- e.g., they can be used with ops
|
||
|
such as `tf.cond` or `tf.while_loop` and used as inputs or outputs for
|
||
|
`tf.function` and Keras layers.
|
||
|
|
||
|
New `ExtensionType` classes are defined by creating a subclass of
|
||
|
`tf.ExtensionType` that
|
||
|
contains type annotations for all instance variables. The following type
|
||
|
annotations are supported:
|
||
|
|
||
|
Type | Example
|
||
|
-------------------- | --------------------------------------------
|
||
|
Python integers | `i: int`
|
||
|
Python floats | `f: float`
|
||
|
Python strings | `s: str`
|
||
|
Python booleans | `b: bool`
|
||
|
Python None | `n: None`
|
||
|
Tensors | `t: tf.Tensor`
|
||
|
Composite Tensors | `rt: tf.RaggedTensor`
|
||
|
Extension Types | `m: MyMaskedTensor`
|
||
|
Tensor shapes | `shape: tf.TensorShape`
|
||
|
Tensor dtypes | `dtype: tf.DType`
|
||
|
Type unions | `length: typing.Union[int, float]`
|
||
|
Tuples | `params: typing.Tuple[int, float, int, int]`
|
||
|
Tuples w/ Ellipsis | `lengths: typing.Tuple[int, ...]`
|
||
|
Mappings | `tags: typing.Mapping[str, str]`
|
||
|
|
||
|
Fields annotated with `typing.Mapping` will be stored using an immutable
|
||
|
mapping type.
|
||
|
|
||
|
ExtensionType values are immutable -- i.e., once constructed, you can not
|
||
|
modify or delete any of their instance members.
|
||
|
|
||
|
### Examples
|
||
|
|
||
|
>>> class MaskedTensor(ExtensionType):
|
||
|
... values: tf.Tensor
|
||
|
... mask: tf.Tensor
|
||
|
|
||
|
>>> class Toy(ExtensionType):
|
||
|
... name: str
|
||
|
... price: ops.Tensor
|
||
|
... features: typing.Mapping[str, tf.Tensor]
|
||
|
|
||
|
>>> class ToyStore(ExtensionType):
|
||
|
... name: str
|
||
|
... toys: typing.Tuple[Toy, ...]
|
||
|
"""
|
||
|
|
||
|
# Let the metaclass know that it should *not* transform this class (since
|
||
|
# this class is part of the ExtensionType framework, and not a user class).
|
||
|
_tf_extension_type_do_not_transform_this_class = True
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
if type(self) is ExtensionType: # pylint: disable=unidiomatic-typecheck
|
||
|
raise AssertionError('Cannot create an instance of ExtensionType '
|
||
|
'because ExtensionType is an abstract base class.')
|
||
|
|
||
|
# This class variable is used to cache the return value for
|
||
|
# _tf_extension_type_fields.
|
||
|
_tf_extension_type_cached_fields = None
|
||
|
|
||
|
@classmethod
|
||
|
def _tf_extension_type_fields(cls): # pylint: disable=no-self-argument
|
||
|
"""An ordered list describing the fields of this ExtensionType.
|
||
|
|
||
|
Returns:
|
||
|
A list of `ExtensionTypeField` objects. Forward references are resolved
|
||
|
if possible, or left unresolved otherwise.
|
||
|
"""
|
||
|
if '_tf_extension_type_cached_fields' in cls.__dict__: # do not inherit.
|
||
|
return cls._tf_extension_type_cached_fields
|
||
|
|
||
|
try:
|
||
|
# Using include_extras=False will replace all Annotated[T, ...] with T.
|
||
|
# The typing_extensions module is used since this is only supported in
|
||
|
# Python 3.9.
|
||
|
type_hints = typing_extensions.get_type_hints(cls, include_extras=False)
|
||
|
ok_to_cache = True # all forward references have been resolved.
|
||
|
except (NameError, AttributeError):
|
||
|
# Unresolved forward reference -- gather type hints manually.
|
||
|
# * NameError comes from an annotation like `Foo` where class
|
||
|
# `Foo` hasn't been defined yet.
|
||
|
# * AttributeError comes from an annotation like `foo.Bar`, where
|
||
|
# the module `foo` exists but `Bar` hasn't been defined yet.
|
||
|
# Note: If a user attempts to instantiate a `ExtensionType` type that
|
||
|
# still has unresolved forward references (e.g., because of a typo or a
|
||
|
# missing import), then the constructor will raise an exception.
|
||
|
type_hints = {}
|
||
|
for base in reversed(cls.__mro__):
|
||
|
type_hints.update(base.__dict__.get('__annotations__', {}))
|
||
|
ok_to_cache = False
|
||
|
|
||
|
fields = []
|
||
|
for (name, value_type) in type_hints.items():
|
||
|
default = getattr(cls, name,
|
||
|
extension_type_field.ExtensionTypeField.NO_DEFAULT)
|
||
|
fields.append(
|
||
|
extension_type_field.ExtensionTypeField(name, value_type, default))
|
||
|
fields = tuple(fields)
|
||
|
|
||
|
if ok_to_cache:
|
||
|
cls._tf_extension_type_cached_fields = fields
|
||
|
|
||
|
return fields
|
||
|
|
||
|
@classmethod
|
||
|
def _tf_extension_type_has_field(cls, name):
|
||
|
return any(name == field.name for field in cls._tf_extension_type_fields())
|
||
|
|
||
|
def _tf_extension_type_convert_fields(self):
|
||
|
extension_type_field.convert_fields(self._tf_extension_type_fields(),
|
||
|
self.__dict__)
|
||
|
|
||
|
def __repr__(self):
|
||
|
fields = ', '.join([
|
||
|
f'{field.name}={getattr(self, field.name)!r}'
|
||
|
for field in self._tf_extension_type_fields()
|
||
|
])
|
||
|
return f'{type(self).__qualname__}({fields})'
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
if (name in _MUTABLE_KERAS_PROPERTIES or
|
||
|
(hasattr(self, _IN_CONSTRUCTOR) and
|
||
|
self._tf_extension_type_has_field(name))):
|
||
|
self.__dict__[name] = value
|
||
|
else:
|
||
|
raise AttributeError(f'Cannot mutate attribute `{name}` '
|
||
|
f'outside the custom constructor of ExtensionType.')
|
||
|
|
||
|
def __delattr__(self, name):
|
||
|
if (name in _MUTABLE_KERAS_PROPERTIES or
|
||
|
(hasattr(self, _IN_CONSTRUCTOR) and
|
||
|
self._tf_extension_type_has_field(name))):
|
||
|
del self.__dict__[name]
|
||
|
else:
|
||
|
raise AttributeError(f'Cannot mutate attribute `{name}` '
|
||
|
f'outside the custom constructor of ExtensionType.')
|
||
|
|
||
|
def __getattr__(self, name):
|
||
|
if name in _MUTABLE_KERAS_PROPERTIES:
|
||
|
return object.__getattribute__(self, name)
|
||
|
if '_tf_extension_type_packed_variant' in self.__dict__:
|
||
|
# Note: it's *not* ok to cache the results of unpack() here. In
|
||
|
# particular, it would be nice if we could do something like
|
||
|
# `self.__dict__.update(unpack(self).__dict__)`, but that (potentially)
|
||
|
# violates an invariant required by the `cond` operation. E.g., if we had
|
||
|
# `tf.cond(lambda: x.foo, lambda: x.bar)`, then tensor `x.bar` used in the
|
||
|
# "else" branch would be created by an op in the "then" branch (when
|
||
|
# looking up `x.foo`); and that's not allowed.
|
||
|
return getattr(unpack(self), name)
|
||
|
|
||
|
raise AttributeError(
|
||
|
f'{type(self).__name__!r} object has no attribute {name!r}')
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
if type(self) is not type(other):
|
||
|
return False
|
||
|
|
||
|
if self._type_spec != other._type_spec:
|
||
|
return False
|
||
|
|
||
|
self_tensors = nest.flatten(self, expand_composites=True)
|
||
|
other_tensors = nest.flatten(other, expand_composites=True)
|
||
|
if len(self_tensors) != len(other_tensors):
|
||
|
return False
|
||
|
conditions = []
|
||
|
for t1, t2 in zip(self_tensors, other_tensors):
|
||
|
conditions.append(
|
||
|
math_ops.reduce_all(
|
||
|
gen_math_ops.equal(
|
||
|
array_ops.shape(t1),
|
||
|
array_ops.shape(t2),
|
||
|
incompatible_shape_error=False)))
|
||
|
# Explicitly check shape (values that have different shapes but broadcast
|
||
|
# to the same value are considered non-equal).
|
||
|
conditions.append(
|
||
|
math_ops.reduce_all(
|
||
|
gen_math_ops.equal(t1, t2, incompatible_shape_error=False)))
|
||
|
return math_ops.reduce_all(array_ops.stack(conditions))
|
||
|
|
||
|
def __ne__(self, other):
|
||
|
eq = self.__eq__(other)
|
||
|
if isinstance(eq, ops.Tensor):
|
||
|
return math_ops.logical_not(eq)
|
||
|
else:
|
||
|
return not eq
|
||
|
|
||
|
def __validate__(self):
|
||
|
"""Perform post-construction validation."""
|
||
|
|
||
|
# This instance variable is used to cache the value for the _type_spec
|
||
|
# property.
|
||
|
_tf_extension_type_cached_type_spec = None
|
||
|
|
||
|
@property
|
||
|
def _type_spec(self): # CompositeTensor API.
|
||
|
# Note: the TypeSpec contains all static (non-tensor) data from `self`.
|
||
|
if self._tf_extension_type_cached_type_spec is None:
|
||
|
assert not is_packed(self) # Packed version always caches TypeSpec.
|
||
|
self.__dict__[
|
||
|
'_tf_extension_type_cached_type_spec'] = self.Spec.from_value(self)
|
||
|
return self._tf_extension_type_cached_type_spec
|
||
|
|
||
|
|
||
|
@tf_export('experimental.extension_type.as_dict')
|
||
|
def as_dict(value):
|
||
|
"""Extracts the attributes of `value` and their values to a dict format.
|
||
|
|
||
|
Unlike `dataclasses.asdict()`, this function is not recursive and in case of
|
||
|
nested `ExtensionType` objects, only the top level object is converted to a
|
||
|
dict.
|
||
|
|
||
|
Args:
|
||
|
value: An `ExtensionType` object.
|
||
|
|
||
|
Returns:
|
||
|
A dict that contains the attributes of `value` and their values.
|
||
|
"""
|
||
|
return {
|
||
|
field.name: getattr(value, field.name)
|
||
|
for field in value._tf_extension_type_fields() # pylint: disable=protected-access
|
||
|
}
|
||
|
|
||
|
|
||
|
def pack(value):
|
||
|
"""Returns a copy of `value` with fields packed in a single Variant.
|
||
|
|
||
|
Args:
|
||
|
value: An `ExtensionType` object.
|
||
|
|
||
|
Returns:
|
||
|
An `ExtensionType` object.
|
||
|
"""
|
||
|
if is_packed(value):
|
||
|
return value
|
||
|
|
||
|
spec = value._type_spec._tf_extension_type_with_packed(True) # pylint: disable=protected-access
|
||
|
try:
|
||
|
variant = composite_tensor_ops.composite_tensor_to_variants(value)
|
||
|
except nested_structure_coder.NotEncodableError as e:
|
||
|
# Note: the only time `_TypeSpecCodec.can_encode` returns False is if the
|
||
|
# named type is not registered. The default error message would simply
|
||
|
# tell the user that there is no encoder for the object, so we provide
|
||
|
# a more useful message letting them know how to register the type.
|
||
|
raise ValueError('ExtensionTypes must have a __name__ field in order '
|
||
|
'to be packed.') from e
|
||
|
|
||
|
return _create_object_from_type_and_dict(
|
||
|
type(value), {
|
||
|
'_tf_extension_type_cached_type_spec': spec,
|
||
|
'_tf_extension_type_packed_variant': variant,
|
||
|
})
|
||
|
|
||
|
|
||
|
def unpack(value):
|
||
|
"""Returns a copy of `value` with individual fields stored in __dict__.
|
||
|
|
||
|
Args:
|
||
|
value: An `ExtensionType` object.
|
||
|
|
||
|
Returns:
|
||
|
An `ExtensionType` object.
|
||
|
"""
|
||
|
if not is_packed(value):
|
||
|
return value
|
||
|
|
||
|
# pylint: disable=protected-access
|
||
|
variant = value._tf_extension_type_packed_variant
|
||
|
spec = value._tf_extension_type_cached_type_spec
|
||
|
spec = spec._tf_extension_type_with_packed(False)
|
||
|
return composite_tensor_ops.composite_tensor_from_variant(variant, spec)
|
||
|
|
||
|
|
||
|
def is_packed(value):
|
||
|
"""Returns true if `value`'s fields are packed in a single Variant."""
|
||
|
if not isinstance(value, ExtensionType):
|
||
|
raise ValueError(f'Expected `value` to be an object of type ExtensionType,'
|
||
|
f'got an instance of {type(value)}.')
|
||
|
return '_tf_extension_type_packed_variant' in value.__dict__
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# Base class for the tf.ExtensionType TypeSpecs
|
||
|
# ==============================================================================
|
||
|
|
||
|
|
||
|
class ExtensionTypeSpec(type_spec.TypeSpec):
|
||
|
"""Base class for tf.ExtensionType TypeSpec."""
|
||
|
|
||
|
def _serialize(self): # TypeSpec API.
|
||
|
# Use a tuple of (name, value) pairs, to ensure we preserve field ordering.
|
||
|
fields = [f.name for f in self._tf_extension_type_fields()]
|
||
|
if self._tf_extension_type_is_packed:
|
||
|
fields.append('_tf_extension_type_is_packed')
|
||
|
return tuple(
|
||
|
(f, _change_nested_mappings_to(self.__dict__[f], dict)) for f in fields)
|
||
|
|
||
|
@classmethod
|
||
|
def _deserialize(cls, state): # TypeSpec API.
|
||
|
state = _change_nested_mappings_to(state, immutable_dict.ImmutableDict)
|
||
|
return _create_object_from_type_and_dict(cls, state)
|
||
|
|
||
|
def __reduce__(self):
|
||
|
# Use value_type instead of spec_type, as spec_type is a nested class.
|
||
|
# Pickle support of nested class requries Pickle protocol version 4, which
|
||
|
# is not enabled by default until py 3.8.
|
||
|
#
|
||
|
# https://www.python.org/dev/peps/pep-3154/#serializing-more-lookupable-objects
|
||
|
# https://docs.python.org/3/library/pickle.html#pickle.DEFAULT_PROTOCOL
|
||
|
return _deserialize_for_reduce, (self.value_type, self._serialize())
|
||
|
|
||
|
def _to_components(self, value): # TypeSpec API.
|
||
|
if self._tf_extension_type_is_packed:
|
||
|
return value._tf_extension_type_packed_variant # pylint: disable=protected-access
|
||
|
|
||
|
tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor)
|
||
|
# Retireve fields by the order of spec dict to preserve field ordering. This
|
||
|
# is needed as nest.flatten would sort dictionary entries by key.
|
||
|
value_tuple = tuple(value.__dict__[key] for key in self.__dict__)
|
||
|
return tuple(
|
||
|
x for x in nest.flatten(value_tuple)
|
||
|
if isinstance(x, tensor_or_composite))
|
||
|
|
||
|
def _from_components(self, components): # TypeSpec API.
|
||
|
if self._tf_extension_type_is_packed:
|
||
|
return _create_object_from_type_and_dict(
|
||
|
self.value_type, {
|
||
|
'_tf_extension_type_cached_type_spec': self,
|
||
|
'_tf_extension_type_packed_variant': components
|
||
|
})
|
||
|
|
||
|
spec_tuple = tuple(self.__dict__.values())
|
||
|
components_iter = iter(components)
|
||
|
flat = [
|
||
|
next(components_iter) if isinstance(x, type_spec.TypeSpec) else x
|
||
|
for x in nest.flatten(spec_tuple)
|
||
|
]
|
||
|
if list(components_iter):
|
||
|
raise ValueError(
|
||
|
'Cannot build an ExtensionType instance from components '
|
||
|
'because more components are provided than the number expected '
|
||
|
'by the type spec.')
|
||
|
value_tuple = nest.pack_sequence_as(spec_tuple, flat)
|
||
|
fields = dict(zip(self.__dict__.keys(), value_tuple))
|
||
|
|
||
|
# Build the new value. Bypass the constructor (__init__), in case the user
|
||
|
# who defined the ExtensionType used a custom constructor.
|
||
|
return _create_object_from_type_and_dict(self.value_type, fields)
|
||
|
|
||
|
@property
|
||
|
def _component_specs(self): # TypeSpec API.
|
||
|
if self._tf_extension_type_is_packed:
|
||
|
return tensor_spec.TensorSpec((), dtypes.variant)
|
||
|
|
||
|
components = []
|
||
|
|
||
|
def push_if_type_spec(x):
|
||
|
if isinstance(x, type_spec.TypeSpec):
|
||
|
components.append(x)
|
||
|
|
||
|
nest.map_structure(push_if_type_spec, tuple(self.__dict__.values()))
|
||
|
return tuple(components)
|
||
|
|
||
|
@classmethod
|
||
|
def from_value(cls, value):
|
||
|
cached_spec = getattr(value, '_tf_extension_type_cached_type_spec', None)
|
||
|
if cached_spec is not None:
|
||
|
return cached_spec
|
||
|
|
||
|
value_fields = value.__dict__
|
||
|
spec_fields = nest.map_structure(_replace_tensor_with_spec, value_fields)
|
||
|
spec_fields.pop('_tf_extension_type_cached_fields', None)
|
||
|
return _create_object_from_type_and_dict(cls, spec_fields)
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
if (hasattr(self, _IN_CONSTRUCTOR) and
|
||
|
self._tf_extension_type_has_field(name)):
|
||
|
self.__dict__[name] = value
|
||
|
else:
|
||
|
raise AttributeError(
|
||
|
f'Cannot mutate attribute `{name}` '
|
||
|
f'outside the custom constructor of ExtensionTypeSpec.')
|
||
|
|
||
|
def __delattr__(self, name):
|
||
|
if (hasattr(self, _IN_CONSTRUCTOR) and
|
||
|
self._tf_extension_type_has_field(name)):
|
||
|
del self.__dict__[name]
|
||
|
else:
|
||
|
raise AttributeError(
|
||
|
f'Cannot mutate attribute `{name}` '
|
||
|
f'outside the custom constructor of ExtensionTypeSpec.')
|
||
|
|
||
|
def __validate__(self):
|
||
|
"""Perform post-construction validation."""
|
||
|
|
||
|
@classmethod
|
||
|
def _tf_extension_type_fields(cls):
|
||
|
return cls.value_type._tf_extension_type_fields() # pylint: disable=protected-access
|
||
|
|
||
|
@classmethod
|
||
|
def _tf_extension_type_has_field(cls, name):
|
||
|
return any(name == field.name for field in cls._tf_extension_type_fields())
|
||
|
|
||
|
def _tf_extension_type_convert_fields(self):
|
||
|
extension_type_field.convert_fields_for_spec(
|
||
|
self._tf_extension_type_fields(), self.__dict__)
|
||
|
|
||
|
def __repr__(self):
|
||
|
fields = ', '.join([f'{k}={v!r}' for (k, v) in self._serialize()])
|
||
|
return f'{type(self).__qualname__}({fields})'
|
||
|
|
||
|
_tf_extension_type_is_packed = False
|
||
|
|
||
|
def _tf_extension_type_with_packed(self, value):
|
||
|
"""Returns a copy of this `TypeSpec` with `packed=value`.
|
||
|
|
||
|
Args:
|
||
|
value: A boolean value.
|
||
|
|
||
|
Returns:
|
||
|
A copy of `self` with `_tf_extension_type_is_packed=value`.
|
||
|
"""
|
||
|
copy = _create_object_from_type_and_dict(type(self), self.__dict__)
|
||
|
copy.__dict__['_tf_extension_type_is_packed'] = value
|
||
|
return copy
|
||
|
|
||
|
|
||
|
class _ExtensionTypeSpecCodec:
|
||
|
"""Codec for `tf.ExtensionTypeSpec`."""
|
||
|
|
||
|
def can_encode(self, pyobj):
|
||
|
"""Returns true if `pyobj` can be encoded as an ExtensionTypeSpec."""
|
||
|
if isinstance(pyobj, ExtensionTypeSpec):
|
||
|
try:
|
||
|
type_spec_registry.get_name(type(pyobj))
|
||
|
return True
|
||
|
except ValueError:
|
||
|
return False
|
||
|
return False
|
||
|
|
||
|
def do_encode(self, extension_type_spec_value, encode_fn):
|
||
|
"""Returns an encoded proto for the given `tf.ExtensionTypeSpec`."""
|
||
|
type_spec_class_name = type_spec_registry.get_name(
|
||
|
type(extension_type_spec_value))
|
||
|
|
||
|
type_state = extension_type_spec_value._serialize() # pylint: disable=protected-access
|
||
|
num_flat_components = len(
|
||
|
nest.flatten(
|
||
|
extension_type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access
|
||
|
encoded_type_spec = struct_pb2.StructuredValue()
|
||
|
encoded_type_spec.type_spec_value.CopyFrom(
|
||
|
struct_pb2.TypeSpecProto(
|
||
|
type_spec_class=struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC,
|
||
|
type_state=encode_fn(type_state),
|
||
|
type_spec_class_name=type_spec_class_name,
|
||
|
num_flat_components=num_flat_components))
|
||
|
return encoded_type_spec
|
||
|
|
||
|
def can_decode(self, value):
|
||
|
"""Returns true if `value` can be decoded into a `tf.ExtensionTypeSpec`."""
|
||
|
if value.HasField('type_spec_value'):
|
||
|
type_spec_class_enum = value.type_spec_value.type_spec_class
|
||
|
return (
|
||
|
type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC)
|
||
|
return False
|
||
|
|
||
|
def do_decode(self, value, decode_fn):
|
||
|
"""Returns the `tf.TypeSpec` encoded by the proto `value`."""
|
||
|
type_spec_proto = value.type_spec_value
|
||
|
class_name = type_spec_proto.type_spec_class_name
|
||
|
|
||
|
try:
|
||
|
type_spec_class = type_spec_registry.lookup(class_name)
|
||
|
except ValueError:
|
||
|
type_spec_class = AnonymousExtensionTypeSpec
|
||
|
warnings.warn(
|
||
|
f"The type '{class_name}' has not been registered. "
|
||
|
'Falling back to using AnonymousExtensionTypeSpec '
|
||
|
'instead.'
|
||
|
)
|
||
|
|
||
|
# pylint: disable=protected-access
|
||
|
return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
|
||
|
|
||
|
|
||
|
nested_structure_coder.register_codec(_ExtensionTypeSpecCodec())
|
||
|
|
||
|
|
||
|
@tf_export('experimental.ExtensionTypeBatchEncoder')
|
||
|
class ExtensionTypeBatchEncoder(type_spec.TypeSpecBatchEncoder):
|
||
|
"""Class used to encode and decode extension type values for batching.
|
||
|
|
||
|
In order to be batched and unbatched by APIs such as `tf.data.Dataset`,
|
||
|
`tf.keras`, and `tf.map_fn`, extension type values must be encoded as a list
|
||
|
of `tf.Tensor`s, where stacking, unstacking, or concatenating these encoded
|
||
|
tensors and then decoding the result must be equivalent to stacking,
|
||
|
unstacking, or concatenating the original values. `ExtensionTypeBatchEncoder`s
|
||
|
are responsible for implementing this encoding.
|
||
|
|
||
|
The default `ExtensionTypeBatchEncoder` that is used by
|
||
|
`BatchableExtensionType` assumes that extension type values can be stacked,
|
||
|
unstacked, or concatenated by simply stacking, unstacking, or concatenating
|
||
|
every nested `Tensor`, `ExtensionType`, `CompositeTensor`, and `TensorShape`
|
||
|
field.
|
||
|
|
||
|
Extension types where this is not the case will need to override
|
||
|
`__batch_encoder__` with a custom encoder that overrides the `batch`,
|
||
|
`unbatch`, `encode`, and `decode` methods. E.g.:
|
||
|
|
||
|
>>> class CustomBatchEncoder(ExtensionTypeBatchEncoder):
|
||
|
... pass # Override batch(), unbatch(), encode(), and decode().
|
||
|
|
||
|
>>> class CustomType(BatchableExtensionType):
|
||
|
... x: tf.Tensor
|
||
|
... y: tf.Tensor
|
||
|
... shape: tf.TensorShape
|
||
|
... __batch_encoder__ = CustomBatchEncoder()
|
||
|
|
||
|
For example, `tf.RaggedTensor` and `tf.SparseTensor` both use custom batch
|
||
|
encodings which define ops to "box" and "unbox" individual values into
|
||
|
`tf.variant` tensors.
|
||
|
"""
|
||
|
|
||
|
def batch(self, spec, batch_size):
|
||
|
"""Returns the TypeSpec representing a batch of values described by `spec`.
|
||
|
|
||
|
The default definition returns a `TypeSpec` that is equal to `spec`, except
|
||
|
that an outer axis with size `batch_size` is added to every nested
|
||
|
`TypeSpec` and `TensorShape` field. Subclasses may override this default
|
||
|
definition, when necessary.
|
||
|
|
||
|
Args:
|
||
|
spec: The `TypeSpec` for an individual value.
|
||
|
batch_size: An `int` indicating the number of values that are batched
|
||
|
together, or `None` if the batch size is not known.
|
||
|
|
||
|
Returns:
|
||
|
A `TypeSpec` for a batch of values.
|
||
|
"""
|
||
|
|
||
|
def batch_field(f):
|
||
|
if isinstance(f, type_spec.BatchableTypeSpec):
|
||
|
return f.__batch_encoder__.batch(f, batch_size)
|
||
|
elif isinstance(f, tensor_shape.TensorShape):
|
||
|
return [batch_size] + f
|
||
|
else:
|
||
|
return f
|
||
|
|
||
|
fields = tuple(spec.__dict__.items())
|
||
|
batched_fields = nest.map_structure(batch_field, fields)
|
||
|
return _create_object_from_type_and_dict(type(spec), batched_fields)
|
||
|
|
||
|
def unbatch(self, spec):
|
||
|
"""Returns the TypeSpec for a single unbatched element in `spec`.
|
||
|
|
||
|
The default definition returns a `TypeSpec` that is equal to `spec`, except
|
||
|
that the outermost axis is removed from every nested `TypeSpec`, and
|
||
|
`TensorShape` field. Subclasses may override this default definition, when
|
||
|
necessary.
|
||
|
|
||
|
Args:
|
||
|
spec: The `TypeSpec` for a batch of values.
|
||
|
|
||
|
Returns:
|
||
|
A `TypeSpec` for an individual value.
|
||
|
"""
|
||
|
|
||
|
def unbatch_field(f):
|
||
|
if isinstance(f, type_spec.BatchableTypeSpec):
|
||
|
return f.__batch_encoder__.unbatch(f)
|
||
|
elif isinstance(f, tensor_shape.TensorShape):
|
||
|
return f[1:]
|
||
|
else:
|
||
|
return f
|
||
|
|
||
|
fields = tuple(spec.__dict__.items())
|
||
|
unbatched_fields = nest.map_structure(unbatch_field, fields)
|
||
|
return _create_object_from_type_and_dict(type(spec), unbatched_fields)
|
||
|
|
||
|
def encode(self, spec, value, minimum_rank=0):
|
||
|
"""Encodes `value` as a nest of batchable Tensors or CompositeTensors.
|
||
|
|
||
|
The default definition returns a flat tuple of all the `Tensor`s,
|
||
|
`CompositeTensor`s, and `ExtensionType`s from a depth-first traversal of
|
||
|
`value`'s fields. Subclasses may override this default definition, when
|
||
|
necessary.
|
||
|
|
||
|
Args:
|
||
|
spec: The TypeSpec of the value to encode.
|
||
|
value: A value compatible with `spec`.
|
||
|
minimum_rank: The minimum rank for the returned Tensors, CompositeTensors,
|
||
|
and ExtensionType values. This can be used to ensure that the encoded
|
||
|
values can be unbatched this number of times. If `minimum_rank>0`,
|
||
|
then `t.shape[:minimum_rank]` must be compatible for all values `t`
|
||
|
returned by `encode`.
|
||
|
|
||
|
Returns:
|
||
|
A nest (as defined by `tf.nest`) of `tf.Tensor`s, batchable
|
||
|
`tf.CompositeTensor`s, or `tf.ExtensionType`s. Stacking, unstacking, or
|
||
|
concatenating these encoded values and then decoding the result must be
|
||
|
equivalent to stacking, unstacking, or concatenating the original values.
|
||
|
"""
|
||
|
return spec._to_components(value) # pylint: disable=protected-access
|
||
|
|
||
|
def decode(self, spec, encoded_value):
|
||
|
"""Decodes `value` from a batchable tensor encoding.
|
||
|
|
||
|
See `encode` for a description of the default encoding. Subclasses may
|
||
|
override this default definition, when necessary.
|
||
|
|
||
|
Args:
|
||
|
spec: The TypeSpec for the result value. If encoded values with spec `s`
|
||
|
were batched, then `spec` should be `s.batch(batch_size)`; or if encoded
|
||
|
values with spec `s` were unbatched, then `spec` should be
|
||
|
`s.unbatch()`.
|
||
|
encoded_value: A nest of values returned by `encode`; or a nest of
|
||
|
values that was formed by stacking, unstacking, or concatenating the
|
||
|
corresponding elements of values returned by `encode`.
|
||
|
|
||
|
Returns:
|
||
|
A value compatible with `type_spec`.
|
||
|
"""
|
||
|
return spec._from_components(encoded_value) # pylint: disable=protected-access
|
||
|
|
||
|
def encoding_specs(self, spec):
|
||
|
"""Returns a list of `TensorSpec`(s) describing the encoding for `spec`.
|
||
|
|
||
|
See `encode` for a description of the default encoding. Subclasses may
|
||
|
override this default definition, when necessary.
|
||
|
|
||
|
Args:
|
||
|
spec: The TypeSpec whose encoding should be described.
|
||
|
|
||
|
Returns:
|
||
|
A nest (as defined by `tf.nest) of `tf.TypeSpec`, describing the values
|
||
|
that are returned by `self.encode(spec, ...)`. All TypeSpecs in this
|
||
|
nest must be batchable.
|
||
|
"""
|
||
|
return spec._component_specs # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
class BatchableExtensionTypeSpec(ExtensionTypeSpec,
|
||
|
type_spec.BatchableTypeSpec):
|
||
|
"""Base class for TypeSpecs for BatchableExtensionTypes."""
|
||
|
|
||
|
__batch_encoder__ = ExtensionTypeBatchEncoder()
|
||
|
|
||
|
def _batch(self, batch_size):
|
||
|
return self.__batch_encoder__.batch(self, batch_size)
|
||
|
|
||
|
def _unbatch(self):
|
||
|
return self.__batch_encoder__.unbatch(self)
|
||
|
|
||
|
def _to_tensor_list(self, value):
|
||
|
return type_spec.batchable_to_tensor_list(self, value)
|
||
|
|
||
|
def _to_batched_tensor_list(self, value):
|
||
|
return type_spec.batchable_to_tensor_list(self, value, minimum_rank=1)
|
||
|
|
||
|
def _from_compatible_tensor_list(self, tensor_list):
|
||
|
return type_spec.batchable_from_tensor_list(self, tensor_list)
|
||
|
|
||
|
@property
|
||
|
def _flat_tensor_specs(self):
|
||
|
return type_spec.get_batchable_flat_tensor_specs(self)
|
||
|
|
||
|
|
||
|
@tf_export('experimental.BatchableExtensionType')
|
||
|
class BatchableExtensionType(ExtensionType):
|
||
|
"""An ExtensionType that can be batched and unbatched.
|
||
|
|
||
|
`BatchableExtensionType`s can be used with APIs that require batching or
|
||
|
unbatching, including `Keras`, `tf.data.Dataset`, and `tf.map_fn`. E.g.:
|
||
|
|
||
|
>>> class Vehicle(tf.experimental.BatchableExtensionType):
|
||
|
... top_speed: tf.Tensor
|
||
|
... mpg: tf.Tensor
|
||
|
>>> batch = Vehicle([120, 150, 80], [30, 40, 12])
|
||
|
>>> tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,
|
||
|
... fn_output_signature=tf.int32).numpy()
|
||
|
array([3600, 6000, 960], dtype=int32)
|
||
|
|
||
|
An `ExtensionTypeBatchEncoder` is used by these APIs to encode `ExtensionType`
|
||
|
values. The default encoder assumes that values can be stacked, unstacked, or
|
||
|
concatenated by simply stacking, unstacking, or concatenating every nested
|
||
|
`Tensor`, `ExtensionType`, `CompositeTensor`, or `TensorShape` field.
|
||
|
Extension types where this is not the case will need to override
|
||
|
`__batch_encoder__` with a custom `ExtensionTypeBatchEncoder`. See
|
||
|
`tf.experimental.ExtensionTypeBatchEncoder` for more details.
|
||
|
"""
|
||
|
# Let the metaclass know that it should *not* transform this class (since
|
||
|
# this class is part of the ExtensionType framework, and not a user class).
|
||
|
_tf_extension_type_do_not_transform_this_class = True
|
||
|
|
||
|
|
||
|
# For Pickle __reduce__ protocol:
|
||
|
def _deserialize_for_reduce(value_type, serialization):
|
||
|
return value_type.Spec._deserialize(serialization) # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
def _replace_tensor_with_spec(value):
|
||
|
if isinstance(value, ops.Tensor):
|
||
|
# Note: we intentionally exclude `value.name` from the `TensorSpec`.
|
||
|
return tensor_spec.TensorSpec(value.shape, value.dtype)
|
||
|
if hasattr(value, '_type_spec'):
|
||
|
return value._type_spec # pylint: disable=protected-access
|
||
|
return value
|
||
|
|
||
|
|
||
|
def _change_nested_mappings_to(value, new_type):
|
||
|
"""Recursively replace mappings with `new_type`."""
|
||
|
if isinstance(value, (dict, immutable_dict.ImmutableDict)):
|
||
|
return new_type([(k, _change_nested_mappings_to(v, new_type))
|
||
|
for (k, v) in value.items()])
|
||
|
elif isinstance(value, tuple):
|
||
|
return tuple(_change_nested_mappings_to(elt, new_type) for elt in value)
|
||
|
else:
|
||
|
return value
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# Helper methods for tf.ExtensionTypeMetaclass
|
||
|
# ==============================================================================
|
||
|
|
||
|
|
||
|
def _check_field_annotations(cls):
|
||
|
"""Validates the field annotations for tf.ExtensionType subclass `cls`."""
|
||
|
annotations = getattr(cls, '__annotations__', {})
|
||
|
|
||
|
# Check that no fields use reserved names.
|
||
|
for name, value in cls.__dict__.items():
|
||
|
if name == 'Spec':
|
||
|
if not isinstance(value, type):
|
||
|
raise ValueError(f'{cls.__qualname__}.Spec must be a nested class; '
|
||
|
f'got {value}.')
|
||
|
if (value.__bases__ != (type_spec.TypeSpec,) and value.__bases__ !=
|
||
|
(object,)):
|
||
|
raise ValueError(f'{cls.__qualname__}.Spec must be directly subclassed '
|
||
|
'from tf.TypeSpec.')
|
||
|
elif extension_type_field.ExtensionTypeField.is_reserved_name(name):
|
||
|
raise ValueError(f'The field annotations for {cls.__name__} are '
|
||
|
f"invalid. Field '{name}' is reserved.")
|
||
|
for name in annotations:
|
||
|
if extension_type_field.ExtensionTypeField.is_reserved_name(name):
|
||
|
raise ValueError(f'The field annotations for {cls.__name__} are '
|
||
|
f"invalid. Field '{name}' is reserved.")
|
||
|
|
||
|
# Check that all fields have type annotaitons.
|
||
|
for (key, value) in cls.__dict__.items():
|
||
|
if not (key in annotations or callable(value) or key.startswith('_abc_') or
|
||
|
key == '_tf_extension_type_fields' or
|
||
|
key.startswith('__') and key.endswith('__') or
|
||
|
isinstance(value, (property, classmethod, staticmethod))):
|
||
|
raise ValueError(f'The field annotations for {cls.__name__} are '
|
||
|
f'invalid. Field {key} is missing a type annotation.')
|
||
|
|
||
|
|
||
|
def _add_extension_type_constructor(cls):
|
||
|
"""Creates a constructor for a ExtensionType or ExtensionTypeSpec subclass."""
|
||
|
if '__init__' in cls.__dict__:
|
||
|
_wrap_user_constructor(cls)
|
||
|
else:
|
||
|
_build_extension_type_constructor(cls)
|
||
|
|
||
|
|
||
|
def _wrap_user_constructor(cls):
|
||
|
"""Wraps a user-defined constructor for tf.ExtensionType subclass `cls`."""
|
||
|
user_constructor = cls.__init__
|
||
|
|
||
|
def wrapped_init(self, *args, **kwargs):
|
||
|
self.__dict__[_IN_CONSTRUCTOR] = True
|
||
|
user_constructor(self, *args, **kwargs)
|
||
|
del self.__dict__[_IN_CONSTRUCTOR]
|
||
|
|
||
|
self._tf_extension_type_convert_fields() # pylint: disable=protected-access
|
||
|
self.__validate__()
|
||
|
|
||
|
cls.__init__ = tf_decorator.make_decorator(user_constructor, wrapped_init)
|
||
|
|
||
|
|
||
|
_NO_DEFAULT = extension_type_field.ExtensionTypeField.NO_DEFAULT
|
||
|
|
||
|
|
||
|
def _build_extension_type_constructor(cls):
|
||
|
"""Builds a constructor for tf.ExtensionType subclass `cls`."""
|
||
|
fields = cls._tf_extension_type_fields() # pylint: disable=protected-access
|
||
|
|
||
|
# Mark any no-default fields that follow default fields as keyword_only.
|
||
|
got_default = False
|
||
|
keyword_only_start = len(fields)
|
||
|
for i in range(len(fields)):
|
||
|
if got_default:
|
||
|
if fields[i].default is _NO_DEFAULT:
|
||
|
keyword_only_start = i
|
||
|
break
|
||
|
elif fields[i].default is not _NO_DEFAULT:
|
||
|
got_default = True
|
||
|
|
||
|
params = []
|
||
|
for i, field in enumerate(fields):
|
||
|
if i < keyword_only_start:
|
||
|
kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||
|
else:
|
||
|
kind = tf_inspect.Parameter.KEYWORD_ONLY
|
||
|
if field.default is _NO_DEFAULT:
|
||
|
default = tf_inspect.Parameter.empty
|
||
|
else:
|
||
|
default = field.default
|
||
|
params.append(
|
||
|
tf_inspect.Parameter(
|
||
|
field.name, kind, default=default, annotation=field.value_type))
|
||
|
|
||
|
signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
|
||
|
|
||
|
def __init__(self, *args, **kwargs): # pylint: disable=invalid-name
|
||
|
bound_args = signature.bind(*args, **kwargs)
|
||
|
bound_args.apply_defaults()
|
||
|
self.__dict__.update(bound_args.arguments)
|
||
|
self._tf_extension_type_convert_fields() # pylint: disable=protected-access
|
||
|
self.__validate__()
|
||
|
|
||
|
# __signature__ is supported by some inspection/documentation tools
|
||
|
# (but note: typing.get_type_hints does not respect __signature__).
|
||
|
__init__.__signature__ = tf_inspect.Signature(
|
||
|
[
|
||
|
tf_inspect.Parameter('self',
|
||
|
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||
|
] + params,
|
||
|
return_annotation=cls)
|
||
|
|
||
|
cls.__init__ = __init__
|
||
|
|
||
|
|
||
|
def _build_spec_constructor(cls):
|
||
|
"""Builds a constructor for ExtensionTypeSpec subclass `cls`."""
|
||
|
params = []
|
||
|
kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||
|
for field in cls._tf_extension_type_fields(): # pylint: disable=protected-access
|
||
|
params.append(tf_inspect.Parameter(field.name, kind))
|
||
|
|
||
|
signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
|
||
|
|
||
|
def __init__(self, *args, **kwargs): # pylint: disable=invalid-name
|
||
|
bound_args = signature.bind(*args, **kwargs)
|
||
|
bound_args.apply_defaults()
|
||
|
self.__dict__.update(bound_args.arguments)
|
||
|
self._tf_extension_type_convert_fields() # pylint: disable=protected-access
|
||
|
self.__validate__()
|
||
|
|
||
|
# __signature__ is supported by some inspection/documentation tools.
|
||
|
__init__.__signature__ = tf_inspect.Signature(
|
||
|
[
|
||
|
tf_inspect.Parameter('self',
|
||
|
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||
|
] + params,
|
||
|
return_annotation=cls)
|
||
|
|
||
|
cls.__init__ = __init__
|
||
|
|
||
|
|
||
|
def _add_type_spec(cls):
|
||
|
"""Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`."""
|
||
|
spec_name = cls.__name__ + '.Spec'
|
||
|
spec_qualname = cls.__qualname__ + '.Spec'
|
||
|
|
||
|
# Set __module__ explicitly as a dynamic created class has module='abc'
|
||
|
# by default.
|
||
|
spec_dict = {'value_type': cls, '__module__': cls.__module__}
|
||
|
|
||
|
# Copy user-supplied customizations into the TypeSpec.
|
||
|
user_spec = cls.__dict__.get('Spec', None)
|
||
|
if user_spec is not None:
|
||
|
for (name, value) in user_spec.__dict__.items():
|
||
|
if extension_type_field.ExtensionTypeField.is_reserved_name(name):
|
||
|
raise ValueError(f'TypeSpec {spec_qualname} uses reserved '
|
||
|
f"name '{name}'.")
|
||
|
if cls._tf_extension_type_has_field(name): # pylint: disable=protected-access
|
||
|
raise ValueError(f"TypeSpec {spec_qualname} defines a variable '{name}'"
|
||
|
f' which shadows a field in {cls.__qualname__}')
|
||
|
if name in ('__module__', '__dict__', '__weakref__'):
|
||
|
continue
|
||
|
|
||
|
spec_dict[name] = value
|
||
|
|
||
|
if issubclass(cls, BatchableExtensionType):
|
||
|
type_spec_base = BatchableExtensionTypeSpec
|
||
|
if hasattr(cls,
|
||
|
'__batch_encoder__') and '__batch_encoder__' not in spec_dict:
|
||
|
spec_dict['__batch_encoder__'] = cls.__batch_encoder__
|
||
|
else:
|
||
|
type_spec_base = ExtensionTypeSpec
|
||
|
if hasattr(cls, '__batch_encoder__') or '__batch_encoder__' in spec_dict:
|
||
|
raise ValueError('__batch_encoder__ should only be defined for '
|
||
|
'BatchableExtensionType classes.')
|
||
|
|
||
|
# Build the TypeSpec and store it as a nested class inside `cls`.
|
||
|
spec = type(spec_name, (type_spec_base,), spec_dict)
|
||
|
spec.__qualname__ = spec_qualname
|
||
|
setattr(cls, 'Spec', spec)
|
||
|
|
||
|
# Build a constructor for the TypeSpec class.
|
||
|
if '__init__' in spec.__dict__:
|
||
|
_wrap_user_constructor(spec)
|
||
|
else:
|
||
|
_build_spec_constructor(spec)
|
||
|
|
||
|
cls.__abstractmethods__ -= {'_type_spec'}
|
||
|
|
||
|
# If the user included an explicit `__name__` attribute, then use that to
|
||
|
# register the TypeSpec (so it can be used in SavedModel signatures).
|
||
|
if '__name__' in cls.__dict__:
|
||
|
type_spec_registry.register(cls.__dict__['__name__'] + '.Spec')(spec)
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# Anonymous ExtensionType
|
||
|
# ==============================================================================
|
||
|
class AnonymousExtensionType(ExtensionType):
|
||
|
"""Fallback used to decode `tf.ExtensionType` when the original type is unavailable.
|
||
|
|
||
|
When a SavedModel is serialized, the signatures of any functions in the
|
||
|
SavedModel can include `tf.ExtensionType` subclasses. These subclasses are
|
||
|
usually
|
||
|
registered, so they can be restored when the SavedModel is loaded. However,
|
||
|
if a SavedModel is loaded without first registering the ExtensionType types in
|
||
|
its
|
||
|
signature, then the SavedModel will fall back to using the
|
||
|
`AnonymousExtensionType`
|
||
|
type instead.
|
||
|
|
||
|
If necessary, `AnonymousExtensionType` objects can be converted to a concrete
|
||
|
`tf.ExtensionType` subclass (and vice versa) using `reinterpret`.
|
||
|
"""
|
||
|
|
||
|
# Let the metaclass know that it should *not* transform this class (since
|
||
|
# this class is part of the ExtensionType framework, and not a user class).
|
||
|
_tf_extension_type_do_not_transform_this_class = True
|
||
|
|
||
|
def __init__(self, **fields):
|
||
|
for name in fields:
|
||
|
if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
|
||
|
(name.startswith('__') and name.endswith('__'))):
|
||
|
raise ValueError(
|
||
|
f'Reserved field name {name} was encountered '
|
||
|
f'when trying to instantiate an AnonymousExtensionType.')
|
||
|
fields = [(k, _convert_anonymous_fields(v)) for (k, v) in fields.items()]
|
||
|
self.__dict__.update(fields)
|
||
|
self._tf_extension_type_convert_fields()
|
||
|
super().__init__()
|
||
|
|
||
|
@classmethod
|
||
|
def _tf_extension_type_fields(cls):
|
||
|
return [
|
||
|
extension_type_field.ExtensionTypeField(name, None)
|
||
|
for name in cls.__dict__
|
||
|
if not extension_type_field.ExtensionTypeField.is_reserved_name(name)
|
||
|
]
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
raise AttributeError(f'Cannot set attribute `{name}`. '
|
||
|
f'AnonymousExtensionType instances are immutable.')
|
||
|
|
||
|
def __delattr__(self, name):
|
||
|
raise AttributeError(f'Cannot delete attribute `{name}`. '
|
||
|
f'AnonymousExtensionType instances are immutable.')
|
||
|
|
||
|
def _tf_extension_type_convert_fields(self):
|
||
|
fields = [(k, _convert_anonymous_fields(v))
|
||
|
for (k, v) in self.__dict__.items()
|
||
|
if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
|
||
|
]
|
||
|
self.__dict__.update(fields)
|
||
|
|
||
|
def __repr__(self):
|
||
|
fields = [
|
||
|
f'{k}={v!r}' for (k, v) in self.__dict__.items()
|
||
|
if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
|
||
|
]
|
||
|
return f'AnonymousExtensionType({", ".join(fields)})'
|
||
|
|
||
|
_tf_extension_type_cached_type_spec = None
|
||
|
|
||
|
@property
|
||
|
def _type_spec(self): # CompositeTensor API.
|
||
|
# Note: the TypeSpec contains all static (non-tensor) data from `self`.
|
||
|
if self._tf_extension_type_cached_type_spec is None:
|
||
|
spec = AnonymousExtensionTypeSpec.from_value(self)
|
||
|
self.__dict__['_tf_extension_type_cached_type_spec'] = spec
|
||
|
return self._tf_extension_type_cached_type_spec
|
||
|
|
||
|
|
||
|
@type_spec_registry.register('tf.AnonymousExtensionType.Spec')
|
||
|
class AnonymousExtensionTypeSpec(ExtensionTypeSpec):
|
||
|
"""TypeSpec for AnonymousExtensionType."""
|
||
|
|
||
|
def __init__(self, **fields):
|
||
|
for name in fields:
|
||
|
if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
|
||
|
(name.startswith('__') and name.endswith('__'))):
|
||
|
raise ValueError(
|
||
|
f'Reserved field name {name} was encountered '
|
||
|
f'when trying to instantiate an AnonymousExtensionTypeSpec.')
|
||
|
fields = [(k, _convert_anonymous_fields(v, for_spec=True))
|
||
|
for (k, v) in fields.items()]
|
||
|
self.__dict__.update(fields)
|
||
|
super().__init__()
|
||
|
|
||
|
value_type = AnonymousExtensionType # TypeSpec API.
|
||
|
|
||
|
def _serialize(self): # TypeSpec API.
|
||
|
return tuple(
|
||
|
(name, _change_nested_mappings_to(value, dict))
|
||
|
for (name, value) in self.__dict__.items()
|
||
|
if not extension_type_field.ExtensionTypeField.is_reserved_name(name))
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
raise AttributeError(f'Cannot set attribute `{name}`. '
|
||
|
f'AnonymousExtensionTypeSpec instances are immutable.')
|
||
|
|
||
|
def __delattr__(self, name):
|
||
|
raise AttributeError(f'Cannot delete attribute `{name}`. '
|
||
|
f'AnonymousExtensionTypeSpec instances are immutable.')
|
||
|
|
||
|
|
||
|
def _convert_anonymous_fields(value, for_spec=False):
|
||
|
"""Type-checks and converts `value` for inclusion in an AnonymousExtensionType."""
|
||
|
if isinstance(value, (int, float, bool, str, bytes, type(None), dtypes.DType,
|
||
|
tensor_shape.TensorShape)):
|
||
|
return value
|
||
|
|
||
|
if isinstance(value, tuple):
|
||
|
return tuple(_convert_anonymous_fields(v, for_spec) for v in value)
|
||
|
|
||
|
if isinstance(value, typing.Mapping):
|
||
|
return immutable_dict.ImmutableDict([
|
||
|
(_convert_anonymous_fields(k, for_spec),
|
||
|
_convert_anonymous_fields(v, for_spec)) for (k, v) in value.items()
|
||
|
])
|
||
|
|
||
|
if (isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) and
|
||
|
not for_spec):
|
||
|
return value
|
||
|
|
||
|
if isinstance(value, type_spec.TypeSpec) and for_spec:
|
||
|
return value
|
||
|
|
||
|
raise ValueError(f'Cannot convert anonymous fields from '
|
||
|
f'an unsupported `value` argument: {value!r}.')
|
||
|
|
||
|
|
||
|
# ==============================================================================
|
||
|
# reinterpret
|
||
|
# ==============================================================================
|
||
|
def reinterpret(value, new_type):
|
||
|
"""Converts a given `ExtensionType` to a new type with compatible fields.
|
||
|
|
||
|
In particular, this can be used to convert a concrete subclass of
|
||
|
`ExtensionType` to an `AnonymousExtensionType`, or vice versa. When
|
||
|
converting to a non-anonymous ExtensionType, field values are type-checked to
|
||
|
ensure they are consistent with `new_type`'s type annotations, and validated
|
||
|
with `new_type.__validate__`.
|
||
|
|
||
|
Args:
|
||
|
value: An instance of a subclass of `tf.ExtensionType`
|
||
|
new_type: A subclass of `tf.ExtensionType`
|
||
|
|
||
|
Returns:
|
||
|
An instance of `new_type`, whose fields are copied from `value`.
|
||
|
"""
|
||
|
if not isinstance(value, ExtensionType):
|
||
|
raise ValueError(
|
||
|
f'reinterpret expects `value` to be a tf.ExtensionType instance; '
|
||
|
f'got {value!r}')
|
||
|
if not (isinstance(new_type, type) and issubclass(new_type, ExtensionType)):
|
||
|
raise ValueError(
|
||
|
f'reinterpret expects `new_type` to be a subclass of tf.ExtensionType; '
|
||
|
f'got {new_type!r}')
|
||
|
|
||
|
fields = [
|
||
|
item for item in value.__dict__.items()
|
||
|
if not extension_type_field.ExtensionTypeField.is_reserved_name(item[0])
|
||
|
]
|
||
|
new_value = _create_object_from_type_and_dict(new_type, fields)
|
||
|
new_value._tf_extension_type_convert_fields() # pylint: disable=protected-access
|
||
|
new_value.__validate__()
|
||
|
return new_value
|