# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helper classes that list&validate all attributes to serialize to SavedModel. """ from tensorflow.python.eager import def_function from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import save_impl from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.trackable import base as trackable from tensorflow.python.trackable.autotrackable import AutoTrackable # TODO(b/134426265): Switch back to single-quotes to match the rest of the file # once the issue with copybara is fixed. # pylint:disable=g-inconsistent-quotes base_layer = LazyLoader( "base_layer", globals(), "tensorflow.python.keras.engine.base_layer") training_lib = LazyLoader( "training_lib", globals(), "tensorflow.python.keras.engine.training") metrics = LazyLoader("metrics", globals(), "tensorflow.python.keras.metrics") recurrent = LazyLoader( "recurrent", globals(), "tensorflow.python.keras.layers.recurrent") # pylint:enable=g-inconsistent-quotes class SerializedAttributes(object): """Class that tracks and validates all serialization attributes. Keras models contain many Python-defined components. For example, the trainable_variable property lists the model's trainable variables by recursively retrieving the trainable variables from each of the child layers. Another example is model.call, a python function that calls child layers and adds ops to the backend graph. Only Tensorflow checkpointable objects and functions can be serialized to SavedModel. Serializing a Keras model as-is results in a checkpointable object that does not resemble a Keras model at all. Thus, extra checkpointable objects and functions must be created during serialization. **Defining new serialized attributes** Child classes should be defined using: SerializedAttributes.with_attributes( 'name', checkpointable_objects=[...], functions=[...], copy_from=[...]) This class is used to cache generated checkpointable objects and functions, ensuring that new objects and functions are generated a single time. **Usage during serialization** Each Layer/Model object should have a corresponding instance of SerializedAttributes. Create a new instance by calling `SerializedAttributes.new(obj)`. Objects and functions may be saved using `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`. The properties `.checkpointable_objects` and `.functions` returns the cached values. **Adding/changing attributes to save to SavedModel** 1. Change the call to `SerializedAttributes.with_attributes` in the correct class: - CommonEndpoints: Base attributes to be added during serialization. If these attributes are present in a Trackable object, it can be deserialized to a Keras Model. - LayerAttributes: Attributes to serialize for Layer objects. - ModelAttributes: Attributes to serialize for Model objects. 2. Update class docstring 3. Update arguments to any calls to `set_and_validate_*`. For example, if `call_raw_tensors` is added to the ModelAttributes function list, then a `call_raw_tensors` function should be passed to `set_and_validate_functions`. **Common endpoints vs other attributes** Only common endpoints are attached directly to the root object. Keras-specific attributes are saved to a separate trackable object with the name "keras_api". The number of objects attached to the root is limited because any naming conflicts will cause user code to break. Another reason is that this will only affect users who call `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are advanced users who are likely to have defined their own tf.functions and trackable objects. The added Keras-specific attributes are kept out of the way in the "keras_api" namespace. Properties defined in this class may be used to filter out keras-specific attributes: - `functions_to_serialize`: Returns dict of functions to attach to the root object. - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to the root object (including separate trackable object containing keras-specific attributes) All changes to the serialized attributes must be backwards-compatible, so attributes should not be removed or modified without sufficient justification. """ @staticmethod def with_attributes( name, checkpointable_objects=None, functions=None, copy_from=None): """Creates a subclass with all attributes as specified in the arguments. Args: name: Name of subclass checkpointable_objects: List of checkpointable objects to be serialized in the SavedModel. functions: List of functions to be serialized in the SavedModel. copy_from: List of other SerializedAttributes subclasses. The returned class will copy checkpoint objects/functions from each subclass. Returns: Child class with attributes as defined in the `checkpointable_objects` and `functions` lists. """ checkpointable_objects = checkpointable_objects or [] functions = functions or [] if copy_from is not None: for cls in copy_from: checkpointable_objects.extend(cls.all_checkpointable_objects) functions.extend(cls.all_functions) classdict = { 'all_checkpointable_objects': set(checkpointable_objects), 'all_functions': set(functions)} return type(name, (SerializedAttributes,), classdict) @staticmethod def new(obj): """Returns a new SerializedAttribute object.""" if isinstance(obj, training_lib.Model): return ModelAttributes() elif isinstance(obj, metrics.Metric): return MetricAttributes() elif isinstance(obj, recurrent.RNN): return RNNAttributes() elif isinstance(obj, base_layer.Layer): return LayerAttributes() else: raise TypeError('Internal error during serialization: Expected Keras ' 'Layer object, got {} of type {}'.format(obj, type(obj))) def __init__(self): self._object_dict = {} self._function_dict = {} self._keras_trackable = AutoTrackable() @property def functions(self): """Returns dictionary of all functions.""" return {key: value for key, value in self._function_dict.items() if value is not None} @property def checkpointable_objects(self): """Returns dictionary of all checkpointable objects.""" return {key: value for key, value in self._object_dict.items() if value is not None} @property def functions_to_serialize(self): """Returns functions to attach to the root object during serialization.""" functions = {} for key, v in self.functions.items(): if key in CommonEndpoints.all_functions: functions[key] = (v.wrapped_call if isinstance(v, save_impl.LayerCall) else v) return functions @property def objects_to_serialize(self): """Returns objects to attach to the root object during serialization.""" objects = {key: value for key, value in self.checkpointable_objects.items() if key in CommonEndpoints.all_checkpointable_objects} objects[constants.KERAS_ATTR] = self._keras_trackable return objects def set_and_validate_functions(self, function_dict): """Saves function dictionary, and validates dictionary values.""" for key in self.all_functions: if key in function_dict: if (function_dict[key] is not None and # Not all functions are required not isinstance(function_dict[key], (def_function.Function, save_impl.LayerCall))): raise ValueError( 'Function dictionary contained a non-function object: {} (for key' ' {})'.format(function_dict[key], key)) fn = function_dict[key] self._function_dict[key] = fn # Extract TensorFlow `Function` from LayerCall. tf_fn = fn.wrapped_call if isinstance(fn, save_impl.LayerCall) else fn setattr(self._keras_trackable, key, tf_fn) else: raise ValueError('Function {} missing from serialized function dict.' .format(key)) return self.functions def set_and_validate_objects(self, object_dict): """Saves objects to a dictionary, and validates the values.""" for key in self.all_checkpointable_objects: if key in object_dict: if not isinstance(object_dict[key], trackable.Trackable): raise ValueError( 'Object dictionary contained a non-trackable object: {} (for key' ' {})'.format(object_dict[key], key)) self._object_dict[key] = object_dict[key] setattr(self._keras_trackable, key, object_dict[key]) else: raise ValueError( 'Object {} missing from serialized object dict.'.format(key)) return self.checkpointable_objects class CommonEndpoints(SerializedAttributes.with_attributes( 'CommonEndpoints', checkpointable_objects=['variables', 'trainable_variables', 'regularization_losses'], functions=['__call__', 'call_and_return_all_conditional_losses', '_default_save_signature'])): """Common endpoints shared by all models loadable by Keras. List of all attributes: variables: List of all variables in the model and its sublayers. trainable_variables: List of all trainable variables in the model and its sublayers. regularization_losses: List of all unconditional losses (losses not dependent on the inputs) in the model and its sublayers. __call__: Function that takes inputs and returns the outputs of the model call function. call_and_return_all_conditional_losses: Function that returns a tuple of (call function outputs, list of all losses that depend on the inputs). _default_save_signature: Traced model call function. This is only included if the top level exported object is a Keras model. """ class LayerAttributes(SerializedAttributes.with_attributes( 'LayerAttributes', checkpointable_objects=['non_trainable_variables', 'layers', 'metrics', 'layer_regularization_losses', 'layer_metrics'], functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'], copy_from=[CommonEndpoints] )): """Layer checkpointable objects + functions that are saved to the SavedModel. List of all attributes: All attributes from CommonEndpoints non_trainable_variables: List of non-trainable variables in the layer and its sublayers. layers: List of all sublayers. metrics: List of all metrics in the layer and its sublayers. call_and_return_conditional_losses: Function that takes inputs and returns a tuple of (outputs of the call function, list of input-dependent losses). The list of losses excludes the activity regularizer function, which is separate to allow the deserialized Layer object to define a different activity regularizer. activity_regularizer_fn: Callable that returns the activity regularizer loss layer_regularization_losses: List of losses owned only by this layer. layer_metrics: List of metrics owned by this layer. """ class ModelAttributes(SerializedAttributes.with_attributes( 'ModelAttributes', copy_from=[LayerAttributes])): """Model checkpointable objects + functions that are saved to the SavedModel. List of all attributes: All attributes from LayerAttributes (including CommonEndpoints) """ # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which # list all losses and metrics defined by `model.compile`. class MetricAttributes( SerializedAttributes.with_attributes( 'MetricAttributes', checkpointable_objects=['variables'], functions=[], )): """Attributes that are added to Metric objects when saved to SavedModel. List of all attributes: variables: list of all variables """ pass class RNNAttributes(SerializedAttributes.with_attributes( 'RNNAttributes', checkpointable_objects=['states'], copy_from=[LayerAttributes])): """RNN checkpointable objects + functions that are saved to the SavedModel. List of all attributes: All attributes from LayerAttributes (including CommonEndpoints) states: List of state variables """