Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/core/function/trace_type/serialization.py
2023-06-19 00:49:18 +02:00

101 lines
3.6 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for serializing and deserializing TraceTypes."""
import abc
from typing import Type
from google.protobuf import message
from tensorflow.core.function.trace_type import serialization_pb2
SerializedTraceType = serialization_pb2.SerializedTraceType
PROTO_CLASS_TO_PY_CLASS = {}
class Serializable(metaclass=abc.ABCMeta):
"""TraceTypes implementing this additional interface are portable."""
@classmethod
@abc.abstractmethod
def experimental_type_proto(cls) -> Type[message.Message]:
"""Returns the unique type of proto associated with this class."""
raise NotImplementedError
@classmethod
@abc.abstractmethod
def experimental_from_proto(cls, proto: message.Message) -> "Serializable":
"""Returns an instance based on a proto."""
raise NotImplementedError
@abc.abstractmethod
def experimental_as_proto(self) -> message.Message:
"""Returns a proto representing this instance."""
raise NotImplementedError
def register_serializable(cls: Type[Serializable]):
"""Registers a Python class to support serialization.
Only register standard TF types. Custom types should NOT be registered.
Args:
cls: Python class to register.
"""
if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
raise ValueError(
"Existing Python class " +
PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
" already has " + cls.experimental_type_proto().__name__ +
" as its associated proto representation. Please ensure " +
cls.__name__ + " has a unique proto representation.")
PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
def serialize(to_serialize: Serializable) -> SerializedTraceType:
"""Converts Serializable to a proto SerializedTraceType."""
if not isinstance(to_serialize, Serializable):
raise ValueError("Can not serialize " + type(to_serialize).__name__ +
" since it is not Serializable. For object " +
str(to_serialize))
actual_proto = to_serialize.experimental_as_proto()
if not isinstance(actual_proto, to_serialize.experimental_type_proto()):
raise ValueError(
type(to_serialize).__name__ +
" returned different type of proto than specified by " +
"experimental_type_proto()")
serialized = SerializedTraceType()
serialized.representation.Pack(actual_proto)
return serialized
def deserialize(proto: SerializedTraceType) -> Serializable:
"""Converts a proto SerializedTraceType to instance of Serializable."""
for proto_class in PROTO_CLASS_TO_PY_CLASS:
if proto.representation.Is(proto_class.DESCRIPTOR):
actual_proto = proto_class()
proto.representation.Unpack(actual_proto)
return PROTO_CLASS_TO_PY_CLASS[proto_class].experimental_from_proto(
actual_proto)
raise ValueError(
"Can not deserialize proto of url: ", proto.representation.type_url,
" since no matching Python class could be found. For value ",
proto.representation.value)