# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Classes and functions implementing Layer SavedModel serialization.""" import tensorflow.compat.v2 as tf from keras.mixed_precision import policy from keras.saving.legacy import serialization from keras.saving.legacy.saved_model import base_serialization from keras.saving.legacy.saved_model import constants from keras.saving.legacy.saved_model import save_impl from keras.saving.legacy.saved_model import serialized_attributes class LayerSavedModelSaver(base_serialization.SavedModelSaver): """Implements Layer SavedModel serialization.""" @property def object_identifier(self): return constants.LAYER_IDENTIFIER @property def python_properties(self): # TODO(kathywu): Add python property validator return self._python_properties_internal() def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) # once the python config serialization has caught up. metadata = dict( name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, dtype=policy.serialize(self.obj._dtype_policy), batch_input_shape=getattr(self.obj, "_batch_input_shape", None), stateful=self.obj.stateful, must_restore_from_config=self.obj._must_restore_from_config, preserve_input_structure_in_config=self.obj._preserve_input_structure_in_config, # noqa: E501 autocast=self.obj._autocast, ) metadata.update(get_serialized(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property # setter. metadata["input_spec"] = tf.nest.map_structure( lambda x: serialization.serialize_keras_object(x) if x else None, self.obj.input_spec, ) if self.obj.activity_regularizer is not None and hasattr( self.obj.activity_regularizer, "get_config" ): metadata[ "activity_regularizer" ] = serialization.serialize_keras_object( self.obj.activity_regularizer ) if self.obj._build_input_shape is not None: metadata["build_input_shape"] = self.obj._build_input_shape return metadata def objects_to_serialize(self, serialization_cache): return self._get_serialized_attributes( serialization_cache ).objects_to_serialize def functions_to_serialize(self, serialization_cache): return self._get_serialized_attributes( serialization_cache ).functions_to_serialize def _get_serialized_attributes(self, serialization_cache): """Generates or retrieves serialized attributes from cache.""" keras_cache = serialization_cache.setdefault( constants.KERAS_CACHE_KEY, {} ) if self.obj in keras_cache: return keras_cache[self.obj] serialized_attr = keras_cache[ self.obj ] = serialized_attributes.SerializedAttributes.new(self.obj) if ( save_impl.should_skip_serialization(self.obj) or self.obj._must_restore_from_config ): return serialized_attr object_dict, function_dict = self._get_serialized_attributes_internal( serialization_cache ) serialized_attr.set_and_validate_objects(object_dict) serialized_attr.set_and_validate_functions(function_dict) return serialized_attr def _get_serialized_attributes_internal(self, serialization_cache): """Returns dictionary of serialized attributes.""" objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) functions = save_impl.wrap_layer_functions( self.obj, serialization_cache ) # Attribute validator requires that the default save signature is added # to function dict, even if the value is None. functions["_default_save_signature"] = None return objects, functions # TODO(kathywu): Move serialization utils (and related utils from # generic_utils.py) to a separate file. def get_serialized(obj): with serialization.skip_failed_serialization(): # Store the config dictionary, which may be used when reviving the # object. When loading, the program will attempt to revive the object # from config, and if that fails, the object will be revived from the # SavedModel. return serialization.serialize_keras_object(obj) class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): """InputLayer serialization.""" @property def object_identifier(self): return constants.INPUT_LAYER_IDENTIFIER @property def python_properties(self): return dict( class_name=type(self.obj).__name__, name=self.obj.name, dtype=self.obj.dtype, sparse=self.obj.sparse, ragged=self.obj.ragged, batch_input_shape=self.obj._batch_input_shape, config=self.obj.get_config(), ) def objects_to_serialize(self, serialization_cache): return {} def functions_to_serialize(self, serialization_cache): return {} class RNNSavedModelSaver(LayerSavedModelSaver): """RNN layer serialization.""" @property def object_identifier(self): return constants.RNN_LAYER_IDENTIFIER def _get_serialized_attributes_internal(self, serialization_cache): objects, functions = super()._get_serialized_attributes_internal( serialization_cache ) states = tf.__internal__.tracking.wrap(self.obj.states) # SaveModel require all the objects to be Trackable when saving. If the # states is still a tuple after wrap_or_unwrap, it means it doesn't # contain any trackable item within it, eg empty tuple or (None, None) # for stateless ConvLSTM2D. We convert them to list so that # wrap_or_unwrap can make it a Trackable again for saving. When loaded, # ConvLSTM2D is able to handle the tuple/list conversion. if isinstance(states, tuple): states = tf.__internal__.tracking.wrap(list(states)) objects["states"] = states return objects, functions class VocabularySavedModelSaver(LayerSavedModelSaver): """Handles vocabulary layer serialization. This class is needed for StringLookup, IntegerLookup, and TextVectorization, which all have a vocabulary as part of the config. Currently, we keep this vocab as part of the config until saving, when we need to clear it to avoid initializing a StaticHashTable twice (once when restoring the config and once when restoring restoring module resources). After clearing the vocab, we persist a property to the layer indicating it was constructed with a vocab. """ @property def python_properties(self): # TODO(kathywu): Add python property validator metadata = self._python_properties_internal() # Clear the vocabulary from the config during saving. metadata["config"]["vocabulary"] = None # Persist a property to track that a vocabulary was passed on # construction. metadata["config"][ "has_input_vocabulary" ] = self.obj._has_input_vocabulary return metadata