Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/framework/extension_type.py

1189 lines
45 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""User-defined ExtensionType classes."""
import abc
import typing
import warnings
import typing_extensions
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import extension_type_field
from tensorflow.python.framework import immutable_dict
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.framework import type_spec_registry
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import composite_tensor_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
# Attribute used to keep track of when we're inside a user-defined constructor
# (in which case the fields of `self` may be modified).
_IN_CONSTRUCTOR = '_tf_extension_type_in_constructor'
_MUTABLE_KERAS_PROPERTIES = [
# Keras uses _keras_mask property to pass the mask around
'_keras_mask',
]
# ==============================================================================
# Utility functions
# ==============================================================================
def _create_object_from_type_and_dict(cls, obj_dict):
"""Creates an object, bypassing the constructor.
Creates an object of type `cls`, whose `__dict__` is updated to contain
`obj_dict`.
Args:
cls: The type of the new object.
obj_dict: A `Mapping` that should be used to initialize the new object's
`__dict__`.
Returns:
An object of type `cls`.
"""
value = object.__new__(cls)
value.__dict__.update(obj_dict)
return value
# ==============================================================================
# Metaclass for tf.ExtensionType
# ==============================================================================
class ExtensionTypeMetaclass(abc.ABCMeta):
"""Metaclass for tf.ExtensionType types."""
def __init__(cls, name, bases, namespace):
# Don't transform base classes that are part of the framework -- only
# transform user classes. We identify classes that are part of the
# framework by setting '_tf_extension_type_do_not_transform_this_class=True'
# in the class definition. (Note: we check for this in the class namespace,
# so it is *not* ineherited.)
if not namespace.get('_tf_extension_type_do_not_transform_this_class',
False):
_check_field_annotations(cls)
_add_extension_type_constructor(cls)
_add_type_spec(cls)
super(ExtensionTypeMetaclass, cls).__init__(name, bases, namespace)
# ==============================================================================
# Base class for user-defined types
# ==============================================================================
@tf_export('experimental.ExtensionType')
class ExtensionType(
composite_tensor.CompositeTensor, metaclass=ExtensionTypeMetaclass):
"""Base class for TensorFlow `ExtensionType` classes.
Tensorflow `ExtensionType` classes are specialized Python classes that can be
used transparently with TensorFlow -- e.g., they can be used with ops
such as `tf.cond` or `tf.while_loop` and used as inputs or outputs for
`tf.function` and Keras layers.
New `ExtensionType` classes are defined by creating a subclass of
`tf.ExtensionType` that
contains type annotations for all instance variables. The following type
annotations are supported:
Type | Example
-------------------- | --------------------------------------------
Python integers | `i: int`
Python floats | `f: float`
Python strings | `s: str`
Python booleans | `b: bool`
Python None | `n: None`
Tensors | `t: tf.Tensor`
Composite Tensors | `rt: tf.RaggedTensor`
Extension Types | `m: MyMaskedTensor`
Tensor shapes | `shape: tf.TensorShape`
Tensor dtypes | `dtype: tf.DType`
Type unions | `length: typing.Union[int, float]`
Tuples | `params: typing.Tuple[int, float, int, int]`
Tuples w/ Ellipsis | `lengths: typing.Tuple[int, ...]`
Mappings | `tags: typing.Mapping[str, str]`
Fields annotated with `typing.Mapping` will be stored using an immutable
mapping type.
ExtensionType values are immutable -- i.e., once constructed, you can not
modify or delete any of their instance members.
### Examples
>>> class MaskedTensor(ExtensionType):
... values: tf.Tensor
... mask: tf.Tensor
>>> class Toy(ExtensionType):
... name: str
... price: ops.Tensor
... features: typing.Mapping[str, tf.Tensor]
>>> class ToyStore(ExtensionType):
... name: str
... toys: typing.Tuple[Toy, ...]
"""
# Let the metaclass know that it should *not* transform this class (since
# this class is part of the ExtensionType framework, and not a user class).
_tf_extension_type_do_not_transform_this_class = True
def __init__(self, *args, **kwargs):
if type(self) is ExtensionType: # pylint: disable=unidiomatic-typecheck
raise AssertionError('Cannot create an instance of ExtensionType '
'because ExtensionType is an abstract base class.')
# This class variable is used to cache the return value for
# _tf_extension_type_fields.
_tf_extension_type_cached_fields = None
@classmethod
def _tf_extension_type_fields(cls): # pylint: disable=no-self-argument
"""An ordered list describing the fields of this ExtensionType.
Returns:
A list of `ExtensionTypeField` objects. Forward references are resolved
if possible, or left unresolved otherwise.
"""
if '_tf_extension_type_cached_fields' in cls.__dict__: # do not inherit.
return cls._tf_extension_type_cached_fields
try:
# Using include_extras=False will replace all Annotated[T, ...] with T.
# The typing_extensions module is used since this is only supported in
# Python 3.9.
type_hints = typing_extensions.get_type_hints(cls, include_extras=False)
ok_to_cache = True # all forward references have been resolved.
except (NameError, AttributeError):
# Unresolved forward reference -- gather type hints manually.
# * NameError comes from an annotation like `Foo` where class
# `Foo` hasn't been defined yet.
# * AttributeError comes from an annotation like `foo.Bar`, where
# the module `foo` exists but `Bar` hasn't been defined yet.
# Note: If a user attempts to instantiate a `ExtensionType` type that
# still has unresolved forward references (e.g., because of a typo or a
# missing import), then the constructor will raise an exception.
type_hints = {}
for base in reversed(cls.__mro__):
type_hints.update(base.__dict__.get('__annotations__', {}))
ok_to_cache = False
fields = []
for (name, value_type) in type_hints.items():
default = getattr(cls, name,
extension_type_field.ExtensionTypeField.NO_DEFAULT)
fields.append(
extension_type_field.ExtensionTypeField(name, value_type, default))
fields = tuple(fields)
if ok_to_cache:
cls._tf_extension_type_cached_fields = fields
return fields
@classmethod
def _tf_extension_type_has_field(cls, name):
return any(name == field.name for field in cls._tf_extension_type_fields())
def _tf_extension_type_convert_fields(self):
extension_type_field.convert_fields(self._tf_extension_type_fields(),
self.__dict__)
def __repr__(self):
fields = ', '.join([
f'{field.name}={getattr(self, field.name)!r}'
for field in self._tf_extension_type_fields()
])
return f'{type(self).__qualname__}({fields})'
def __setattr__(self, name, value):
if (name in _MUTABLE_KERAS_PROPERTIES or
(hasattr(self, _IN_CONSTRUCTOR) and
self._tf_extension_type_has_field(name))):
self.__dict__[name] = value
else:
raise AttributeError(f'Cannot mutate attribute `{name}` '
f'outside the custom constructor of ExtensionType.')
def __delattr__(self, name):
if (name in _MUTABLE_KERAS_PROPERTIES or
(hasattr(self, _IN_CONSTRUCTOR) and
self._tf_extension_type_has_field(name))):
del self.__dict__[name]
else:
raise AttributeError(f'Cannot mutate attribute `{name}` '
f'outside the custom constructor of ExtensionType.')
def __getattr__(self, name):
if name in _MUTABLE_KERAS_PROPERTIES:
return object.__getattribute__(self, name)
if '_tf_extension_type_packed_variant' in self.__dict__:
# Note: it's *not* ok to cache the results of unpack() here. In
# particular, it would be nice if we could do something like
# `self.__dict__.update(unpack(self).__dict__)`, but that (potentially)
# violates an invariant required by the `cond` operation. E.g., if we had
# `tf.cond(lambda: x.foo, lambda: x.bar)`, then tensor `x.bar` used in the
# "else" branch would be created by an op in the "then" branch (when
# looking up `x.foo`); and that's not allowed.
return getattr(unpack(self), name)
raise AttributeError(
f'{type(self).__name__!r} object has no attribute {name!r}')
def __eq__(self, other):
if type(self) is not type(other):
return False
if self._type_spec != other._type_spec:
return False
self_tensors = nest.flatten(self, expand_composites=True)
other_tensors = nest.flatten(other, expand_composites=True)
if len(self_tensors) != len(other_tensors):
return False
conditions = []
for t1, t2 in zip(self_tensors, other_tensors):
conditions.append(
math_ops.reduce_all(
gen_math_ops.equal(
array_ops.shape(t1),
array_ops.shape(t2),
incompatible_shape_error=False)))
# Explicitly check shape (values that have different shapes but broadcast
# to the same value are considered non-equal).
conditions.append(
math_ops.reduce_all(
gen_math_ops.equal(t1, t2, incompatible_shape_error=False)))
return math_ops.reduce_all(array_ops.stack(conditions))
def __ne__(self, other):
eq = self.__eq__(other)
if isinstance(eq, ops.Tensor):
return math_ops.logical_not(eq)
else:
return not eq
def __validate__(self):
"""Perform post-construction validation."""
# This instance variable is used to cache the value for the _type_spec
# property.
_tf_extension_type_cached_type_spec = None
@property
def _type_spec(self): # CompositeTensor API.
# Note: the TypeSpec contains all static (non-tensor) data from `self`.
if self._tf_extension_type_cached_type_spec is None:
assert not is_packed(self) # Packed version always caches TypeSpec.
self.__dict__[
'_tf_extension_type_cached_type_spec'] = self.Spec.from_value(self)
return self._tf_extension_type_cached_type_spec
@tf_export('experimental.extension_type.as_dict')
def as_dict(value):
"""Extracts the attributes of `value` and their values to a dict format.
Unlike `dataclasses.asdict()`, this function is not recursive and in case of
nested `ExtensionType` objects, only the top level object is converted to a
dict.
Args:
value: An `ExtensionType` object.
Returns:
A dict that contains the attributes of `value` and their values.
"""
return {
field.name: getattr(value, field.name)
for field in value._tf_extension_type_fields() # pylint: disable=protected-access
}
def pack(value):
"""Returns a copy of `value` with fields packed in a single Variant.
Args:
value: An `ExtensionType` object.
Returns:
An `ExtensionType` object.
"""
if is_packed(value):
return value
spec = value._type_spec._tf_extension_type_with_packed(True) # pylint: disable=protected-access
try:
variant = composite_tensor_ops.composite_tensor_to_variants(value)
except nested_structure_coder.NotEncodableError as e:
# Note: the only time `_TypeSpecCodec.can_encode` returns False is if the
# named type is not registered. The default error message would simply
# tell the user that there is no encoder for the object, so we provide
# a more useful message letting them know how to register the type.
raise ValueError('ExtensionTypes must have a __name__ field in order '
'to be packed.') from e
return _create_object_from_type_and_dict(
type(value), {
'_tf_extension_type_cached_type_spec': spec,
'_tf_extension_type_packed_variant': variant,
})
def unpack(value):
"""Returns a copy of `value` with individual fields stored in __dict__.
Args:
value: An `ExtensionType` object.
Returns:
An `ExtensionType` object.
"""
if not is_packed(value):
return value
# pylint: disable=protected-access
variant = value._tf_extension_type_packed_variant
spec = value._tf_extension_type_cached_type_spec
spec = spec._tf_extension_type_with_packed(False)
return composite_tensor_ops.composite_tensor_from_variant(variant, spec)
def is_packed(value):
"""Returns true if `value`'s fields are packed in a single Variant."""
if not isinstance(value, ExtensionType):
raise ValueError(f'Expected `value` to be an object of type ExtensionType,'
f'got an instance of {type(value)}.')
return '_tf_extension_type_packed_variant' in value.__dict__
# ==============================================================================
# Base class for the tf.ExtensionType TypeSpecs
# ==============================================================================
class ExtensionTypeSpec(type_spec.TypeSpec):
"""Base class for tf.ExtensionType TypeSpec."""
def _serialize(self): # TypeSpec API.
# Use a tuple of (name, value) pairs, to ensure we preserve field ordering.
fields = [f.name for f in self._tf_extension_type_fields()]
if self._tf_extension_type_is_packed:
fields.append('_tf_extension_type_is_packed')
return tuple(
(f, _change_nested_mappings_to(self.__dict__[f], dict)) for f in fields)
@classmethod
def _deserialize(cls, state): # TypeSpec API.
state = _change_nested_mappings_to(state, immutable_dict.ImmutableDict)
return _create_object_from_type_and_dict(cls, state)
def __reduce__(self):
# Use value_type instead of spec_type, as spec_type is a nested class.
# Pickle support of nested class requries Pickle protocol version 4, which
# is not enabled by default until py 3.8.
#
# https://www.python.org/dev/peps/pep-3154/#serializing-more-lookupable-objects
# https://docs.python.org/3/library/pickle.html#pickle.DEFAULT_PROTOCOL
return _deserialize_for_reduce, (self.value_type, self._serialize())
def _to_components(self, value): # TypeSpec API.
if self._tf_extension_type_is_packed:
return value._tf_extension_type_packed_variant # pylint: disable=protected-access
tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor)
# Retireve fields by the order of spec dict to preserve field ordering. This
# is needed as nest.flatten would sort dictionary entries by key.
value_tuple = tuple(value.__dict__[key] for key in self.__dict__)
return tuple(
x for x in nest.flatten(value_tuple)
if isinstance(x, tensor_or_composite))
def _from_components(self, components): # TypeSpec API.
if self._tf_extension_type_is_packed:
return _create_object_from_type_and_dict(
self.value_type, {
'_tf_extension_type_cached_type_spec': self,
'_tf_extension_type_packed_variant': components
})
spec_tuple = tuple(self.__dict__.values())
components_iter = iter(components)
flat = [
next(components_iter) if isinstance(x, type_spec.TypeSpec) else x
for x in nest.flatten(spec_tuple)
]
if list(components_iter):
raise ValueError(
'Cannot build an ExtensionType instance from components '
'because more components are provided than the number expected '
'by the type spec.')
value_tuple = nest.pack_sequence_as(spec_tuple, flat)
fields = dict(zip(self.__dict__.keys(), value_tuple))
# Build the new value. Bypass the constructor (__init__), in case the user
# who defined the ExtensionType used a custom constructor.
return _create_object_from_type_and_dict(self.value_type, fields)
@property
def _component_specs(self): # TypeSpec API.
if self._tf_extension_type_is_packed:
return tensor_spec.TensorSpec((), dtypes.variant)
components = []
def push_if_type_spec(x):
if isinstance(x, type_spec.TypeSpec):
components.append(x)
nest.map_structure(push_if_type_spec, tuple(self.__dict__.values()))
return tuple(components)
@classmethod
def from_value(cls, value):
cached_spec = getattr(value, '_tf_extension_type_cached_type_spec', None)
if cached_spec is not None:
return cached_spec
value_fields = value.__dict__
spec_fields = nest.map_structure(_replace_tensor_with_spec, value_fields)
spec_fields.pop('_tf_extension_type_cached_fields', None)
return _create_object_from_type_and_dict(cls, spec_fields)
def __setattr__(self, name, value):
if (hasattr(self, _IN_CONSTRUCTOR) and
self._tf_extension_type_has_field(name)):
self.__dict__[name] = value
else:
raise AttributeError(
f'Cannot mutate attribute `{name}` '
f'outside the custom constructor of ExtensionTypeSpec.')
def __delattr__(self, name):
if (hasattr(self, _IN_CONSTRUCTOR) and
self._tf_extension_type_has_field(name)):
del self.__dict__[name]
else:
raise AttributeError(
f'Cannot mutate attribute `{name}` '
f'outside the custom constructor of ExtensionTypeSpec.')
def __validate__(self):
"""Perform post-construction validation."""
@classmethod
def _tf_extension_type_fields(cls):
return cls.value_type._tf_extension_type_fields() # pylint: disable=protected-access
@classmethod
def _tf_extension_type_has_field(cls, name):
return any(name == field.name for field in cls._tf_extension_type_fields())
def _tf_extension_type_convert_fields(self):
extension_type_field.convert_fields_for_spec(
self._tf_extension_type_fields(), self.__dict__)
def __repr__(self):
fields = ', '.join([f'{k}={v!r}' for (k, v) in self._serialize()])
return f'{type(self).__qualname__}({fields})'
_tf_extension_type_is_packed = False
def _tf_extension_type_with_packed(self, value):
"""Returns a copy of this `TypeSpec` with `packed=value`.
Args:
value: A boolean value.
Returns:
A copy of `self` with `_tf_extension_type_is_packed=value`.
"""
copy = _create_object_from_type_and_dict(type(self), self.__dict__)
copy.__dict__['_tf_extension_type_is_packed'] = value
return copy
class _ExtensionTypeSpecCodec:
"""Codec for `tf.ExtensionTypeSpec`."""
def can_encode(self, pyobj):
"""Returns true if `pyobj` can be encoded as an ExtensionTypeSpec."""
if isinstance(pyobj, ExtensionTypeSpec):
try:
type_spec_registry.get_name(type(pyobj))
return True
except ValueError:
return False
return False
def do_encode(self, extension_type_spec_value, encode_fn):
"""Returns an encoded proto for the given `tf.ExtensionTypeSpec`."""
type_spec_class_name = type_spec_registry.get_name(
type(extension_type_spec_value))
type_state = extension_type_spec_value._serialize() # pylint: disable=protected-access
num_flat_components = len(
nest.flatten(
extension_type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access
encoded_type_spec = struct_pb2.StructuredValue()
encoded_type_spec.type_spec_value.CopyFrom(
struct_pb2.TypeSpecProto(
type_spec_class=struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC,
type_state=encode_fn(type_state),
type_spec_class_name=type_spec_class_name,
num_flat_components=num_flat_components))
return encoded_type_spec
def can_decode(self, value):
"""Returns true if `value` can be decoded into a `tf.ExtensionTypeSpec`."""
if value.HasField('type_spec_value'):
type_spec_class_enum = value.type_spec_value.type_spec_class
return (
type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC)
return False
def do_decode(self, value, decode_fn):
"""Returns the `tf.TypeSpec` encoded by the proto `value`."""
type_spec_proto = value.type_spec_value
class_name = type_spec_proto.type_spec_class_name
try:
type_spec_class = type_spec_registry.lookup(class_name)
except ValueError:
type_spec_class = AnonymousExtensionTypeSpec
warnings.warn(
f"The type '{class_name}' has not been registered. "
'Falling back to using AnonymousExtensionTypeSpec '
'instead.'
)
# pylint: disable=protected-access
return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
nested_structure_coder.register_codec(_ExtensionTypeSpecCodec())
@tf_export('experimental.ExtensionTypeBatchEncoder')
class ExtensionTypeBatchEncoder(type_spec.TypeSpecBatchEncoder):
"""Class used to encode and decode extension type values for batching.
In order to be batched and unbatched by APIs such as `tf.data.Dataset`,
`tf.keras`, and `tf.map_fn`, extension type values must be encoded as a list
of `tf.Tensor`s, where stacking, unstacking, or concatenating these encoded
tensors and then decoding the result must be equivalent to stacking,
unstacking, or concatenating the original values. `ExtensionTypeBatchEncoder`s
are responsible for implementing this encoding.
The default `ExtensionTypeBatchEncoder` that is used by
`BatchableExtensionType` assumes that extension type values can be stacked,
unstacked, or concatenated by simply stacking, unstacking, or concatenating
every nested `Tensor`, `ExtensionType`, `CompositeTensor`, and `TensorShape`
field.
Extension types where this is not the case will need to override
`__batch_encoder__` with a custom encoder that overrides the `batch`,
`unbatch`, `encode`, and `decode` methods. E.g.:
>>> class CustomBatchEncoder(ExtensionTypeBatchEncoder):
... pass # Override batch(), unbatch(), encode(), and decode().
>>> class CustomType(BatchableExtensionType):
... x: tf.Tensor
... y: tf.Tensor
... shape: tf.TensorShape
... __batch_encoder__ = CustomBatchEncoder()
For example, `tf.RaggedTensor` and `tf.SparseTensor` both use custom batch
encodings which define ops to "box" and "unbox" individual values into
`tf.variant` tensors.
"""
def batch(self, spec, batch_size):
"""Returns the TypeSpec representing a batch of values described by `spec`.
The default definition returns a `TypeSpec` that is equal to `spec`, except
that an outer axis with size `batch_size` is added to every nested
`TypeSpec` and `TensorShape` field. Subclasses may override this default
definition, when necessary.
Args:
spec: The `TypeSpec` for an individual value.
batch_size: An `int` indicating the number of values that are batched
together, or `None` if the batch size is not known.
Returns:
A `TypeSpec` for a batch of values.
"""
def batch_field(f):
if isinstance(f, type_spec.BatchableTypeSpec):
return f.__batch_encoder__.batch(f, batch_size)
elif isinstance(f, tensor_shape.TensorShape):
return [batch_size] + f
else:
return f
fields = tuple(spec.__dict__.items())
batched_fields = nest.map_structure(batch_field, fields)
return _create_object_from_type_and_dict(type(spec), batched_fields)
def unbatch(self, spec):
"""Returns the TypeSpec for a single unbatched element in `spec`.
The default definition returns a `TypeSpec` that is equal to `spec`, except
that the outermost axis is removed from every nested `TypeSpec`, and
`TensorShape` field. Subclasses may override this default definition, when
necessary.
Args:
spec: The `TypeSpec` for a batch of values.
Returns:
A `TypeSpec` for an individual value.
"""
def unbatch_field(f):
if isinstance(f, type_spec.BatchableTypeSpec):
return f.__batch_encoder__.unbatch(f)
elif isinstance(f, tensor_shape.TensorShape):
return f[1:]
else:
return f
fields = tuple(spec.__dict__.items())
unbatched_fields = nest.map_structure(unbatch_field, fields)
return _create_object_from_type_and_dict(type(spec), unbatched_fields)
def encode(self, spec, value, minimum_rank=0):
"""Encodes `value` as a nest of batchable Tensors or CompositeTensors.
The default definition returns a flat tuple of all the `Tensor`s,
`CompositeTensor`s, and `ExtensionType`s from a depth-first traversal of
`value`'s fields. Subclasses may override this default definition, when
necessary.
Args:
spec: The TypeSpec of the value to encode.
value: A value compatible with `spec`.
minimum_rank: The minimum rank for the returned Tensors, CompositeTensors,
and ExtensionType values. This can be used to ensure that the encoded
values can be unbatched this number of times. If `minimum_rank>0`,
then `t.shape[:minimum_rank]` must be compatible for all values `t`
returned by `encode`.
Returns:
A nest (as defined by `tf.nest`) of `tf.Tensor`s, batchable
`tf.CompositeTensor`s, or `tf.ExtensionType`s. Stacking, unstacking, or
concatenating these encoded values and then decoding the result must be
equivalent to stacking, unstacking, or concatenating the original values.
"""
return spec._to_components(value) # pylint: disable=protected-access
def decode(self, spec, encoded_value):
"""Decodes `value` from a batchable tensor encoding.
See `encode` for a description of the default encoding. Subclasses may
override this default definition, when necessary.
Args:
spec: The TypeSpec for the result value. If encoded values with spec `s`
were batched, then `spec` should be `s.batch(batch_size)`; or if encoded
values with spec `s` were unbatched, then `spec` should be
`s.unbatch()`.
encoded_value: A nest of values returned by `encode`; or a nest of
values that was formed by stacking, unstacking, or concatenating the
corresponding elements of values returned by `encode`.
Returns:
A value compatible with `type_spec`.
"""
return spec._from_components(encoded_value) # pylint: disable=protected-access
def encoding_specs(self, spec):
"""Returns a list of `TensorSpec`(s) describing the encoding for `spec`.
See `encode` for a description of the default encoding. Subclasses may
override this default definition, when necessary.
Args:
spec: The TypeSpec whose encoding should be described.
Returns:
A nest (as defined by `tf.nest) of `tf.TypeSpec`, describing the values
that are returned by `self.encode(spec, ...)`. All TypeSpecs in this
nest must be batchable.
"""
return spec._component_specs # pylint: disable=protected-access
class BatchableExtensionTypeSpec(ExtensionTypeSpec,
type_spec.BatchableTypeSpec):
"""Base class for TypeSpecs for BatchableExtensionTypes."""
__batch_encoder__ = ExtensionTypeBatchEncoder()
def _batch(self, batch_size):
return self.__batch_encoder__.batch(self, batch_size)
def _unbatch(self):
return self.__batch_encoder__.unbatch(self)
def _to_tensor_list(self, value):
return type_spec.batchable_to_tensor_list(self, value)
def _to_batched_tensor_list(self, value):
return type_spec.batchable_to_tensor_list(self, value, minimum_rank=1)
def _from_compatible_tensor_list(self, tensor_list):
return type_spec.batchable_from_tensor_list(self, tensor_list)
@property
def _flat_tensor_specs(self):
return type_spec.get_batchable_flat_tensor_specs(self)
@tf_export('experimental.BatchableExtensionType')
class BatchableExtensionType(ExtensionType):
"""An ExtensionType that can be batched and unbatched.
`BatchableExtensionType`s can be used with APIs that require batching or
unbatching, including `Keras`, `tf.data.Dataset`, and `tf.map_fn`. E.g.:
>>> class Vehicle(tf.experimental.BatchableExtensionType):
... top_speed: tf.Tensor
... mpg: tf.Tensor
>>> batch = Vehicle([120, 150, 80], [30, 40, 12])
>>> tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,
... fn_output_signature=tf.int32).numpy()
array([3600, 6000, 960], dtype=int32)
An `ExtensionTypeBatchEncoder` is used by these APIs to encode `ExtensionType`
values. The default encoder assumes that values can be stacked, unstacked, or
concatenated by simply stacking, unstacking, or concatenating every nested
`Tensor`, `ExtensionType`, `CompositeTensor`, or `TensorShape` field.
Extension types where this is not the case will need to override
`__batch_encoder__` with a custom `ExtensionTypeBatchEncoder`. See
`tf.experimental.ExtensionTypeBatchEncoder` for more details.
"""
# Let the metaclass know that it should *not* transform this class (since
# this class is part of the ExtensionType framework, and not a user class).
_tf_extension_type_do_not_transform_this_class = True
# For Pickle __reduce__ protocol:
def _deserialize_for_reduce(value_type, serialization):
return value_type.Spec._deserialize(serialization) # pylint: disable=protected-access
def _replace_tensor_with_spec(value):
if isinstance(value, ops.Tensor):
# Note: we intentionally exclude `value.name` from the `TensorSpec`.
return tensor_spec.TensorSpec(value.shape, value.dtype)
if hasattr(value, '_type_spec'):
return value._type_spec # pylint: disable=protected-access
return value
def _change_nested_mappings_to(value, new_type):
"""Recursively replace mappings with `new_type`."""
if isinstance(value, (dict, immutable_dict.ImmutableDict)):
return new_type([(k, _change_nested_mappings_to(v, new_type))
for (k, v) in value.items()])
elif isinstance(value, tuple):
return tuple(_change_nested_mappings_to(elt, new_type) for elt in value)
else:
return value
# ==============================================================================
# Helper methods for tf.ExtensionTypeMetaclass
# ==============================================================================
def _check_field_annotations(cls):
"""Validates the field annotations for tf.ExtensionType subclass `cls`."""
annotations = getattr(cls, '__annotations__', {})
# Check that no fields use reserved names.
for name, value in cls.__dict__.items():
if name == 'Spec':
if not isinstance(value, type):
raise ValueError(f'{cls.__qualname__}.Spec must be a nested class; '
f'got {value}.')
if (value.__bases__ != (type_spec.TypeSpec,) and value.__bases__ !=
(object,)):
raise ValueError(f'{cls.__qualname__}.Spec must be directly subclassed '
'from tf.TypeSpec.')
elif extension_type_field.ExtensionTypeField.is_reserved_name(name):
raise ValueError(f'The field annotations for {cls.__name__} are '
f"invalid. Field '{name}' is reserved.")
for name in annotations:
if extension_type_field.ExtensionTypeField.is_reserved_name(name):
raise ValueError(f'The field annotations for {cls.__name__} are '
f"invalid. Field '{name}' is reserved.")
# Check that all fields have type annotaitons.
for (key, value) in cls.__dict__.items():
if not (key in annotations or callable(value) or key.startswith('_abc_') or
key == '_tf_extension_type_fields' or
key.startswith('__') and key.endswith('__') or
isinstance(value, (property, classmethod, staticmethod))):
raise ValueError(f'The field annotations for {cls.__name__} are '
f'invalid. Field {key} is missing a type annotation.')
def _add_extension_type_constructor(cls):
"""Creates a constructor for a ExtensionType or ExtensionTypeSpec subclass."""
if '__init__' in cls.__dict__:
_wrap_user_constructor(cls)
else:
_build_extension_type_constructor(cls)
def _wrap_user_constructor(cls):
"""Wraps a user-defined constructor for tf.ExtensionType subclass `cls`."""
user_constructor = cls.__init__
def wrapped_init(self, *args, **kwargs):
self.__dict__[_IN_CONSTRUCTOR] = True
user_constructor(self, *args, **kwargs)
del self.__dict__[_IN_CONSTRUCTOR]
self._tf_extension_type_convert_fields() # pylint: disable=protected-access
self.__validate__()
cls.__init__ = tf_decorator.make_decorator(user_constructor, wrapped_init)
_NO_DEFAULT = extension_type_field.ExtensionTypeField.NO_DEFAULT
def _build_extension_type_constructor(cls):
"""Builds a constructor for tf.ExtensionType subclass `cls`."""
fields = cls._tf_extension_type_fields() # pylint: disable=protected-access
# Mark any no-default fields that follow default fields as keyword_only.
got_default = False
keyword_only_start = len(fields)
for i in range(len(fields)):
if got_default:
if fields[i].default is _NO_DEFAULT:
keyword_only_start = i
break
elif fields[i].default is not _NO_DEFAULT:
got_default = True
params = []
for i, field in enumerate(fields):
if i < keyword_only_start:
kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
else:
kind = tf_inspect.Parameter.KEYWORD_ONLY
if field.default is _NO_DEFAULT:
default = tf_inspect.Parameter.empty
else:
default = field.default
params.append(
tf_inspect.Parameter(
field.name, kind, default=default, annotation=field.value_type))
signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
def __init__(self, *args, **kwargs): # pylint: disable=invalid-name
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
self.__dict__.update(bound_args.arguments)
self._tf_extension_type_convert_fields() # pylint: disable=protected-access
self.__validate__()
# __signature__ is supported by some inspection/documentation tools
# (but note: typing.get_type_hints does not respect __signature__).
__init__.__signature__ = tf_inspect.Signature(
[
tf_inspect.Parameter('self',
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
] + params,
return_annotation=cls)
cls.__init__ = __init__
def _build_spec_constructor(cls):
"""Builds a constructor for ExtensionTypeSpec subclass `cls`."""
params = []
kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
for field in cls._tf_extension_type_fields(): # pylint: disable=protected-access
params.append(tf_inspect.Parameter(field.name, kind))
signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
def __init__(self, *args, **kwargs): # pylint: disable=invalid-name
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
self.__dict__.update(bound_args.arguments)
self._tf_extension_type_convert_fields() # pylint: disable=protected-access
self.__validate__()
# __signature__ is supported by some inspection/documentation tools.
__init__.__signature__ = tf_inspect.Signature(
[
tf_inspect.Parameter('self',
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
] + params,
return_annotation=cls)
cls.__init__ = __init__
def _add_type_spec(cls):
"""Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`."""
spec_name = cls.__name__ + '.Spec'
spec_qualname = cls.__qualname__ + '.Spec'
# Set __module__ explicitly as a dynamic created class has module='abc'
# by default.
spec_dict = {'value_type': cls, '__module__': cls.__module__}
# Copy user-supplied customizations into the TypeSpec.
user_spec = cls.__dict__.get('Spec', None)
if user_spec is not None:
for (name, value) in user_spec.__dict__.items():
if extension_type_field.ExtensionTypeField.is_reserved_name(name):
raise ValueError(f'TypeSpec {spec_qualname} uses reserved '
f"name '{name}'.")
if cls._tf_extension_type_has_field(name): # pylint: disable=protected-access
raise ValueError(f"TypeSpec {spec_qualname} defines a variable '{name}'"
f' which shadows a field in {cls.__qualname__}')
if name in ('__module__', '__dict__', '__weakref__'):
continue
spec_dict[name] = value
if issubclass(cls, BatchableExtensionType):
type_spec_base = BatchableExtensionTypeSpec
if hasattr(cls,
'__batch_encoder__') and '__batch_encoder__' not in spec_dict:
spec_dict['__batch_encoder__'] = cls.__batch_encoder__
else:
type_spec_base = ExtensionTypeSpec
if hasattr(cls, '__batch_encoder__') or '__batch_encoder__' in spec_dict:
raise ValueError('__batch_encoder__ should only be defined for '
'BatchableExtensionType classes.')
# Build the TypeSpec and store it as a nested class inside `cls`.
spec = type(spec_name, (type_spec_base,), spec_dict)
spec.__qualname__ = spec_qualname
setattr(cls, 'Spec', spec)
# Build a constructor for the TypeSpec class.
if '__init__' in spec.__dict__:
_wrap_user_constructor(spec)
else:
_build_spec_constructor(spec)
cls.__abstractmethods__ -= {'_type_spec'}
# If the user included an explicit `__name__` attribute, then use that to
# register the TypeSpec (so it can be used in SavedModel signatures).
if '__name__' in cls.__dict__:
type_spec_registry.register(cls.__dict__['__name__'] + '.Spec')(spec)
# ==============================================================================
# Anonymous ExtensionType
# ==============================================================================
class AnonymousExtensionType(ExtensionType):
"""Fallback used to decode `tf.ExtensionType` when the original type is unavailable.
When a SavedModel is serialized, the signatures of any functions in the
SavedModel can include `tf.ExtensionType` subclasses. These subclasses are
usually
registered, so they can be restored when the SavedModel is loaded. However,
if a SavedModel is loaded without first registering the ExtensionType types in
its
signature, then the SavedModel will fall back to using the
`AnonymousExtensionType`
type instead.
If necessary, `AnonymousExtensionType` objects can be converted to a concrete
`tf.ExtensionType` subclass (and vice versa) using `reinterpret`.
"""
# Let the metaclass know that it should *not* transform this class (since
# this class is part of the ExtensionType framework, and not a user class).
_tf_extension_type_do_not_transform_this_class = True
def __init__(self, **fields):
for name in fields:
if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
(name.startswith('__') and name.endswith('__'))):
raise ValueError(
f'Reserved field name {name} was encountered '
f'when trying to instantiate an AnonymousExtensionType.')
fields = [(k, _convert_anonymous_fields(v)) for (k, v) in fields.items()]
self.__dict__.update(fields)
self._tf_extension_type_convert_fields()
super().__init__()
@classmethod
def _tf_extension_type_fields(cls):
return [
extension_type_field.ExtensionTypeField(name, None)
for name in cls.__dict__
if not extension_type_field.ExtensionTypeField.is_reserved_name(name)
]
def __setattr__(self, name, value):
raise AttributeError(f'Cannot set attribute `{name}`. '
f'AnonymousExtensionType instances are immutable.')
def __delattr__(self, name):
raise AttributeError(f'Cannot delete attribute `{name}`. '
f'AnonymousExtensionType instances are immutable.')
def _tf_extension_type_convert_fields(self):
fields = [(k, _convert_anonymous_fields(v))
for (k, v) in self.__dict__.items()
if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
]
self.__dict__.update(fields)
def __repr__(self):
fields = [
f'{k}={v!r}' for (k, v) in self.__dict__.items()
if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
]
return f'AnonymousExtensionType({", ".join(fields)})'
_tf_extension_type_cached_type_spec = None
@property
def _type_spec(self): # CompositeTensor API.
# Note: the TypeSpec contains all static (non-tensor) data from `self`.
if self._tf_extension_type_cached_type_spec is None:
spec = AnonymousExtensionTypeSpec.from_value(self)
self.__dict__['_tf_extension_type_cached_type_spec'] = spec
return self._tf_extension_type_cached_type_spec
@type_spec_registry.register('tf.AnonymousExtensionType.Spec')
class AnonymousExtensionTypeSpec(ExtensionTypeSpec):
"""TypeSpec for AnonymousExtensionType."""
def __init__(self, **fields):
for name in fields:
if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
(name.startswith('__') and name.endswith('__'))):
raise ValueError(
f'Reserved field name {name} was encountered '
f'when trying to instantiate an AnonymousExtensionTypeSpec.')
fields = [(k, _convert_anonymous_fields(v, for_spec=True))
for (k, v) in fields.items()]
self.__dict__.update(fields)
super().__init__()
value_type = AnonymousExtensionType # TypeSpec API.
def _serialize(self): # TypeSpec API.
return tuple(
(name, _change_nested_mappings_to(value, dict))
for (name, value) in self.__dict__.items()
if not extension_type_field.ExtensionTypeField.is_reserved_name(name))
def __setattr__(self, name, value):
raise AttributeError(f'Cannot set attribute `{name}`. '
f'AnonymousExtensionTypeSpec instances are immutable.')
def __delattr__(self, name):
raise AttributeError(f'Cannot delete attribute `{name}`. '
f'AnonymousExtensionTypeSpec instances are immutable.')
def _convert_anonymous_fields(value, for_spec=False):
"""Type-checks and converts `value` for inclusion in an AnonymousExtensionType."""
if isinstance(value, (int, float, bool, str, bytes, type(None), dtypes.DType,
tensor_shape.TensorShape)):
return value
if isinstance(value, tuple):
return tuple(_convert_anonymous_fields(v, for_spec) for v in value)
if isinstance(value, typing.Mapping):
return immutable_dict.ImmutableDict([
(_convert_anonymous_fields(k, for_spec),
_convert_anonymous_fields(v, for_spec)) for (k, v) in value.items()
])
if (isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) and
not for_spec):
return value
if isinstance(value, type_spec.TypeSpec) and for_spec:
return value
raise ValueError(f'Cannot convert anonymous fields from '
f'an unsupported `value` argument: {value!r}.')
# ==============================================================================
# reinterpret
# ==============================================================================
def reinterpret(value, new_type):
"""Converts a given `ExtensionType` to a new type with compatible fields.
In particular, this can be used to convert a concrete subclass of
`ExtensionType` to an `AnonymousExtensionType`, or vice versa. When
converting to a non-anonymous ExtensionType, field values are type-checked to
ensure they are consistent with `new_type`'s type annotations, and validated
with `new_type.__validate__`.
Args:
value: An instance of a subclass of `tf.ExtensionType`
new_type: A subclass of `tf.ExtensionType`
Returns:
An instance of `new_type`, whose fields are copied from `value`.
"""
if not isinstance(value, ExtensionType):
raise ValueError(
f'reinterpret expects `value` to be a tf.ExtensionType instance; '
f'got {value!r}')
if not (isinstance(new_type, type) and issubclass(new_type, ExtensionType)):
raise ValueError(
f'reinterpret expects `new_type` to be a subclass of tf.ExtensionType; '
f'got {new_type!r}')
fields = [
item for item in value.__dict__.items()
if not extension_type_field.ExtensionTypeField.is_reserved_name(item[0])
]
new_value = _create_object_from_type_and_dict(new_type, fields)
new_value._tf_extension_type_convert_fields() # pylint: disable=protected-access
new_value.__validate__()
return new_value