# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Classes and functions implementing Layer SavedModel serialization.""" from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving.saved_model import base_serialization from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import save_impl from tensorflow.python.keras.saving.saved_model import serialized_attributes from tensorflow.python.keras.utils import generic_utils from tensorflow.python.trackable import data_structures from tensorflow.python.util import nest class LayerSavedModelSaver(base_serialization.SavedModelSaver): """Implements Layer SavedModel serialization.""" @property def object_identifier(self): return constants.LAYER_IDENTIFIER @property def python_properties(self): # TODO(kathywu): Add python property validator return self._python_properties_internal() def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None), stateful=self.obj.stateful, must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access ) metadata.update(get_serialized(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata['activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata def objects_to_serialize(self, serialization_cache): return (self._get_serialized_attributes( serialization_cache).objects_to_serialize) def functions_to_serialize(self, serialization_cache): return (self._get_serialized_attributes( serialization_cache).functions_to_serialize) def _get_serialized_attributes(self, serialization_cache): """Generates or retrieves serialized attributes from cache.""" keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {}) if self.obj in keras_cache: return keras_cache[self.obj] serialized_attr = keras_cache[self.obj] = ( serialized_attributes.SerializedAttributes.new(self.obj)) if (save_impl.should_skip_serialization(self.obj) or self.obj._must_restore_from_config): # pylint: disable=protected-access return serialized_attr object_dict, function_dict = self._get_serialized_attributes_internal( serialization_cache) serialized_attr.set_and_validate_objects(object_dict) serialized_attr.set_and_validate_functions(function_dict) return serialized_attr def _get_serialized_attributes_internal(self, serialization_cache): """Returns dictionary of serialized attributes.""" objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) functions = save_impl.wrap_layer_functions(self.obj, serialization_cache) # Attribute validator requires that the default save signature is added to # function dict, even if the value is None. functions['_default_save_signature'] = None return objects, functions # TODO(kathywu): Move serialization utils (and related utils from # generic_utils.py) to a separate file. def get_serialized(obj): with generic_utils.skip_failed_serialization(): # Store the config dictionary, which may be used when reviving the object. # When loading, the program will attempt to revive the object from config, # and if that fails, the object will be revived from the SavedModel. return generic_utils.serialize_keras_object(obj) class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): """InputLayer serialization.""" @property def object_identifier(self): return constants.INPUT_LAYER_IDENTIFIER @property def python_properties(self): return dict( class_name=type(self.obj).__name__, name=self.obj.name, dtype=self.obj.dtype, sparse=self.obj.sparse, ragged=self.obj.ragged, batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access config=self.obj.get_config()) def objects_to_serialize(self, serialization_cache): return {} def functions_to_serialize(self, serialization_cache): return {} class RNNSavedModelSaver(LayerSavedModelSaver): """RNN layer serialization.""" @property def object_identifier(self): return constants.RNN_LAYER_IDENTIFIER def _get_serialized_attributes_internal(self, serialization_cache): objects, functions = ( super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( serialization_cache)) states = data_structures.wrap_or_unwrap(self.obj.states) # SaveModel require all the objects to be Trackable when saving. # If the states is still a tuple after wrap_or_unwrap, it means it doesn't # contain any trackable item within it, eg empty tuple or (None, None) for # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can # make it a Trackable again for saving. When loaded, ConvLSTM2D is # able to handle the tuple/list conversion. if isinstance(states, tuple): states = data_structures.wrap_or_unwrap(list(states)) objects['states'] = states return objects, functions class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver): """Index lookup layer serialization.""" @property def python_properties(self): # TODO(kathywu): Add python property validator metadata = self._python_properties_internal() if metadata['config'].get('has_static_table', False): metadata['config']['vocabulary'] = None return metadata