669 lines
24 KiB
Python
669 lines
24 KiB
Python
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""TraceType implementations for common Python types."""
|
|
|
|
import collections
|
|
from typing import Any, Hashable, Optional, Sequence, Type
|
|
from typing import Dict as PythonDict
|
|
from typing import Tuple as PythonTuple
|
|
import weakref
|
|
|
|
from tensorflow.core.function.trace_type import default_types_pb2
|
|
from tensorflow.core.function.trace_type import serialization
|
|
from tensorflow.core.function.trace_type import util
|
|
from tensorflow.python.types import trace
|
|
|
|
|
|
class Literal(trace.TraceType, serialization.Serializable):
|
|
"""Represents a Literal type like bool, int or string."""
|
|
|
|
def __init__(self, value: Any):
|
|
self.value = value
|
|
self._value_hash = hash(value)
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
return self == other
|
|
|
|
def most_specific_common_supertype(
|
|
self, types: Sequence[trace.TraceType]) -> Optional["Literal"]:
|
|
return self if all(self == other for other in types) else None
|
|
|
|
@classmethod
|
|
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedLiteral]:
|
|
return default_types_pb2.SerializedLiteral
|
|
|
|
@classmethod
|
|
def experimental_from_proto(
|
|
cls, proto: default_types_pb2.SerializedLiteral) -> "Literal":
|
|
if proto.HasField("bool_value"):
|
|
return Literal(proto.bool_value)
|
|
|
|
if proto.HasField("int_value"):
|
|
return Literal(proto.int_value)
|
|
|
|
if proto.HasField("float_value"):
|
|
return Literal(proto.float_value)
|
|
|
|
if proto.HasField("str_value"):
|
|
return Literal(proto.str_value)
|
|
|
|
if proto.HasField("none_value"):
|
|
return Literal(None)
|
|
|
|
raise ValueError("Malformed Literal proto can not be deserialized")
|
|
|
|
def experimental_as_proto(self) -> default_types_pb2.SerializedLiteral:
|
|
if isinstance(self.value, bool):
|
|
return default_types_pb2.SerializedLiteral(bool_value=self.value)
|
|
|
|
if isinstance(self.value, int):
|
|
return default_types_pb2.SerializedLiteral(int_value=self.value)
|
|
|
|
if isinstance(self.value, float):
|
|
return default_types_pb2.SerializedLiteral(float_value=self.value)
|
|
|
|
if isinstance(self.value, str):
|
|
return default_types_pb2.SerializedLiteral(str_value=self.value)
|
|
|
|
if self.value is None:
|
|
return default_types_pb2.SerializedLiteral(
|
|
none_value=default_types_pb2.SerializedLiteral.NoneValue())
|
|
|
|
raise ValueError("Can not serialize Literal of type " +
|
|
type(self.value).__name__)
|
|
|
|
def placeholder_value(self, placeholder_context=None) -> Any:
|
|
# TODO(b/263505796): Remove this check when a range's placeholder output
|
|
# is expected to be a range and not a list.
|
|
if isinstance(self.value, range):
|
|
return list(self.value)
|
|
return self.value
|
|
|
|
def _to_tensors(self, value: Any):
|
|
return []
|
|
|
|
def __eq__(self, other) -> bool:
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
return isinstance(other, Literal) and self.value == other.value
|
|
|
|
def __hash__(self) -> int:
|
|
return self._value_hash
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(value={self.value!r})"
|
|
|
|
|
|
class Weakref(trace.TraceType):
|
|
"""Represents weakref of an arbitrary Python object.
|
|
|
|
When a function argument is a custom class, instead of making a copy of it
|
|
just for the sake of function cache, a weakref is instead kept to save memory.
|
|
"""
|
|
|
|
def __init__(self, ref: weakref.ReferenceType):
|
|
self._ref = ref
|
|
self._ref_hash = hash(ref)
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
return self == other
|
|
|
|
def most_specific_common_supertype(
|
|
self, types: Sequence[trace.TraceType]) -> Optional["Weakref"]:
|
|
return self if all(self == other for other in types) else None
|
|
|
|
def placeholder_value(self, placeholder_context=None) -> Any:
|
|
return self._ref()
|
|
|
|
def _to_tensors(self, value: Any) -> Any:
|
|
return []
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
if not isinstance(other, Weakref):
|
|
return False
|
|
|
|
if self._ref() is None or other._ref() is None:
|
|
return False
|
|
|
|
if self._ref() is other._ref():
|
|
return True
|
|
|
|
return self._ref == other._ref
|
|
|
|
def __hash__(self):
|
|
return self._ref_hash
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(ref={self._ref!r})"
|
|
|
|
|
|
class Tuple(trace.TraceType, serialization.Serializable):
|
|
"""Represents a tuple of TraceType objects."""
|
|
|
|
def __init__(self, *components: trace.TraceType):
|
|
self.components = components
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
if (not isinstance(other, Tuple) or
|
|
len(self.components) != len(other.components)):
|
|
return False
|
|
|
|
return all(
|
|
self_component.is_subtype_of(other_component) for self_component,
|
|
other_component in zip(self.components, other.components))
|
|
|
|
def most_specific_common_supertype(
|
|
self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
|
|
"""See base class."""
|
|
if not all(
|
|
isinstance(other, Tuple) and
|
|
len(self.components) == len(other.components) for other in others):
|
|
return None
|
|
|
|
supertyped_components = []
|
|
for i, component in enumerate(self.components):
|
|
supertyped_component = component.most_specific_common_supertype(
|
|
[other.components[i] for other in others])
|
|
if supertyped_component is None:
|
|
return None
|
|
supertyped_components.append(supertyped_component)
|
|
|
|
return Tuple(*supertyped_components)
|
|
|
|
@classmethod
|
|
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedTuple]:
|
|
return default_types_pb2.SerializedTuple
|
|
|
|
@classmethod
|
|
def experimental_from_proto(
|
|
cls, proto: default_types_pb2.SerializedTuple) -> "Tuple":
|
|
return Tuple(*[serialization.deserialize(c) for c in proto.components])
|
|
|
|
def experimental_as_proto(self) -> default_types_pb2.SerializedTuple:
|
|
return default_types_pb2.SerializedTuple(
|
|
components=[serialization.serialize(c) for c in self.components])
|
|
|
|
def placeholder_value(self, placeholder_context) -> Any:
|
|
components = [
|
|
component.placeholder_value(placeholder_context)
|
|
for component in self.components
|
|
]
|
|
return tuple(components)
|
|
|
|
def _to_tensors(self, value) -> Any:
|
|
assert isinstance(value, tuple)
|
|
flattened_values = []
|
|
for comp_value, comp_type in zip(value, self.components):
|
|
flattened_values.extend(comp_type._to_tensors(comp_value)) # pylint: disable=protected-access
|
|
return flattened_values
|
|
|
|
def _cast(self, value: Any, casting_context) -> Any:
|
|
assert isinstance(value, tuple), f"Cannot cast {value!r} to tuple type."
|
|
return tuple(component._cast( # pylint: disable=protected-access
|
|
v, casting_context) for v, component in zip(value, self.components))
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
if not isinstance(other, Tuple):
|
|
return False
|
|
|
|
return self.components == other.components
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.components)
|
|
|
|
def __repr__(self):
|
|
return f"Tuple(components={self.components!r})"
|
|
|
|
|
|
class List(trace.TraceType, serialization.Serializable):
|
|
"""Represents a list of TraceType objects."""
|
|
|
|
def __init__(self, *components: trace.TraceType):
|
|
self.components_tuple = Tuple(*components)
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
if not isinstance(other, List):
|
|
return False
|
|
|
|
return self.components_tuple.is_subtype_of(other.components_tuple)
|
|
|
|
def most_specific_common_supertype(
|
|
self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
|
|
"""See base class."""
|
|
if not all(isinstance(other, List) for other in others):
|
|
return None
|
|
|
|
supertyped_components_tuple = (
|
|
self.components_tuple.most_specific_common_supertype(
|
|
[other.components_tuple for other in others]
|
|
)
|
|
)
|
|
|
|
if supertyped_components_tuple is None:
|
|
return None
|
|
|
|
return List(*supertyped_components_tuple.components)
|
|
|
|
@classmethod
|
|
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedList]:
|
|
return default_types_pb2.SerializedList
|
|
|
|
@classmethod
|
|
def experimental_from_proto(
|
|
cls, proto: default_types_pb2.SerializedList) -> "List":
|
|
return List(
|
|
*Tuple.experimental_from_proto(proto.components_tuple).components)
|
|
|
|
def experimental_as_proto(self) -> default_types_pb2.SerializedList:
|
|
return default_types_pb2.SerializedList(
|
|
components_tuple=self.components_tuple.experimental_as_proto())
|
|
|
|
def placeholder_value(self, placeholder_context) -> Any:
|
|
return list(self.components_tuple.placeholder_value(placeholder_context))
|
|
|
|
def _to_tensors(self, value):
|
|
assert isinstance(value, list)
|
|
return self.components_tuple._to_tensors(tuple(value)) # pylint: disable=protected-access
|
|
|
|
def _cast(self, value: Any, casting_context) -> Any:
|
|
assert isinstance(value, list), f"Cannot cast {value!r} to list type."
|
|
return [component._cast(v, casting_context) for v, component in zip( # pylint: disable=protected-access
|
|
value, self.components_tuple.components)]
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
if not isinstance(other, List):
|
|
return False
|
|
|
|
return self.components_tuple == other.components_tuple
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.components_tuple)
|
|
|
|
def __repr__(self):
|
|
return f"List(components={self.components_tuple.components!r})"
|
|
|
|
|
|
class NamedTuple(trace.TraceType, serialization.Serializable):
|
|
"""Represents a NamedTuple of TraceType objects."""
|
|
|
|
def __init__(self,
|
|
type_name: str,
|
|
attribute_names: PythonTuple[str],
|
|
attributes: PythonTuple[trace.TraceType],
|
|
placeholder_type: Optional[Type[Any]] = None):
|
|
self.type_name = type_name
|
|
self.attribute_names = attribute_names
|
|
self.attributes = Tuple(*attributes)
|
|
self._placeholder_type = placeholder_type
|
|
|
|
@classmethod
|
|
def from_type_and_attributes(
|
|
cls, named_tuple_type: Any,
|
|
attributes: PythonTuple[trace.TraceType]) -> "NamedTuple":
|
|
return NamedTuple(named_tuple_type.__name__, named_tuple_type._fields,
|
|
attributes, named_tuple_type)
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
if not isinstance(other, NamedTuple):
|
|
return False
|
|
|
|
return (self.type_name == other.type_name and
|
|
self.attribute_names == other.attribute_names and
|
|
self.attributes.is_subtype_of(other.attributes))
|
|
|
|
def most_specific_common_supertype(
|
|
self, others: Sequence[trace.TraceType]) -> Optional["NamedTuple"]:
|
|
"""See base class."""
|
|
if not all(
|
|
isinstance(other, NamedTuple) and self.type_name == other.type_name and
|
|
self.attribute_names == other.attribute_names for other in others):
|
|
return None
|
|
|
|
supertyped_attributes = self.attributes.most_specific_common_supertype(
|
|
[other.attributes for other in others])
|
|
|
|
if supertyped_attributes is None:
|
|
return None
|
|
|
|
return NamedTuple(self.type_name, self.attribute_names,
|
|
supertyped_attributes.components, self._placeholder_type)
|
|
|
|
@classmethod
|
|
def experimental_type_proto(
|
|
cls) -> Type[default_types_pb2.SerializedNamedTuple]:
|
|
return default_types_pb2.SerializedNamedTuple
|
|
|
|
@classmethod
|
|
def experimental_from_proto(
|
|
cls, proto: default_types_pb2.SerializedNamedTuple) -> "NamedTuple":
|
|
return NamedTuple(
|
|
proto.type_name, tuple(proto.attribute_names),
|
|
Tuple.experimental_from_proto(proto.attributes).components)
|
|
|
|
def experimental_as_proto(self) -> default_types_pb2.SerializedNamedTuple:
|
|
return default_types_pb2.SerializedNamedTuple(
|
|
type_name=self.type_name,
|
|
attribute_names=list(self.attribute_names),
|
|
attributes=self.attributes.experimental_as_proto())
|
|
|
|
def placeholder_value(self, placeholder_context) -> Any:
|
|
if self._placeholder_type is None:
|
|
# We don't need to trace after serialization so it is not needed but we
|
|
# can generate a placeholder type using the description if ever needed.
|
|
raise ValueError("Can not generate placeholder value for NamedTuple with"
|
|
" unspecified placeholder_type. Note: placeholder_type "
|
|
"is lost during serialization.")
|
|
attribute_placeholders = [
|
|
attribute.placeholder_value(placeholder_context)
|
|
for attribute in self.attributes.components
|
|
]
|
|
return self._placeholder_type(*attribute_placeholders)
|
|
|
|
def _to_tensors(self, value: Any):
|
|
assert util.is_namedtuple(value)
|
|
flattened_values = []
|
|
for attribute_name, attribute_type in zip(
|
|
self.attribute_names, self.attributes.components):
|
|
attribute_value = getattr(value, attribute_name)
|
|
flattened_values.extend(attribute_type._to_tensors(attribute_value)) # pylint: disable=protected-access
|
|
return flattened_values
|
|
|
|
def _cast(self, value: Any, casting_context) -> Any:
|
|
# Value must have same attributes with the TraceType
|
|
assert (
|
|
isinstance(value, self._placeholder_type) # pylint: disable=unidiomatic-typecheck
|
|
), f"Cannot cast {value!r} to type {self._placeholder_type!r}."
|
|
cast_value = {}
|
|
value_dict = value._asdict()
|
|
assert set(value_dict.keys()) == set(
|
|
self.attribute_names
|
|
), f"{value!r} has different attributes with the TraceType {self!r}"
|
|
|
|
for k, v in zip(self.attribute_names, self.attributes.components):
|
|
cast_value[k] = v._cast(getattr(value, k), casting_context) # pylint: disable=protected-access
|
|
return self._placeholder_type(**cast_value)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self.type_name, self.attribute_names, self.attributes))
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
if not isinstance(other, NamedTuple):
|
|
return False
|
|
|
|
return (self.type_name == other.type_name and
|
|
self.attribute_names == other.attribute_names and
|
|
self.attributes == other.attributes)
|
|
|
|
def __repr__(self):
|
|
return (f"NamedTuple(type_name={self.type_name}, "
|
|
f"attribute_names={self.attribute_names}, "
|
|
f"attributes={self.attributes.components})")
|
|
|
|
|
|
class Attrs(trace.TraceType):
|
|
"""Represents a class annotated by attr.s."""
|
|
|
|
def __init__(self,
|
|
type_name: str,
|
|
attribute_names: PythonTuple[str],
|
|
attributes: PythonTuple[trace.TraceType],
|
|
placeholder_type: Optional[Type[Any]] = None):
|
|
self.named_attributes = NamedTuple(type_name, attribute_names, attributes)
|
|
self._placeholder_type = placeholder_type
|
|
|
|
@classmethod
|
|
def from_type_and_attributes(
|
|
cls, attrs_type: Any,
|
|
attributes: PythonTuple[trace.TraceType]) -> "Attrs":
|
|
return Attrs(attrs_type.__name__,
|
|
tuple(attr.name for attr in attrs_type.__attrs_attrs__),
|
|
attributes, attrs_type)
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
if not isinstance(other, Attrs):
|
|
return False
|
|
|
|
return self.named_attributes.is_subtype_of(other.named_attributes)
|
|
|
|
def most_specific_common_supertype(
|
|
self, others: Sequence[trace.TraceType]) -> Optional["Attrs"]:
|
|
"""See base class."""
|
|
if not all(isinstance(other, Attrs) for other in others):
|
|
return None
|
|
|
|
supertyped_attributes = (
|
|
self.named_attributes.most_specific_common_supertype(
|
|
[other.named_attributes for other in others]
|
|
)
|
|
)
|
|
|
|
if supertyped_attributes is None:
|
|
return None
|
|
|
|
return Attrs(self.named_attributes.type_name,
|
|
self.named_attributes.attribute_names,
|
|
supertyped_attributes.attributes.components,
|
|
self._placeholder_type)
|
|
|
|
@classmethod
|
|
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedAttrs]:
|
|
return default_types_pb2.SerializedAttrs
|
|
|
|
@classmethod
|
|
def experimental_from_proto(
|
|
cls, proto: default_types_pb2.SerializedAttrs) -> "Attrs":
|
|
return Attrs(
|
|
proto.named_attributes.type_name,
|
|
tuple(proto.named_attributes.attribute_names),
|
|
Tuple.experimental_from_proto(
|
|
proto.named_attributes.attributes).components)
|
|
|
|
def experimental_as_proto(self) -> default_types_pb2.SerializedAttrs:
|
|
return default_types_pb2.SerializedAttrs(
|
|
named_attributes=self.named_attributes.experimental_as_proto())
|
|
|
|
def placeholder_value(self, placeholder_context) -> Any:
|
|
if self._placeholder_type is None:
|
|
# We don't need to trace after serialization so it is not needed but we
|
|
# can generate a placeholder type using the description if ever needed.
|
|
raise ValueError("Can not generate placeholder value for Attrs with"
|
|
" unspecified placeholder_type. Note: placeholder_type "
|
|
"is lost during serialization.")
|
|
attribute_placeholders = [
|
|
attribute.placeholder_value(placeholder_context)
|
|
for attribute in self.named_attributes.attributes.components
|
|
]
|
|
return self._placeholder_type(*attribute_placeholders)
|
|
|
|
def _to_tensors(self, value: Any):
|
|
assert util.is_attrs(value)
|
|
flattened_values = []
|
|
for attribute_name, attribute_type in zip(
|
|
self.named_attributes.attribute_names,
|
|
self.named_attributes.attributes.components):
|
|
attribute_value = getattr(value, attribute_name)
|
|
flattened_values.extend(attribute_type._to_tensors(attribute_value)) # pylint: disable=protected-access
|
|
return flattened_values
|
|
|
|
def _cast(self, value: Any, casting_context) -> Any:
|
|
assert util.is_attrs(value)
|
|
value_cast = {}
|
|
for attribute_name, attribute_type in zip(
|
|
self.named_attributes.attribute_names,
|
|
self.named_attributes.attributes.components):
|
|
attribute_value = getattr(value, attribute_name)
|
|
value_cast[attribute_name] = attribute_type._cast( # pylint: disable=protected-access
|
|
attribute_value, casting_context)
|
|
|
|
return self._placeholder_type(**value_cast)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.named_attributes)
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
if not isinstance(other, Attrs):
|
|
return False
|
|
|
|
return self.named_attributes == other.named_attributes
|
|
|
|
def __repr__(self):
|
|
return (f"Attrs(type_name={self.named_attributes.type_name}, "
|
|
f"attribute_names={self.named_attributes.attribute_names}, "
|
|
f"attributes={self.named_attributes.attributes.components})")
|
|
|
|
|
|
class Dict(trace.TraceType, serialization.Serializable):
|
|
"""Represents a dictionary of TraceType objects.
|
|
|
|
Attributes:
|
|
mapping: A mapping from keys to corresponding TraceTypes of the dict values.
|
|
"""
|
|
|
|
def __init__(self,
|
|
mapping: PythonDict[Hashable, trace.TraceType],
|
|
placeholder_type: Optional[Type[Any]] = None):
|
|
self.mapping = mapping
|
|
self._placeholder_type = placeholder_type
|
|
|
|
def _has_same_structure(self, other):
|
|
if not isinstance(other, Dict):
|
|
return False
|
|
|
|
return self.mapping.keys() == other.mapping.keys()
|
|
|
|
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
|
"""See base class."""
|
|
if not self._has_same_structure(other):
|
|
return False
|
|
|
|
# We need all keys to be present because there can be logic relying on
|
|
# their existence or lack thereof and hence can not guarantee subtype based
|
|
# on a subset or superset of keys.
|
|
# Only the tracing code can explicitly check for key dependencies and inform
|
|
# that decision.
|
|
return all(self.mapping[key].is_subtype_of(other.mapping[key])
|
|
for key in self.mapping)
|
|
|
|
def most_specific_common_supertype(
|
|
self, types: Sequence[trace.TraceType]) -> Optional["Dict"]:
|
|
"""See base class."""
|
|
if not all(self._has_same_structure(other) for other in types):
|
|
return None
|
|
|
|
new_mapping = {}
|
|
for key in self.mapping.keys():
|
|
common = self.mapping[key].most_specific_common_supertype(
|
|
[other.mapping[key] for other in types])
|
|
if common is None:
|
|
return None
|
|
else:
|
|
new_mapping[key] = common
|
|
|
|
return Dict(new_mapping, self._placeholder_type)
|
|
|
|
@classmethod
|
|
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedDict]:
|
|
return default_types_pb2.SerializedDict
|
|
|
|
@classmethod
|
|
def experimental_from_proto(
|
|
cls, proto: default_types_pb2.SerializedDict) -> "Dict":
|
|
return Dict({
|
|
Literal.experimental_from_proto(k).value: serialization.deserialize(v)
|
|
for k, v in zip(proto.keys, proto.values)
|
|
})
|
|
|
|
def experimental_as_proto(self) -> default_types_pb2.SerializedDict:
|
|
return default_types_pb2.SerializedDict(
|
|
keys=[Literal(k).experimental_as_proto() for k in self.mapping.keys()],
|
|
values=[serialization.serialize(v) for v in self.mapping.values()])
|
|
|
|
def placeholder_value(self, placeholder_context) -> Any:
|
|
if self._placeholder_type is None:
|
|
raise ValueError("Can not generate placeholder value for Dict with"
|
|
" unspecified placeholder_type. Note: placeholder_type "
|
|
"is lost during serialization.")
|
|
attribute_placeholders = [
|
|
(key, value.placeholder_value(placeholder_context))
|
|
for key, value in self.mapping.items()
|
|
]
|
|
if self._placeholder_type is collections.defaultdict:
|
|
return dict(attribute_placeholders)
|
|
return self._placeholder_type(attribute_placeholders)
|
|
|
|
def _to_tensors(self, value: Any):
|
|
assert isinstance(value, collections.abc.Mapping)
|
|
flattened_values = []
|
|
for key in sorted(self.mapping.keys()):
|
|
comp_value, comp_type = value[key], self.mapping[key]
|
|
flattened_values.extend(comp_type._to_tensors(comp_value)) # pylint: disable=protected-access
|
|
return flattened_values
|
|
|
|
def _cast(self, value: Any, casting_context) -> Any:
|
|
# Value must have same keys with the TraceType
|
|
assert isinstance(
|
|
value, dict
|
|
), f"Cannot cast {value!r} to Python dict type."
|
|
assert set(value.keys()) == set(
|
|
self.mapping.keys()
|
|
), f"{value!r} has different keys with the TraceType {self!r}."
|
|
|
|
cast_value = {}
|
|
for k in value:
|
|
assert k in self.mapping, f"Key {k} does not exist in TraceType {self!r}."
|
|
cast_value[k] = self.mapping[k]._cast(value[k], casting_context) # pylint: disable=protected-access
|
|
if self._placeholder_type is None:
|
|
return cast_value
|
|
else:
|
|
return self._placeholder_type(**cast_value)
|
|
|
|
def __eq__(self, other) -> bool:
|
|
if not isinstance(other, trace.TraceType):
|
|
return NotImplemented
|
|
|
|
if not isinstance(other, Dict):
|
|
return False
|
|
|
|
return self.mapping == other.mapping
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(frozenset(self.mapping.keys()))
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(mapping={self.mapping!r})"
|
|
|
|
serialization.register_serializable(Literal)
|
|
serialization.register_serializable(Tuple)
|
|
serialization.register_serializable(List)
|
|
serialization.register_serializable(NamedTuple)
|
|
serialization.register_serializable(Attrs)
|
|
serialization.register_serializable(Dict)
|