2023-06-19 00:49:18 +02:00

669 lines
24 KiB

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TraceType implementations for common Python types."""
import collections
from typing import Any, Hashable, Optional, Sequence, Type
from typing import Dict as PythonDict
from typing import Tuple as PythonTuple
import weakref
from tensorflow.core.function.trace_type import default_types_pb2
from tensorflow.core.function.trace_type import serialization
from tensorflow.core.function.trace_type import util
from tensorflow.python.types import trace
class Literal(trace.TraceType, serialization.Serializable):
"""Represents a Literal type like bool, int or string."""
def __init__(self, value: Any):
self.value = value
self._value_hash = hash(value)
def is_subtype_of(self, other: trace.TraceType) -> bool:
return self == other
def most_specific_common_supertype(
self, types: Sequence[trace.TraceType]) -> Optional["Literal"]:
return self if all(self == other for other in types) else None
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedLiteral]:
return default_types_pb2.SerializedLiteral
def experimental_from_proto(
cls, proto: default_types_pb2.SerializedLiteral) -> "Literal":
if proto.HasField("bool_value"):
return Literal(proto.bool_value)
if proto.HasField("int_value"):
return Literal(proto.int_value)
if proto.HasField("float_value"):
return Literal(proto.float_value)
if proto.HasField("str_value"):
return Literal(proto.str_value)
if proto.HasField("none_value"):
return Literal(None)
raise ValueError("Malformed Literal proto can not be deserialized")
def experimental_as_proto(self) -> default_types_pb2.SerializedLiteral:
if isinstance(self.value, bool):
return default_types_pb2.SerializedLiteral(bool_value=self.value)
if isinstance(self.value, int):
return default_types_pb2.SerializedLiteral(int_value=self.value)
if isinstance(self.value, float):
return default_types_pb2.SerializedLiteral(float_value=self.value)
if isinstance(self.value, str):
return default_types_pb2.SerializedLiteral(str_value=self.value)
if self.value is None:
return default_types_pb2.SerializedLiteral(
raise ValueError("Can not serialize Literal of type " +
def placeholder_value(self, placeholder_context=None) -> Any:
# TODO(b/263505796): Remove this check when a range's placeholder output
# is expected to be a range and not a list.
if isinstance(self.value, range):
return list(self.value)
return self.value
def _to_tensors(self, value: Any):
return []
def __eq__(self, other) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
return isinstance(other, Literal) and self.value == other.value
def __hash__(self) -> int:
return self._value_hash
def __repr__(self):
return f"{self.__class__.__name__}(value={self.value!r})"
class Weakref(trace.TraceType):
"""Represents weakref of an arbitrary Python object.
When a function argument is a custom class, instead of making a copy of it
just for the sake of function cache, a weakref is instead kept to save memory.
def __init__(self, ref: weakref.ReferenceType):
self._ref = ref
self._ref_hash = hash(ref)
def is_subtype_of(self, other: trace.TraceType) -> bool:
return self == other
def most_specific_common_supertype(
self, types: Sequence[trace.TraceType]) -> Optional["Weakref"]:
return self if all(self == other for other in types) else None
def placeholder_value(self, placeholder_context=None) -> Any:
return self._ref()
def _to_tensors(self, value: Any) -> Any:
return []
def __eq__(self, other):
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, Weakref):
return False
if self._ref() is None or other._ref() is None:
return False
if self._ref() is other._ref():
return True
return self._ref == other._ref
def __hash__(self):
return self._ref_hash
def __repr__(self):
return f"{self.__class__.__name__}(ref={self._ref!r})"
class Tuple(trace.TraceType, serialization.Serializable):
"""Represents a tuple of TraceType objects."""
def __init__(self, *components: trace.TraceType):
self.components = components
def is_subtype_of(self, other: trace.TraceType) -> bool:
if (not isinstance(other, Tuple) or
len(self.components) != len(other.components)):
return False
return all(
self_component.is_subtype_of(other_component) for self_component,
other_component in zip(self.components, other.components))
def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
"""See base class."""
if not all(
isinstance(other, Tuple) and
len(self.components) == len(other.components) for other in others):
return None
supertyped_components = []
for i, component in enumerate(self.components):
supertyped_component = component.most_specific_common_supertype(
[other.components[i] for other in others])
if supertyped_component is None:
return None
return Tuple(*supertyped_components)
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedTuple]:
return default_types_pb2.SerializedTuple
def experimental_from_proto(
cls, proto: default_types_pb2.SerializedTuple) -> "Tuple":
return Tuple(*[serialization.deserialize(c) for c in proto.components])
def experimental_as_proto(self) -> default_types_pb2.SerializedTuple:
return default_types_pb2.SerializedTuple(
components=[serialization.serialize(c) for c in self.components])
def placeholder_value(self, placeholder_context) -> Any:
components = [
for component in self.components
return tuple(components)
def _to_tensors(self, value) -> Any:
assert isinstance(value, tuple)
flattened_values = []
for comp_value, comp_type in zip(value, self.components):
flattened_values.extend(comp_type._to_tensors(comp_value)) # pylint: disable=protected-access
return flattened_values
def _cast(self, value: Any, casting_context) -> Any:
assert isinstance(value, tuple), f"Cannot cast {value!r} to tuple type."
return tuple(component._cast( # pylint: disable=protected-access
v, casting_context) for v, component in zip(value, self.components))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, Tuple):
return False
return self.components == other.components
def __hash__(self) -> int:
return hash(self.components)
def __repr__(self):
return f"Tuple(components={self.components!r})"
class List(trace.TraceType, serialization.Serializable):
"""Represents a list of TraceType objects."""
def __init__(self, *components: trace.TraceType):
self.components_tuple = Tuple(*components)
def is_subtype_of(self, other: trace.TraceType) -> bool:
if not isinstance(other, List):
return False
return self.components_tuple.is_subtype_of(other.components_tuple)
def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
"""See base class."""
if not all(isinstance(other, List) for other in others):
return None
supertyped_components_tuple = (
[other.components_tuple for other in others]
if supertyped_components_tuple is None:
return None
return List(*supertyped_components_tuple.components)
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedList]:
return default_types_pb2.SerializedList
def experimental_from_proto(
cls, proto: default_types_pb2.SerializedList) -> "List":
return List(
def experimental_as_proto(self) -> default_types_pb2.SerializedList:
return default_types_pb2.SerializedList(
def placeholder_value(self, placeholder_context) -> Any:
return list(self.components_tuple.placeholder_value(placeholder_context))
def _to_tensors(self, value):
assert isinstance(value, list)
return self.components_tuple._to_tensors(tuple(value)) # pylint: disable=protected-access
def _cast(self, value: Any, casting_context) -> Any:
assert isinstance(value, list), f"Cannot cast {value!r} to list type."
return [component._cast(v, casting_context) for v, component in zip( # pylint: disable=protected-access
value, self.components_tuple.components)]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, List):
return False
return self.components_tuple == other.components_tuple
def __hash__(self) -> int:
return hash(self.components_tuple)
def __repr__(self):
return f"List(components={self.components_tuple.components!r})"
class NamedTuple(trace.TraceType, serialization.Serializable):
"""Represents a NamedTuple of TraceType objects."""
def __init__(self,
type_name: str,
attribute_names: PythonTuple[str],
attributes: PythonTuple[trace.TraceType],
placeholder_type: Optional[Type[Any]] = None):
self.type_name = type_name
self.attribute_names = attribute_names
self.attributes = Tuple(*attributes)
self._placeholder_type = placeholder_type
def from_type_and_attributes(
cls, named_tuple_type: Any,
attributes: PythonTuple[trace.TraceType]) -> "NamedTuple":
return NamedTuple(named_tuple_type.__name__, named_tuple_type._fields,
attributes, named_tuple_type)
def is_subtype_of(self, other: trace.TraceType) -> bool:
if not isinstance(other, NamedTuple):
return False
return (self.type_name == other.type_name and
self.attribute_names == other.attribute_names and
def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["NamedTuple"]:
"""See base class."""
if not all(
isinstance(other, NamedTuple) and self.type_name == other.type_name and
self.attribute_names == other.attribute_names for other in others):
return None
supertyped_attributes = self.attributes.most_specific_common_supertype(
[other.attributes for other in others])
if supertyped_attributes is None:
return None
return NamedTuple(self.type_name, self.attribute_names,
supertyped_attributes.components, self._placeholder_type)
def experimental_type_proto(
cls) -> Type[default_types_pb2.SerializedNamedTuple]:
return default_types_pb2.SerializedNamedTuple
def experimental_from_proto(
cls, proto: default_types_pb2.SerializedNamedTuple) -> "NamedTuple":
return NamedTuple(
proto.type_name, tuple(proto.attribute_names),
def experimental_as_proto(self) -> default_types_pb2.SerializedNamedTuple:
return default_types_pb2.SerializedNamedTuple(
def placeholder_value(self, placeholder_context) -> Any:
if self._placeholder_type is None:
# We don't need to trace after serialization so it is not needed but we
# can generate a placeholder type using the description if ever needed.
raise ValueError("Can not generate placeholder value for NamedTuple with"
" unspecified placeholder_type. Note: placeholder_type "
"is lost during serialization.")
attribute_placeholders = [
for attribute in self.attributes.components
return self._placeholder_type(*attribute_placeholders)
def _to_tensors(self, value: Any):
assert util.is_namedtuple(value)
flattened_values = []
for attribute_name, attribute_type in zip(
self.attribute_names, self.attributes.components):
attribute_value = getattr(value, attribute_name)
flattened_values.extend(attribute_type._to_tensors(attribute_value)) # pylint: disable=protected-access
return flattened_values
def _cast(self, value: Any, casting_context) -> Any:
# Value must have same attributes with the TraceType
assert (
isinstance(value, self._placeholder_type) # pylint: disable=unidiomatic-typecheck
), f"Cannot cast {value!r} to type {self._placeholder_type!r}."
cast_value = {}
value_dict = value._asdict()
assert set(value_dict.keys()) == set(
), f"{value!r} has different attributes with the TraceType {self!r}"
for k, v in zip(self.attribute_names, self.attributes.components):
cast_value[k] = v._cast(getattr(value, k), casting_context) # pylint: disable=protected-access
return self._placeholder_type(**cast_value)
def __hash__(self) -> int:
return hash((self.type_name, self.attribute_names, self.attributes))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, NamedTuple):
return False
return (self.type_name == other.type_name and
self.attribute_names == other.attribute_names and
self.attributes == other.attributes)
def __repr__(self):
return (f"NamedTuple(type_name={self.type_name}, "
f"attribute_names={self.attribute_names}, "
class Attrs(trace.TraceType):
"""Represents a class annotated by attr.s."""
def __init__(self,
type_name: str,
attribute_names: PythonTuple[str],
attributes: PythonTuple[trace.TraceType],
placeholder_type: Optional[Type[Any]] = None):
self.named_attributes = NamedTuple(type_name, attribute_names, attributes)
self._placeholder_type = placeholder_type
def from_type_and_attributes(
cls, attrs_type: Any,
attributes: PythonTuple[trace.TraceType]) -> "Attrs":
return Attrs(attrs_type.__name__,
tuple( for attr in attrs_type.__attrs_attrs__),
attributes, attrs_type)
def is_subtype_of(self, other: trace.TraceType) -> bool:
if not isinstance(other, Attrs):
return False
return self.named_attributes.is_subtype_of(other.named_attributes)
def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["Attrs"]:
"""See base class."""
if not all(isinstance(other, Attrs) for other in others):
return None
supertyped_attributes = (
[other.named_attributes for other in others]
if supertyped_attributes is None:
return None
return Attrs(self.named_attributes.type_name,
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedAttrs]:
return default_types_pb2.SerializedAttrs
def experimental_from_proto(
cls, proto: default_types_pb2.SerializedAttrs) -> "Attrs":
return Attrs(
def experimental_as_proto(self) -> default_types_pb2.SerializedAttrs:
return default_types_pb2.SerializedAttrs(
def placeholder_value(self, placeholder_context) -> Any:
if self._placeholder_type is None:
# We don't need to trace after serialization so it is not needed but we
# can generate a placeholder type using the description if ever needed.
raise ValueError("Can not generate placeholder value for Attrs with"
" unspecified placeholder_type. Note: placeholder_type "
"is lost during serialization.")
attribute_placeholders = [
for attribute in self.named_attributes.attributes.components
return self._placeholder_type(*attribute_placeholders)
def _to_tensors(self, value: Any):
assert util.is_attrs(value)
flattened_values = []
for attribute_name, attribute_type in zip(
attribute_value = getattr(value, attribute_name)
flattened_values.extend(attribute_type._to_tensors(attribute_value)) # pylint: disable=protected-access
return flattened_values
def _cast(self, value: Any, casting_context) -> Any:
assert util.is_attrs(value)
value_cast = {}
for attribute_name, attribute_type in zip(
attribute_value = getattr(value, attribute_name)
value_cast[attribute_name] = attribute_type._cast( # pylint: disable=protected-access
attribute_value, casting_context)
return self._placeholder_type(**value_cast)
def __hash__(self) -> int:
return hash(self.named_attributes)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, Attrs):
return False
return self.named_attributes == other.named_attributes
def __repr__(self):
return (f"Attrs(type_name={self.named_attributes.type_name}, "
f"attribute_names={self.named_attributes.attribute_names}, "
class Dict(trace.TraceType, serialization.Serializable):
"""Represents a dictionary of TraceType objects.
mapping: A mapping from keys to corresponding TraceTypes of the dict values.
def __init__(self,
mapping: PythonDict[Hashable, trace.TraceType],
placeholder_type: Optional[Type[Any]] = None):
self.mapping = mapping
self._placeholder_type = placeholder_type
def _has_same_structure(self, other):
if not isinstance(other, Dict):
return False
return self.mapping.keys() == other.mapping.keys()
def is_subtype_of(self, other: trace.TraceType) -> bool:
"""See base class."""
if not self._has_same_structure(other):
return False
# We need all keys to be present because there can be logic relying on
# their existence or lack thereof and hence can not guarantee subtype based
# on a subset or superset of keys.
# Only the tracing code can explicitly check for key dependencies and inform
# that decision.
return all(self.mapping[key].is_subtype_of(other.mapping[key])
for key in self.mapping)
def most_specific_common_supertype(
self, types: Sequence[trace.TraceType]) -> Optional["Dict"]:
"""See base class."""
if not all(self._has_same_structure(other) for other in types):
return None
new_mapping = {}
for key in self.mapping.keys():
common = self.mapping[key].most_specific_common_supertype(
[other.mapping[key] for other in types])
if common is None:
return None
new_mapping[key] = common
return Dict(new_mapping, self._placeholder_type)
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedDict]:
return default_types_pb2.SerializedDict
def experimental_from_proto(
cls, proto: default_types_pb2.SerializedDict) -> "Dict":
return Dict({
Literal.experimental_from_proto(k).value: serialization.deserialize(v)
for k, v in zip(proto.keys, proto.values)
def experimental_as_proto(self) -> default_types_pb2.SerializedDict:
return default_types_pb2.SerializedDict(
keys=[Literal(k).experimental_as_proto() for k in self.mapping.keys()],
values=[serialization.serialize(v) for v in self.mapping.values()])
def placeholder_value(self, placeholder_context) -> Any:
if self._placeholder_type is None:
raise ValueError("Can not generate placeholder value for Dict with"
" unspecified placeholder_type. Note: placeholder_type "
"is lost during serialization.")
attribute_placeholders = [
(key, value.placeholder_value(placeholder_context))
for key, value in self.mapping.items()
if self._placeholder_type is collections.defaultdict:
return dict(attribute_placeholders)
return self._placeholder_type(attribute_placeholders)
def _to_tensors(self, value: Any):
assert isinstance(value,
flattened_values = []
for key in sorted(self.mapping.keys()):
comp_value, comp_type = value[key], self.mapping[key]
flattened_values.extend(comp_type._to_tensors(comp_value)) # pylint: disable=protected-access
return flattened_values
def _cast(self, value: Any, casting_context) -> Any:
# Value must have same keys with the TraceType
assert isinstance(
value, dict
), f"Cannot cast {value!r} to Python dict type."
assert set(value.keys()) == set(
), f"{value!r} has different keys with the TraceType {self!r}."
cast_value = {}
for k in value:
assert k in self.mapping, f"Key {k} does not exist in TraceType {self!r}."
cast_value[k] = self.mapping[k]._cast(value[k], casting_context) # pylint: disable=protected-access
if self._placeholder_type is None:
return cast_value
return self._placeholder_type(**cast_value)
def __eq__(self, other) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
if not isinstance(other, Dict):
return False
return self.mapping == other.mapping
def __hash__(self) -> int:
return hash(frozenset(self.mapping.keys()))
def __repr__(self):
return f"{self.__class__.__name__}(mapping={self.mapping!r})"