176 lines
7.1 KiB
Python
176 lines
7.1 KiB
Python
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Classes and functions implementing Layer SavedModel serialization."""
|
||
|
|
||
|
from tensorflow.python.keras.mixed_precision import policy
|
||
|
from tensorflow.python.keras.saving.saved_model import base_serialization
|
||
|
from tensorflow.python.keras.saving.saved_model import constants
|
||
|
from tensorflow.python.keras.saving.saved_model import save_impl
|
||
|
from tensorflow.python.keras.saving.saved_model import serialized_attributes
|
||
|
from tensorflow.python.keras.utils import generic_utils
|
||
|
from tensorflow.python.trackable import data_structures
|
||
|
from tensorflow.python.util import nest
|
||
|
|
||
|
|
||
|
class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||
|
"""Implements Layer SavedModel serialization."""
|
||
|
|
||
|
@property
|
||
|
def object_identifier(self):
|
||
|
return constants.LAYER_IDENTIFIER
|
||
|
|
||
|
@property
|
||
|
def python_properties(self):
|
||
|
# TODO(kathywu): Add python property validator
|
||
|
return self._python_properties_internal()
|
||
|
|
||
|
def _python_properties_internal(self):
|
||
|
"""Returns dictionary of all python properties."""
|
||
|
# TODO(kathywu): Add support for metrics serialization.
|
||
|
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
||
|
# the python config serialization has caught up.
|
||
|
metadata = dict(
|
||
|
name=self.obj.name,
|
||
|
trainable=self.obj.trainable,
|
||
|
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
||
|
dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access
|
||
|
batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
|
||
|
stateful=self.obj.stateful,
|
||
|
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
||
|
)
|
||
|
|
||
|
metadata.update(get_serialized(self.obj))
|
||
|
if self.obj.input_spec is not None:
|
||
|
# Layer's input_spec has already been type-checked in the property setter.
|
||
|
metadata['input_spec'] = nest.map_structure(
|
||
|
lambda x: generic_utils.serialize_keras_object(x) if x else None,
|
||
|
self.obj.input_spec)
|
||
|
if (self.obj.activity_regularizer is not None and
|
||
|
hasattr(self.obj.activity_regularizer, 'get_config')):
|
||
|
metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
|
||
|
self.obj.activity_regularizer)
|
||
|
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||
|
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||
|
return metadata
|
||
|
|
||
|
def objects_to_serialize(self, serialization_cache):
|
||
|
return (self._get_serialized_attributes(
|
||
|
serialization_cache).objects_to_serialize)
|
||
|
|
||
|
def functions_to_serialize(self, serialization_cache):
|
||
|
return (self._get_serialized_attributes(
|
||
|
serialization_cache).functions_to_serialize)
|
||
|
|
||
|
def _get_serialized_attributes(self, serialization_cache):
|
||
|
"""Generates or retrieves serialized attributes from cache."""
|
||
|
keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {})
|
||
|
if self.obj in keras_cache:
|
||
|
return keras_cache[self.obj]
|
||
|
|
||
|
serialized_attr = keras_cache[self.obj] = (
|
||
|
serialized_attributes.SerializedAttributes.new(self.obj))
|
||
|
|
||
|
if (save_impl.should_skip_serialization(self.obj) or
|
||
|
self.obj._must_restore_from_config): # pylint: disable=protected-access
|
||
|
return serialized_attr
|
||
|
|
||
|
object_dict, function_dict = self._get_serialized_attributes_internal(
|
||
|
serialization_cache)
|
||
|
|
||
|
serialized_attr.set_and_validate_objects(object_dict)
|
||
|
serialized_attr.set_and_validate_functions(function_dict)
|
||
|
return serialized_attr
|
||
|
|
||
|
def _get_serialized_attributes_internal(self, serialization_cache):
|
||
|
"""Returns dictionary of serialized attributes."""
|
||
|
objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
|
||
|
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
|
||
|
# Attribute validator requires that the default save signature is added to
|
||
|
# function dict, even if the value is None.
|
||
|
functions['_default_save_signature'] = None
|
||
|
return objects, functions
|
||
|
|
||
|
|
||
|
# TODO(kathywu): Move serialization utils (and related utils from
|
||
|
# generic_utils.py) to a separate file.
|
||
|
def get_serialized(obj):
|
||
|
with generic_utils.skip_failed_serialization():
|
||
|
# Store the config dictionary, which may be used when reviving the object.
|
||
|
# When loading, the program will attempt to revive the object from config,
|
||
|
# and if that fails, the object will be revived from the SavedModel.
|
||
|
return generic_utils.serialize_keras_object(obj)
|
||
|
|
||
|
|
||
|
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||
|
"""InputLayer serialization."""
|
||
|
|
||
|
@property
|
||
|
def object_identifier(self):
|
||
|
return constants.INPUT_LAYER_IDENTIFIER
|
||
|
|
||
|
@property
|
||
|
def python_properties(self):
|
||
|
|
||
|
return dict(
|
||
|
class_name=type(self.obj).__name__,
|
||
|
name=self.obj.name,
|
||
|
dtype=self.obj.dtype,
|
||
|
sparse=self.obj.sparse,
|
||
|
ragged=self.obj.ragged,
|
||
|
batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access
|
||
|
config=self.obj.get_config())
|
||
|
|
||
|
def objects_to_serialize(self, serialization_cache):
|
||
|
return {}
|
||
|
|
||
|
def functions_to_serialize(self, serialization_cache):
|
||
|
return {}
|
||
|
|
||
|
|
||
|
class RNNSavedModelSaver(LayerSavedModelSaver):
|
||
|
"""RNN layer serialization."""
|
||
|
|
||
|
@property
|
||
|
def object_identifier(self):
|
||
|
return constants.RNN_LAYER_IDENTIFIER
|
||
|
|
||
|
def _get_serialized_attributes_internal(self, serialization_cache):
|
||
|
objects, functions = (
|
||
|
super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
|
||
|
serialization_cache))
|
||
|
states = data_structures.wrap_or_unwrap(self.obj.states)
|
||
|
# SaveModel require all the objects to be Trackable when saving.
|
||
|
# If the states is still a tuple after wrap_or_unwrap, it means it doesn't
|
||
|
# contain any trackable item within it, eg empty tuple or (None, None) for
|
||
|
# stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can
|
||
|
# make it a Trackable again for saving. When loaded, ConvLSTM2D is
|
||
|
# able to handle the tuple/list conversion.
|
||
|
if isinstance(states, tuple):
|
||
|
states = data_structures.wrap_or_unwrap(list(states))
|
||
|
objects['states'] = states
|
||
|
return objects, functions
|
||
|
|
||
|
|
||
|
class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver):
|
||
|
"""Index lookup layer serialization."""
|
||
|
|
||
|
@property
|
||
|
def python_properties(self):
|
||
|
# TODO(kathywu): Add python property validator
|
||
|
metadata = self._python_properties_internal()
|
||
|
if metadata['config'].get('has_static_table', False):
|
||
|
metadata['config']['vocabulary'] = None
|
||
|
return metadata
|