158 lines
6.2 KiB
Python
158 lines
6.2 KiB
Python
![]() |
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Keras legacy SavedModel saving."""
|
||
|
|
||
|
import os
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
from absl import logging
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.protobuf import saved_metadata_pb2
|
||
|
from keras.protobuf import versions_pb2
|
||
|
from keras.saving.legacy import saving_utils
|
||
|
from keras.saving.legacy import serialization
|
||
|
from keras.saving.legacy.saved_model import constants
|
||
|
from keras.saving.legacy.saved_model import save_impl
|
||
|
from keras.saving.legacy.saved_model import utils
|
||
|
from keras.utils.generic_utils import LazyLoader
|
||
|
from keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.saved_model import save as save_lib
|
||
|
|
||
|
# To avoid circular dependencies between keras/engine and keras/saving,
|
||
|
# code in keras/saving must delay imports.
|
||
|
|
||
|
base_layer = LazyLoader("base_layer", globals(), "keras.engine.base_layer")
|
||
|
training_lib = LazyLoader("training_lib", globals(), "keras.engine.training")
|
||
|
|
||
|
|
||
|
def save(
|
||
|
model,
|
||
|
filepath,
|
||
|
overwrite,
|
||
|
include_optimizer,
|
||
|
signatures=None,
|
||
|
options=None,
|
||
|
save_traces=True,
|
||
|
):
|
||
|
"""Saves a model as a SavedModel to the filepath.
|
||
|
|
||
|
Args:
|
||
|
model: Keras model instance to be saved.
|
||
|
filepath: String path to save the model.
|
||
|
overwrite: whether to overwrite the existing filepath.
|
||
|
include_optimizer: If True, save the model's optimizer state.
|
||
|
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
|
||
|
format only. Please see the `signatures` argument in
|
||
|
`tf.saved_model.save` for details.
|
||
|
options: (only applies to SavedModel format) `tf.saved_model.SaveOptions`
|
||
|
object that specifies options for saving to SavedModel.
|
||
|
save_traces: (only applies to SavedModel format) When enabled, the
|
||
|
SavedModel will store the function traces for each layer. This
|
||
|
can be disabled, so that only the configs of each layer are stored.
|
||
|
Defaults to `True`. Disabling this will decrease serialization time
|
||
|
and reduce file size, but it requires that all custom layers/models
|
||
|
implement a `get_config()` method.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if the model's inputs have not been defined.
|
||
|
"""
|
||
|
# If file exists and should not be overwritten.
|
||
|
if not overwrite and os.path.exists(filepath):
|
||
|
proceed = ask_to_proceed_with_overwrite(filepath)
|
||
|
if not proceed:
|
||
|
return
|
||
|
|
||
|
if save_traces:
|
||
|
if save_impl.should_skip_serialization(model):
|
||
|
saving_utils.raise_model_input_error(model)
|
||
|
|
||
|
if not include_optimizer:
|
||
|
orig_optimizer = model.optimizer
|
||
|
model.optimizer = None
|
||
|
# TODO(b/180760306) Change to del model.optimizer if Layer's __delattr__
|
||
|
# calls AutoTrackable's __delattr__.
|
||
|
model._delete_tracking("optimizer")
|
||
|
|
||
|
# Trace all functions and signatures with `training=0` instead of using an
|
||
|
# already-set learning phase placeholder.
|
||
|
# This is needed for compatibility reasons until learning phase setting
|
||
|
# is removed from the public apis.
|
||
|
with serialization.SharedObjectSavingScope():
|
||
|
with backend.deprecated_internal_learning_phase_scope(0):
|
||
|
with utils.keras_option_scope(save_traces):
|
||
|
saved_nodes, node_paths = save_lib.save_and_return_nodes(
|
||
|
model, filepath, signatures, options
|
||
|
)
|
||
|
|
||
|
# Save all metadata to a separate file in the SavedModel directory.
|
||
|
metadata = generate_keras_metadata(saved_nodes, node_paths)
|
||
|
|
||
|
with tf.io.gfile.GFile(
|
||
|
tf.io.gfile.join(filepath, constants.SAVED_METADATA_PATH), "wb"
|
||
|
) as w:
|
||
|
w.write(metadata.SerializeToString(deterministic=True))
|
||
|
|
||
|
if not include_optimizer:
|
||
|
model.optimizer = orig_optimizer
|
||
|
|
||
|
|
||
|
def generate_keras_metadata(saved_nodes, node_paths):
|
||
|
"""Constructs a KerasMetadata proto with the metadata of each object."""
|
||
|
metadata = saved_metadata_pb2.SavedMetadata()
|
||
|
for node_id, node in enumerate(saved_nodes):
|
||
|
if isinstance(node, base_layer.Layer):
|
||
|
path = node_paths[node]
|
||
|
if not path:
|
||
|
node_path = "root"
|
||
|
else:
|
||
|
node_path = f"root.{'.'.join([ref.name for ref in path])}"
|
||
|
|
||
|
metadata.nodes.add(
|
||
|
node_id=node_id,
|
||
|
node_path=node_path,
|
||
|
version=versions_pb2.VersionDef(
|
||
|
producer=2, min_consumer=1, bad_consumers=[]
|
||
|
),
|
||
|
identifier=node._object_identifier,
|
||
|
metadata=node._tracking_metadata,
|
||
|
)
|
||
|
|
||
|
# Log warning if the node's class name conflicts with a Keras
|
||
|
# built-in object.
|
||
|
class_name = node.__class__.__name__
|
||
|
from keras.layers import serialization as layers_serialization
|
||
|
|
||
|
builtin_layer = layers_serialization.get_builtin_layer(class_name)
|
||
|
if builtin_layer:
|
||
|
if not isinstance(node, builtin_layer):
|
||
|
logging.warning(
|
||
|
"%s has the same name '%s' as a built-in Keras "
|
||
|
"object. Consider renaming %s to avoid naming "
|
||
|
"conflicts when loading with "
|
||
|
"`tf.keras.models.load_model`. "
|
||
|
"If renaming is not possible, pass "
|
||
|
"the object in the `custom_objects` "
|
||
|
"parameter of the load "
|
||
|
"function.",
|
||
|
node,
|
||
|
class_name,
|
||
|
node.__class__,
|
||
|
)
|
||
|
|
||
|
return metadata
|