1384 lines
55 KiB
Python
1384 lines
55 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Keras SavedModel deserialization."""
|
|
|
|
import re
|
|
import types
|
|
import warnings
|
|
|
|
import tensorflow.compat.v1.logging as logging
|
|
import tensorflow.compat.v2 as tf
|
|
from google.protobuf import message
|
|
|
|
from keras import backend
|
|
from keras import regularizers
|
|
from keras.engine import input_spec
|
|
from keras.optimizers.legacy import optimizer_v2
|
|
from keras.protobuf import saved_metadata_pb2
|
|
from keras.protobuf import versions_pb2
|
|
from keras.saving import object_registration
|
|
from keras.saving.legacy import saving_utils
|
|
from keras.saving.legacy import serialization
|
|
from keras.saving.legacy.saved_model import constants
|
|
from keras.saving.legacy.saved_model import json_utils
|
|
from keras.saving.legacy.saved_model import utils
|
|
from keras.saving.legacy.saved_model.serialized_attributes import (
|
|
CommonEndpoints,
|
|
)
|
|
from keras.utils import layer_utils
|
|
from keras.utils import metrics_utils
|
|
from keras.utils import tf_inspect
|
|
from keras.utils.generic_utils import LazyLoader
|
|
|
|
# To avoid circular dependencies between keras/engine and keras/saving,
|
|
# code in keras/saving must delay imports.
|
|
|
|
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
|
# once the issue with copybara is fixed.
|
|
|
|
models_lib = LazyLoader("models_lib", globals(), "keras.models")
|
|
base_layer = LazyLoader("base_layer", globals(), "keras.engine.base_layer")
|
|
layers_module = LazyLoader("layers_module", globals(), "keras.layers")
|
|
input_layer = LazyLoader("input_layer", globals(), "keras.engine.input_layer")
|
|
functional_lib = LazyLoader(
|
|
"functional_lib", globals(), "keras.engine.functional"
|
|
)
|
|
training_lib = LazyLoader("training_lib", globals(), "keras.engine.training")
|
|
training_lib_v1 = LazyLoader(
|
|
"training_lib_v1", globals(), "keras.engine.training_v1"
|
|
)
|
|
metrics = LazyLoader("metrics", globals(), "keras.metrics")
|
|
base_rnn = LazyLoader("base_rnn", globals(), "keras.layers.rnn.base_rnn")
|
|
|
|
|
|
PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
|
|
CommonEndpoints.all_checkpointable_objects
|
|
)
|
|
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
|
|
|
|
|
|
def load(path, compile=True, options=None):
|
|
"""Loads Keras objects from a SavedModel.
|
|
|
|
Any Keras layer or model saved to the SavedModel will be loaded back
|
|
as Keras objects. Other objects are loaded as regular trackable objects
|
|
(same as `tf.saved_model.load`).
|
|
|
|
Currently, Keras saving/loading only retains the Keras object's weights,
|
|
losses, and call function.
|
|
|
|
The loaded model can be re-compiled, but the original optimizer, compiled
|
|
loss functions, and metrics are not retained. This is temporary, and
|
|
`model.save` will soon be able to serialize compiled models.
|
|
|
|
Args:
|
|
path: Path to SavedModel.
|
|
compile: If true, compile the model after loading it.
|
|
options: Optional `tf.saved_model.LoadOptions` object that specifies
|
|
options for loading from SavedModel.
|
|
|
|
Returns:
|
|
Object loaded from SavedModel.
|
|
"""
|
|
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and
|
|
# metrics.
|
|
# TODO(kathywu): Add code to load from objects that contain all endpoints
|
|
|
|
# Look for metadata file or parse the SavedModel
|
|
metadata = saved_metadata_pb2.SavedMetadata()
|
|
meta_graph_def = tf.__internal__.saved_model.parse_saved_model(
|
|
path
|
|
).meta_graphs[0]
|
|
object_graph_def = meta_graph_def.object_graph_def
|
|
path_to_metadata_pb = tf.io.gfile.join(path, constants.SAVED_METADATA_PATH)
|
|
if tf.compat.v1.gfile.Exists(path_to_metadata_pb):
|
|
try:
|
|
with tf.io.gfile.GFile(path_to_metadata_pb, "rb") as f:
|
|
file_content = f.read()
|
|
metadata.ParseFromString(file_content)
|
|
except message.DecodeError as e:
|
|
raise IOError(
|
|
f"Cannot parse keras metadata at path {path_to_metadata_pb}: "
|
|
f"Received error: {e}"
|
|
)
|
|
else:
|
|
logging.warning(
|
|
"SavedModel saved prior to TF 2.5 detected when loading "
|
|
"Keras model. Please ensure that you are saving the model "
|
|
"with model.save() or tf.keras.models.save_model(), *NOT* "
|
|
"tf.saved_model.save(). To confirm, there should be a file "
|
|
'named "keras_metadata.pb" in the SavedModel directory.'
|
|
)
|
|
_read_legacy_metadata(object_graph_def, metadata, path)
|
|
|
|
if not metadata.nodes:
|
|
# When there are no Keras objects, return the results from the core
|
|
# loader
|
|
return tf.saved_model.load(path, options=options)
|
|
|
|
metadata = _update_to_current_version(metadata)
|
|
# Recreate layers and metrics using the info stored in the metadata.
|
|
keras_loader = KerasObjectLoader(metadata, object_graph_def)
|
|
keras_loader.load_layers(compile=compile)
|
|
|
|
# Generate a dictionary of all loaded nodes.
|
|
nodes_to_load = {"root": None}
|
|
for node_id, loaded_node in keras_loader.loaded_nodes.items():
|
|
nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore", message="Trying to load ShardedVariables"
|
|
)
|
|
loaded = tf.__internal__.saved_model.load_partial(
|
|
path, nodes_to_load, options=options
|
|
)
|
|
|
|
# Finalize the loaded layers and remove the extra tracked dependencies.
|
|
keras_loader.finalize_objects()
|
|
keras_loader.del_tracking()
|
|
|
|
model = loaded["root"]
|
|
|
|
if isinstance(model, training_lib.Model) and compile:
|
|
# TODO(kathywu): Use compiled objects from SavedModel, instead of
|
|
# creating new objects from the training config.
|
|
training_config = model._serialized_attributes["metadata"].get(
|
|
"training_config", None
|
|
)
|
|
if training_config is not None:
|
|
model.compile(
|
|
**saving_utils.compile_args_from_training_config(
|
|
training_config
|
|
),
|
|
from_serialized=True,
|
|
)
|
|
saving_utils.try_build_compiled_arguments(model)
|
|
if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
|
|
if model.optimizer.get_slot_names():
|
|
logging.warning(
|
|
"Your optimizer uses slots. "
|
|
"Slots cannot be restored from saved_model, "
|
|
"as a result, your model is starting with "
|
|
"a new initialized optimizer."
|
|
)
|
|
else:
|
|
logging.warning(
|
|
"No training configuration found in save file, so the "
|
|
"model was *not* compiled. Compile it manually."
|
|
)
|
|
|
|
# Force variables and resources to initialize.
|
|
if not tf.executing_eagerly():
|
|
sess = backend.get_session() # Variables are initialized by this call.
|
|
sess.run(
|
|
tf.compat.v1.get_collection(
|
|
tf.compat.v1.GraphKeys.TABLE_INITIALIZERS
|
|
)
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def _update_to_current_version(metadata):
|
|
"""Applies version updates to the metadata proto for backwards compat."""
|
|
for node in metadata.nodes:
|
|
if node.version.producer == 1 and node.identifier in [
|
|
constants.MODEL_IDENTIFIER,
|
|
constants.SEQUENTIAL_IDENTIFIER,
|
|
constants.NETWORK_IDENTIFIER,
|
|
]:
|
|
node_metadata = json_utils.decode(node.metadata)
|
|
save_spec = node_metadata.get("save_spec")
|
|
|
|
if save_spec is not None:
|
|
node_metadata["full_save_spec"] = ([save_spec], {})
|
|
node.metadata = json_utils.Encoder().encode(node_metadata)
|
|
return metadata
|
|
|
|
|
|
def _read_legacy_metadata(object_graph_def, metadata, path):
|
|
"""Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
|
|
# Older SavedModels store the metadata directly in the proto instead of the
|
|
# separate pb file.
|
|
node_paths = _generate_object_paths(object_graph_def)
|
|
for node_id, proto in enumerate(object_graph_def.nodes):
|
|
if (
|
|
proto.WhichOneof("kind") == "user_object"
|
|
and proto.user_object.identifier
|
|
in constants.KERAS_OBJECT_IDENTIFIERS
|
|
):
|
|
if not proto.user_object.metadata:
|
|
raise ValueError(
|
|
"Unable to create a Keras model from SavedModel at "
|
|
f"{path}. This SavedModel was exported with "
|
|
"`tf.saved_model.save`, and lacks the Keras metadata file. "
|
|
"Please save your Keras model by calling `model.save` "
|
|
"or `tf.keras.models.save_model`. Note that "
|
|
"you can still load this SavedModel with "
|
|
"`tf.saved_model.load`."
|
|
)
|
|
metadata.nodes.add(
|
|
node_id=node_id,
|
|
node_path=node_paths[node_id],
|
|
version=versions_pb2.VersionDef(
|
|
producer=1, min_consumer=1, bad_consumers=[]
|
|
),
|
|
identifier=proto.user_object.identifier,
|
|
metadata=proto.user_object.metadata,
|
|
)
|
|
|
|
|
|
def _generate_object_paths(object_graph_def):
|
|
"""Traverses through an ObjectGraphDef and builds a map of all node
|
|
paths."""
|
|
paths = {0: "root"}
|
|
nodes_to_visit = [0]
|
|
|
|
while nodes_to_visit:
|
|
current_node = nodes_to_visit.pop()
|
|
current_path = paths[current_node]
|
|
for reference in object_graph_def.nodes[current_node].children:
|
|
if reference.node_id in paths:
|
|
continue
|
|
paths[reference.node_id] = f"{current_path}.{reference.local_name}"
|
|
nodes_to_visit.append(reference.node_id)
|
|
|
|
return paths
|
|
|
|
|
|
def _is_graph_network(layer):
|
|
"""Determines whether the layer is a graph network."""
|
|
|
|
if isinstance(layer, RevivedNetwork):
|
|
return False
|
|
elif isinstance(layer, functional_lib.Functional):
|
|
return layer._is_graph_network or isinstance(
|
|
layer, models_lib.Sequential
|
|
)
|
|
return False
|
|
|
|
|
|
class KerasObjectLoader:
|
|
"""Loader that recreates Keras objects (e.g.
|
|
|
|
layers, models).
|
|
|
|
Layers and models are revived from either the config or SavedModel following
|
|
these rules:
|
|
1. If object is a graph network (i.e. Sequential or Functional) then it will
|
|
be initialized using the structure from the config only after the
|
|
children layers have been created. Graph networks must be initialized
|
|
with inputs and outputs, so all child layers must be created beforehand.
|
|
2. If object's config exists and the class can be found, then revive from
|
|
config.
|
|
3. Object may have already been created if its parent was revived from
|
|
config. In this case, do nothing.
|
|
4. If nothing of the above applies, compose the various artifacts from the
|
|
SavedModel to create a subclassed layer or model. At this time, custom
|
|
metrics are not supported.
|
|
|
|
"""
|
|
|
|
def __init__(self, metadata, object_graph_def):
|
|
self._metadata = {x.node_id: x for x in metadata.nodes}
|
|
self._proto = object_graph_def
|
|
|
|
self._node_paths = {
|
|
node_data.node_id: node_data.node_path
|
|
for node_data in metadata.nodes
|
|
}
|
|
self.loaded_nodes = {} # Maps node path -> loaded node
|
|
|
|
# Store all node ids that have already been traversed when tracking
|
|
# nodes that were recreated from the config.
|
|
self._traversed_nodes_from_config = set()
|
|
|
|
# Maps model id -> (blank model obj, list of child layer or their node
|
|
# ids) This tracks all layers in functional and sequential models. These
|
|
# models are only reconstructed after all of their child layers have
|
|
# been created.
|
|
self.model_layer_dependencies = {}
|
|
self._models_to_reconstruct = []
|
|
|
|
def del_tracking(self):
|
|
"""Removes tracked references that are only used when loading the
|
|
model."""
|
|
# Now that the node object has been fully loaded, and the checkpoint has
|
|
# been restored, the object no longer needs to track objects added from
|
|
# SerializedAttributes. (Note that saving a training checkpoint still
|
|
# functions correctly, because layers and variables are tracked
|
|
# separately by the Layer object.)
|
|
# TODO(kathywu): Instead of outright deleting these nodes (which would
|
|
# make restoring from a different checkpoint tricky), mark them as extra
|
|
# dependencies that are OK to overwrite.
|
|
for node in self.loaded_nodes.values():
|
|
node = node[0]
|
|
if not isinstance(node, base_layer.Layer):
|
|
# Loaded nodes can contain other trackable objects created when
|
|
# loading layers from the config, such as variables.
|
|
continue
|
|
for name in PUBLIC_ATTRIBUTES:
|
|
node._delete_tracking(name)
|
|
|
|
if isinstance(node, functional_lib.Functional):
|
|
# Delete the temporary layer dependencies, which were used to
|
|
# restore the checkpointed values. When the model is live, the
|
|
# user can delete or add layers to the model at any time, so
|
|
# these layer dependencies may be obsolete.
|
|
dependencies = list(node._self_unconditional_dependency_names)
|
|
for name in dependencies:
|
|
if (
|
|
re.match(r"^layer(_with_weights)?-[\d+]", name)
|
|
is not None
|
|
):
|
|
node._delete_tracking(name)
|
|
|
|
def _add_children_recreated_from_config(self, obj, proto, node_id):
|
|
"""Recursively records objects recreated from config."""
|
|
|
|
if node_id in self._traversed_nodes_from_config:
|
|
return
|
|
|
|
parent_path = self._node_paths[node_id]
|
|
self._traversed_nodes_from_config.add(node_id)
|
|
obj._maybe_initialize_trackable()
|
|
if isinstance(obj, base_layer.Layer) and not obj.built:
|
|
metadata = json_utils.decode(self._metadata[node_id].metadata)
|
|
self._try_build_layer(
|
|
obj, node_id, metadata.get("build_input_shape")
|
|
)
|
|
|
|
# Create list of all possible children
|
|
children = []
|
|
# Look for direct children
|
|
for reference in proto.children:
|
|
obj_child = obj._lookup_dependency(reference.local_name)
|
|
children.append(
|
|
(obj_child, reference.node_id, reference.local_name)
|
|
)
|
|
|
|
# Add metrics that may have been added to the layer._metrics list.
|
|
# This is stored in the SavedModel as layer.keras_api.layer_metrics in
|
|
# SavedModels created after Tf 2.2.
|
|
metric_list_node_id = self._search_for_child_node(
|
|
node_id, [constants.KERAS_ATTR, "layer_metrics"]
|
|
)
|
|
if metric_list_node_id is not None and hasattr(obj, "_metrics"):
|
|
obj_metrics = {m.name: m for m in obj._metrics}
|
|
for reference in self._proto.nodes[metric_list_node_id].children:
|
|
metric = obj_metrics.get(reference.local_name)
|
|
if metric is not None:
|
|
metric_path = "{}.layer_metrics.{}".format(
|
|
constants.KERAS_ATTR, reference.local_name
|
|
)
|
|
children.append((metric, reference.node_id, metric_path))
|
|
|
|
for obj_child, child_id, child_name in children:
|
|
child_proto = self._proto.nodes[child_id]
|
|
|
|
if not isinstance(obj_child, tf.__internal__.tracking.Trackable):
|
|
continue
|
|
if (
|
|
child_proto.user_object.identifier
|
|
in tf.__internal__.saved_model.load.registered_identifiers()
|
|
):
|
|
setter = tf.__internal__.saved_model.load.get_setter(
|
|
child_proto.user_object
|
|
)
|
|
elif (
|
|
obj_child._object_identifier
|
|
in constants.KERAS_OBJECT_IDENTIFIERS
|
|
):
|
|
setter = _revive_setter
|
|
else:
|
|
setter = setattr
|
|
|
|
if child_id in self.loaded_nodes:
|
|
if self.loaded_nodes[child_id][0] is not obj_child:
|
|
# This means that the same trackable object is referenced by
|
|
# two different objects that were recreated from the config.
|
|
logging.warning(
|
|
"Looks like there is an object (perhaps variable or "
|
|
"layer) that is shared between different "
|
|
"layers/models. This may cause issues when restoring "
|
|
"the variable values. Object: {}".format(obj_child)
|
|
)
|
|
continue
|
|
|
|
# Overwrite variable names with the ones saved in the SavedModel.
|
|
if (
|
|
child_proto.WhichOneof("kind") == "variable"
|
|
and child_proto.variable.name
|
|
):
|
|
obj_child._handle_name = child_proto.variable.name + ":0"
|
|
|
|
if isinstance(
|
|
obj_child, tf.__internal__.tracking.TrackableDataStructure
|
|
):
|
|
setter = lambda *args: None
|
|
|
|
child_path = f"{parent_path}.{child_name}"
|
|
self._node_paths[child_id] = child_path
|
|
self._add_children_recreated_from_config(
|
|
obj_child, child_proto, child_id
|
|
)
|
|
self.loaded_nodes[child_id] = obj_child, setter
|
|
|
|
def load_layers(self, compile=True):
|
|
"""Load all layer nodes from the metadata."""
|
|
# Load metrics after models and layers, since it's likely that models
|
|
# and layers will create the metric when initialized (this avoids
|
|
# wasting time by creating objects multiple times).
|
|
metric_list = []
|
|
for node_metadata in self._metadata.values():
|
|
if node_metadata.identifier == constants.METRIC_IDENTIFIER:
|
|
metric_list.append(node_metadata)
|
|
continue
|
|
|
|
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
|
|
node_metadata.node_id,
|
|
node_metadata.identifier,
|
|
node_metadata.metadata,
|
|
)
|
|
|
|
for node_metadata in metric_list:
|
|
try:
|
|
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
|
|
node_metadata.node_id,
|
|
node_metadata.identifier,
|
|
node_metadata.metadata,
|
|
)
|
|
except ValueError as e:
|
|
# Metrics are only needed when the model is compiled later. We
|
|
# ignore errors when trying to load custom metrics when
|
|
# `compile=False` until custom metrics are serialized properly
|
|
# (b/135550038).
|
|
if compile:
|
|
raise e
|
|
logging.warning(
|
|
"Unable to restore custom metric. Please ensure that "
|
|
"the layer implements `get_config` and `from_config` "
|
|
"when saving. In addition, please use the "
|
|
"`custom_objects` arg when calling `load_model()`."
|
|
)
|
|
|
|
def _load_layer(self, node_id, identifier, metadata):
|
|
"""Load a single layer from a SavedUserObject proto."""
|
|
metadata = json_utils.decode(metadata)
|
|
|
|
# If node was already created
|
|
if node_id in self.loaded_nodes:
|
|
node, setter = self.loaded_nodes[node_id]
|
|
|
|
# Revive setter requires the object to have a
|
|
# `_serialized_attributes` property. Add it here.
|
|
_maybe_add_serialized_attributes(node, metadata)
|
|
|
|
config = metadata.get("config")
|
|
if _is_graph_network(node) and serialization.validate_config(
|
|
config
|
|
):
|
|
child_nodes = self._get_child_layer_node_ids(node_id)
|
|
self.model_layer_dependencies[node_id] = (node, child_nodes)
|
|
if not child_nodes:
|
|
self._models_to_reconstruct.append(node_id)
|
|
return node, setter
|
|
|
|
# Detect whether this object can be revived from the config. If not,
|
|
# then revive from the SavedModel instead.
|
|
obj, setter = self._revive_from_config(identifier, metadata, node_id)
|
|
if obj is None:
|
|
obj, setter = revive_custom_object(identifier, metadata)
|
|
|
|
# Add an attribute that stores the extra functions/objects saved in the
|
|
# SavedModel. Most of these functions/objects are ignored, but some are
|
|
# used later in the loading process (e.g. the list of regularization
|
|
# losses, or the training config of compiled models).
|
|
_maybe_add_serialized_attributes(obj, metadata)
|
|
return obj, setter
|
|
|
|
def _revive_from_config(self, identifier, metadata, node_id):
|
|
"""Revives a layer/model from config, or returns None."""
|
|
if identifier == constants.METRIC_IDENTIFIER:
|
|
obj = self._revive_metric_from_config(metadata)
|
|
else:
|
|
obj = self._revive_graph_network(
|
|
identifier, metadata, node_id
|
|
) or self._revive_layer_or_model_from_config(metadata, node_id)
|
|
|
|
if obj is None:
|
|
return None, None
|
|
|
|
setter = self._config_node_setter(_revive_setter)
|
|
self._add_children_recreated_from_config(
|
|
obj, self._proto.nodes[node_id], node_id
|
|
)
|
|
return obj, setter
|
|
|
|
def _revive_graph_network(self, identifier, metadata, node_id):
|
|
"""Revives a graph network from config."""
|
|
# Determine whether the metadata contains information for reviving a
|
|
# functional or Sequential model.
|
|
config = metadata.get("config")
|
|
if not serialization.validate_config(config):
|
|
return None
|
|
|
|
class_name = tf.compat.as_str(metadata["class_name"])
|
|
if object_registration.get_registered_object(class_name) is not None:
|
|
return None
|
|
model_is_functional_or_sequential = (
|
|
metadata.get("is_graph_network", False)
|
|
or class_name == "Sequential"
|
|
or class_name == "Functional"
|
|
)
|
|
if not model_is_functional_or_sequential:
|
|
return None
|
|
|
|
# Revive functional and sequential models as blank model objects for now
|
|
# ( must be initialized to enable setattr tracking and attribute
|
|
# caching). Reconstruction of the network is deferred until all of the
|
|
# model's layers have been revived.
|
|
if class_name == "Sequential":
|
|
model = models_lib.Sequential(name=config["name"])
|
|
# The model is a custom Sequential model.
|
|
elif identifier == constants.SEQUENTIAL_IDENTIFIER:
|
|
# Uses the custom class name, since the config does not have one.
|
|
model = models_lib.Sequential(name=class_name)
|
|
else:
|
|
model = models_lib.Functional(
|
|
inputs=[], outputs=[], name=config["name"]
|
|
)
|
|
|
|
# Record this model and its layers. This will later be used to
|
|
# reconstruct the model.
|
|
layers = self._get_child_layer_node_ids(node_id)
|
|
self.model_layer_dependencies[node_id] = (model, layers)
|
|
if not layers:
|
|
self._models_to_reconstruct.append(node_id)
|
|
return model
|
|
|
|
def _revive_layer_or_model_from_config(self, metadata, node_id):
|
|
"""Revives a layer/custom model from config; returns None if
|
|
infeasible."""
|
|
# Check that the following requirements are met for reviving from
|
|
# config:
|
|
# 1. Object can be deserialized from config.
|
|
# 2. If the object needs to be built, then the build input shape can
|
|
# be found.
|
|
class_name = metadata.get("class_name")
|
|
config = metadata.get("config")
|
|
shared_object_id = metadata.get("shared_object_id")
|
|
must_restore_from_config = metadata.get("must_restore_from_config")
|
|
if not serialization.validate_config(config):
|
|
return None
|
|
|
|
try:
|
|
try:
|
|
obj = layers_module.deserialize(
|
|
serialization.serialize_keras_class_and_config(
|
|
class_name, config, shared_object_id=shared_object_id
|
|
)
|
|
)
|
|
except (TypeError, KeyError) as e:
|
|
# A name conflict has occurred. The `class_name` is in the Keras
|
|
# native framework; however, the value in the framework is
|
|
# different from the user's class definition which confuses the
|
|
# KerasObjectLoader.
|
|
builtin_layer = layers_module.get_builtin_layer(class_name)
|
|
if builtin_layer:
|
|
raise RuntimeError(
|
|
f"Unable to restore object of class '{class_name}'. "
|
|
"One of several possible causes could be "
|
|
"a missing custom object. "
|
|
"Decorate your custom object with "
|
|
"`@keras.utils.register_keras_serializable` and "
|
|
"include that file in your program, "
|
|
"or pass your class in a "
|
|
"`keras.utils.CustomObjectScope` "
|
|
"that wraps this load call. "
|
|
f"\n\nException: {e}"
|
|
) from e
|
|
else:
|
|
raise
|
|
except Exception as e:
|
|
if must_restore_from_config:
|
|
raise e
|
|
else:
|
|
return None
|
|
|
|
# Use the dtype, name, and trainable status. Often times these are not
|
|
# specified in custom configs, so retrieve their values from the
|
|
# metadata.
|
|
|
|
obj._name = metadata["name"]
|
|
if metadata.get("trainable") is not None:
|
|
obj.trainable = metadata["trainable"]
|
|
if metadata.get("dtype") is not None:
|
|
obj._set_dtype_policy(metadata["dtype"])
|
|
if metadata.get("stateful") is not None:
|
|
obj.stateful = metadata["stateful"]
|
|
if metadata.get("autocast") is not None:
|
|
obj._autocast = metadata["autocast"]
|
|
# Restore model save spec for subclassed models. (layers do not store a
|
|
# SaveSpec)
|
|
if isinstance(obj, training_lib.Model):
|
|
full_save_spec = metadata.get("full_save_spec")
|
|
if full_save_spec is not None:
|
|
args_spec, kwargs_spec = full_save_spec
|
|
inputs_spec = args_spec.pop(0)
|
|
obj._set_save_spec(inputs_spec, args_spec, kwargs_spec)
|
|
|
|
build_input_shape = metadata.get("build_input_shape")
|
|
built = self._try_build_layer(obj, node_id, build_input_shape)
|
|
|
|
if not built:
|
|
# If the layer cannot be built, revive a custom layer instead.
|
|
return None
|
|
return obj
|
|
|
|
def _revive_metric_from_config(self, metadata):
|
|
"""Revives a metric object using the config saved in the metadata."""
|
|
class_name = tf.compat.as_str(metadata["class_name"])
|
|
config = metadata.get("config")
|
|
|
|
if not serialization.validate_config(config):
|
|
return None
|
|
|
|
try:
|
|
obj = metrics.deserialize(
|
|
serialization.serialize_keras_class_and_config(
|
|
class_name, config
|
|
)
|
|
)
|
|
except ValueError:
|
|
return None
|
|
|
|
build_input_shape = metadata.get("build_input_shape")
|
|
if build_input_shape is not None and hasattr(obj, "_build"):
|
|
obj._build(build_input_shape)
|
|
|
|
return obj
|
|
|
|
def _try_build_layer(self, obj, node_id, build_input_shape):
|
|
"""Attempts to build the layer."""
|
|
if obj.built or hasattr(obj.build, "_is_default"):
|
|
obj.built = True
|
|
return True
|
|
|
|
if build_input_shape is None:
|
|
build_input_shape = self._infer_inputs(
|
|
node_id, convert_to_shapes=True
|
|
)
|
|
|
|
if build_input_shape is not None:
|
|
obj.build(build_input_shape)
|
|
base_layer.Layer.build(obj, build_input_shape)
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_path(self, node_id):
|
|
return self._node_paths[node_id]
|
|
|
|
def finalize_objects(self):
|
|
"""Finish setting up Keras objects.
|
|
|
|
This function is executed after all objects and functions have been
|
|
created. Call functions and losses are attached to each layer, and once
|
|
all layers have been fully set up, graph networks are initialized.
|
|
|
|
Subclassed models that are revived from the SavedModel are treated like
|
|
layers, and have their call/loss functions attached here.
|
|
"""
|
|
# Finish setting up layers and subclassed models. This step attaches
|
|
# call functions and losses to each object, and sets model
|
|
# inputs/outputs.
|
|
layers_revived_from_config = []
|
|
layers_revived_from_saved_model = []
|
|
for node_id, (node, _) in self.loaded_nodes.items():
|
|
if (
|
|
not isinstance(node, base_layer.Layer)
|
|
# Don't finalize models until all layers have finished loading.
|
|
or node_id in self.model_layer_dependencies
|
|
):
|
|
continue
|
|
|
|
self._unblock_model_reconstruction(node_id, node)
|
|
|
|
if isinstance(node, input_layer.InputLayer):
|
|
continue
|
|
elif isinstance(node, metrics.Metric):
|
|
continue
|
|
|
|
if isinstance(node, (RevivedLayer, RevivedInputLayer)):
|
|
layers_revived_from_saved_model.append(node)
|
|
else:
|
|
layers_revived_from_config.append(node)
|
|
|
|
_finalize_saved_model_layers(layers_revived_from_saved_model)
|
|
_finalize_config_layers(layers_revived_from_config)
|
|
|
|
# Initialize graph networks, now that layer dependencies have been
|
|
# resolved.
|
|
self._reconstruct_all_models()
|
|
|
|
def _unblock_model_reconstruction(self, layer_id, layer):
|
|
"""Removes layer from blocking model reconstruction."""
|
|
for model_id, v in self.model_layer_dependencies.items():
|
|
_, layers = v
|
|
if layer_id not in layers:
|
|
continue
|
|
layers[layers.index(layer_id)] = layer
|
|
if all(isinstance(x, base_layer.Layer) for x in layers):
|
|
self._models_to_reconstruct.append(model_id)
|
|
|
|
def _reconstruct_all_models(self):
|
|
"""Reconstructs the network structure of all models."""
|
|
all_initialized_models = set()
|
|
while self._models_to_reconstruct:
|
|
model_id = self._models_to_reconstruct.pop(0)
|
|
all_initialized_models.add(model_id)
|
|
model, layers = self.model_layer_dependencies[model_id]
|
|
self._reconstruct_model(model_id, model, layers)
|
|
_finalize_config_layers([model])
|
|
|
|
if all_initialized_models != set(self.model_layer_dependencies.keys()):
|
|
# This should not happen.
|
|
uninitialized_model_ids = (
|
|
set(self.model_layer_dependencies.keys())
|
|
- all_initialized_models
|
|
)
|
|
uninitialized_model_names = [
|
|
self.model_layer_dependencies[model_id][0].name
|
|
for model_id in uninitialized_model_ids
|
|
]
|
|
raise ValueError(
|
|
"Error loading model(s) in the SavedModel format. "
|
|
"The following model(s) could not be initialized: "
|
|
f"{uninitialized_model_names}"
|
|
)
|
|
|
|
def _reconstruct_model(self, model_id, model, layers):
|
|
"""Reconstructs the network structure."""
|
|
config = json_utils.decode(self._metadata[model_id].metadata)["config"]
|
|
|
|
# Set up model inputs
|
|
if model.inputs:
|
|
# Inputs may already be created if the model is instantiated in
|
|
# another object's __init__.
|
|
pass
|
|
elif isinstance(model, models_lib.Sequential):
|
|
if not layers or not isinstance(layers[0], input_layer.InputLayer):
|
|
if config["layers"][0]["class_name"] == "InputLayer":
|
|
layers.insert(
|
|
0,
|
|
input_layer.InputLayer.from_config(
|
|
config["layers"][0]["config"]
|
|
),
|
|
)
|
|
elif "batch_input_shape" in config["layers"][0]["config"]:
|
|
batch_input_shape = config["layers"][0]["config"][
|
|
"batch_input_shape"
|
|
]
|
|
layers.insert(
|
|
0,
|
|
input_layer.InputLayer(
|
|
input_shape=batch_input_shape[1:],
|
|
batch_size=batch_input_shape[0],
|
|
dtype=layers[0].dtype,
|
|
name=layers[0].name + "_input",
|
|
),
|
|
)
|
|
model.__init__(layers, name=config["name"])
|
|
if not model.inputs:
|
|
first_layer = self._get_child_layer_node_ids(model_id)[0]
|
|
input_specs = self._infer_inputs(first_layer)
|
|
input_shapes = self._infer_inputs(
|
|
first_layer, convert_to_shapes=True
|
|
)
|
|
model._set_inputs(input_specs)
|
|
if not model.built and not isinstance(input_specs, dict):
|
|
model.build(input_shapes)
|
|
else: # Reconstruct functional model
|
|
(
|
|
inputs,
|
|
outputs,
|
|
created_layers,
|
|
) = functional_lib.reconstruct_from_config(
|
|
config, created_layers={layer.name: layer for layer in layers}
|
|
)
|
|
model.__init__(inputs, outputs, name=config["name"])
|
|
functional_lib.connect_ancillary_layers(model, created_layers)
|
|
|
|
# Set model dtype.
|
|
_set_network_attributes_from_metadata(model)
|
|
|
|
# Unblock models that are dependent on this model.
|
|
self._unblock_model_reconstruction(model_id, model)
|
|
|
|
def _get_child_layer_node_ids(self, node_id):
|
|
"""Returns the node ids of each layer in a Sequential/Functional
|
|
model."""
|
|
# Sequential and Functional track layers with names following the format
|
|
# "layer-N". Use this to generate the list of layers.
|
|
num_layers = 0
|
|
child_layers = {}
|
|
pattern = re.compile("layer-(\\d+)")
|
|
|
|
for child in self._proto.nodes[node_id].children:
|
|
m = pattern.match(child.local_name)
|
|
if m is None:
|
|
continue
|
|
layer_n = int(m.group(1))
|
|
num_layers = max(layer_n + 1, num_layers)
|
|
child_layers[layer_n] = child.node_id
|
|
|
|
ordered = []
|
|
for n in range(num_layers):
|
|
child = child_layers.get(n)
|
|
if child is None:
|
|
break
|
|
ordered.append(child)
|
|
return ordered
|
|
|
|
def _search_for_child_node(self, parent_id, path_to_child):
|
|
"""Returns node id of child node.
|
|
|
|
A helper method for traversing the object graph proto.
|
|
|
|
As an example, say that the object graph proto in the SavedModel
|
|
contains an object with the following child and grandchild attributes:
|
|
|
|
`parent.child_a.child_b`
|
|
|
|
This method can be used to retrieve the node id of `child_b` using the
|
|
parent's node id by calling:
|
|
|
|
`_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
|
|
|
|
Args:
|
|
parent_id: node id of parent node
|
|
path_to_child: list of children names.
|
|
|
|
Returns:
|
|
node_id of child, or None if child isn't found.
|
|
"""
|
|
if not path_to_child:
|
|
return parent_id
|
|
|
|
for child in self._proto.nodes[parent_id].children:
|
|
if child.local_name == path_to_child[0]:
|
|
return self._search_for_child_node(
|
|
child.node_id, path_to_child[1:]
|
|
)
|
|
return None
|
|
|
|
def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
|
|
"""Infers input shape of layer from SavedModel functions."""
|
|
call_fn_id = self._search_for_child_node(
|
|
layer_node_id, ["call_and_return_all_conditional_losses"]
|
|
)
|
|
if call_fn_id is None:
|
|
return None
|
|
|
|
concrete_functions = self._proto.nodes[
|
|
call_fn_id
|
|
].function.concrete_functions
|
|
if not concrete_functions:
|
|
return None
|
|
call_fn_name = concrete_functions[0]
|
|
call_fn_proto = self._proto.concrete_functions[call_fn_name]
|
|
structured_input_signature = tf.__internal__.saved_model.decode_proto(
|
|
call_fn_proto.canonicalized_input_signature
|
|
)
|
|
inputs = structured_input_signature[0][0]
|
|
if convert_to_shapes:
|
|
return tf.nest.map_structure(lambda spec: spec.shape, inputs)
|
|
else:
|
|
return inputs
|
|
|
|
def _config_node_setter(self, setter):
|
|
"""Creates edges for nodes that are recreated from config."""
|
|
|
|
def setattr_wrapper(obj, name, value):
|
|
# Avoid overwriting attributes of objects recreated from the config.
|
|
if obj._lookup_dependency(name) is None:
|
|
setter(obj, name, value)
|
|
|
|
return setattr_wrapper
|
|
|
|
|
|
def _finalize_saved_model_layers(layers):
|
|
"""Runs the final steps of loading Keras Layers from SavedModel."""
|
|
|
|
# 1. Set up call functions for all layers initialized from the SavedModel (
|
|
# and not the config)
|
|
for layer in layers:
|
|
layer.built = True
|
|
layer_call = getattr(
|
|
_get_keras_attr(layer), "call_and_return_conditional_losses", None
|
|
)
|
|
if layer_call and layer_call.concrete_functions:
|
|
call_spec = layer_utils.CallFunctionSpec(
|
|
tf_inspect.getfullargspec(layer_call)
|
|
)
|
|
layer.call = utils.use_wrapped_call(
|
|
layer, layer_call, call_spec, return_method=True
|
|
)
|
|
expects_training_arg = layer._serialized_attributes["metadata"][
|
|
"expects_training_arg"
|
|
]
|
|
if "training" in layer_call.function_spec.arg_names:
|
|
# This could change the value of `expects_training_arg` if this
|
|
# layer doesn't expect a training arg, but has a child layer
|
|
# that does.
|
|
expects_training_arg = True
|
|
layer._init_call_fn_args(expects_training_arg)
|
|
else:
|
|
layer.call = types.MethodType(
|
|
_unable_to_call_layer_due_to_serialization_issue, layer
|
|
)
|
|
|
|
for layer in layers:
|
|
# 2. Set model inputs and outputs.
|
|
if isinstance(layer, RevivedNetwork):
|
|
_set_network_attributes_from_metadata(layer)
|
|
|
|
if hasattr(
|
|
_get_keras_attr(layer), "call_and_return_conditional_losses"
|
|
):
|
|
call_fn = _get_keras_attr(
|
|
layer
|
|
).call_and_return_conditional_losses
|
|
if not call_fn.concrete_functions:
|
|
continue
|
|
if call_fn.input_signature is None:
|
|
args, kwargs = infer_inputs_from_restored_call_function(
|
|
call_fn
|
|
)
|
|
args = list(args)
|
|
inputs = args.pop(0)
|
|
else:
|
|
args = call_fn.input_signature
|
|
args = list(args)
|
|
inputs = args.pop(0)
|
|
kwargs = None
|
|
layer._set_save_spec(inputs, args, kwargs)
|
|
|
|
# V1 models require calling _set_inputs to set the `.inputs`
|
|
# attr. Skip this step when there are multiple tensor inputs
|
|
# (this behavior is not well supported in V1 models).
|
|
if not any(
|
|
isinstance(x, tf.TensorSpec)
|
|
for x in tf.nest.flatten([args, kwargs])
|
|
):
|
|
layer._set_inputs(inputs)
|
|
|
|
# 3. Add losses that aren't generated by the layer.call function.
|
|
_restore_layer_unconditional_losses(layer)
|
|
_restore_layer_activation_loss(layer)
|
|
|
|
# 4. Restore metrics list
|
|
_restore_layer_metrics(layer)
|
|
|
|
|
|
def _unable_to_call_layer_due_to_serialization_issue(
|
|
layer, *unused_args, **unused_kwargs
|
|
):
|
|
"""Replaces the `layer.call` if the layer was not fully serialized.
|
|
|
|
Keras Model/Layer serialization is relatively relaxed because SavedModels
|
|
are not always loaded back as keras models. Thus, when there is an issue
|
|
tracing a non-signature function, a warning is logged instead of raising an
|
|
error. This results in a SavedModel where the model's call function is
|
|
saved, but the internal layer call functions are not.
|
|
|
|
When deserialized with `tf.keras.models.load_model`, the internal layers
|
|
which do not have serialized call functions should raise an error when
|
|
called.
|
|
|
|
Args:
|
|
layer: Layer without the serialized call function.
|
|
|
|
Raises:
|
|
ValueError
|
|
"""
|
|
|
|
raise ValueError(
|
|
f"Cannot call custom layer {layer.name} of type {type(layer)}, because "
|
|
"the call function was not serialized to the SavedModel."
|
|
"Please try one of the following methods to fix this issue:"
|
|
"\n\n(1) Implement `get_config` and `from_config` in the layer/model "
|
|
"class, and pass the object to the `custom_objects` argument when "
|
|
"loading the model. For more details, see: "
|
|
"https://www.tensorflow.org/guide/keras/save_and_serialize"
|
|
"\n\n(2) Ensure that the subclassed model or layer overwrites `call` "
|
|
"and not `__call__`. The input shape and dtype will be automatically "
|
|
"recorded when the object is called, and used when saving. To manually "
|
|
"specify the input shape/dtype, decorate the call function with "
|
|
"`@tf.function(input_signature=...)`."
|
|
)
|
|
|
|
|
|
def _finalize_config_layers(layers):
|
|
"""Runs the final steps of loading Keras Layers from config."""
|
|
for layer in layers:
|
|
# It is assumed that layers define their unconditional losses after
|
|
# being recreated from the config and built. The exceptions to this are
|
|
# Functional and Sequential models, which only store conditional losses
|
|
# (losses dependent on the inputs) in the config. Unconditional losses
|
|
# like weight regularization must be revived from the SavedModel.
|
|
if _is_graph_network(layer):
|
|
_restore_layer_unconditional_losses(layer)
|
|
|
|
# Some layers, like Dense, record their activation loss function in the
|
|
# config. However, not all layers do this, so the activation loss may be
|
|
# missing when restored from the config/hdf5.
|
|
# TODO(kathywu): Investigate ways to improve the config to ensure
|
|
# consistent loading behavior between HDF5 and SavedModel.
|
|
_restore_layer_activation_loss(layer)
|
|
|
|
# Restore metrics list.
|
|
_restore_layer_metrics(layer)
|
|
|
|
# Restore RNN layer states.
|
|
if (
|
|
isinstance(layer, base_rnn.RNN)
|
|
and layer.stateful
|
|
and hasattr(_get_keras_attr(layer), "states")
|
|
):
|
|
layer.states = getattr(_get_keras_attr(layer), "states", None)
|
|
for variable in tf.nest.flatten(layer.states):
|
|
backend.track_variable(variable)
|
|
|
|
# Perform any layer defined finalization of the layer state.
|
|
layer.finalize_state()
|
|
|
|
|
|
def _finalize_metric(metric):
|
|
metric.update_state = types.MethodType(
|
|
metrics_utils.update_state_wrapper(metric.keras_api.update_state),
|
|
metric,
|
|
)
|
|
metric.result = metric.keras_api.result
|
|
|
|
|
|
def _restore_layer_unconditional_losses(layer):
|
|
"""Restore unconditional losses from SavedModel."""
|
|
if hasattr(_get_keras_attr(layer), "layer_regularization_losses"):
|
|
losses = getattr(
|
|
_get_keras_attr(layer), "layer_regularization_losses", []
|
|
)
|
|
else:
|
|
# Some earlier SavedModels may not have layer_regularization_losses
|
|
# serialized separately. Fall back to using the regularization_losses
|
|
# list if it does not exist.
|
|
losses = layer._serialized_attributes.get("regularization_losses", [])
|
|
for loss in losses:
|
|
layer.add_loss(loss)
|
|
|
|
|
|
def _restore_layer_activation_loss(layer):
|
|
"""Restore actiation loss from SavedModel."""
|
|
# Use wrapped activity regularizer function if the layer's activity
|
|
# regularizer wasn't created during initialization.
|
|
activity_regularizer = getattr(
|
|
_get_keras_attr(layer), "activity_regularizer_fn", None
|
|
)
|
|
if activity_regularizer and not layer.activity_regularizer:
|
|
try:
|
|
layer.activity_regularizer = activity_regularizer
|
|
except AttributeError:
|
|
# This may happen if a layer wrapper is saved with an activity
|
|
# regularizer. The wrapper object's activity regularizer is
|
|
# unsettable.
|
|
pass
|
|
|
|
|
|
def revive_custom_object(identifier, metadata):
|
|
"""Revives object from SavedModel."""
|
|
if tf.compat.v1.executing_eagerly_outside_functions():
|
|
model_class = training_lib.Model
|
|
else:
|
|
model_class = training_lib_v1.Model
|
|
|
|
revived_classes = {
|
|
constants.INPUT_LAYER_IDENTIFIER: (
|
|
RevivedInputLayer,
|
|
input_layer.InputLayer,
|
|
),
|
|
constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
|
|
constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
|
|
constants.NETWORK_IDENTIFIER: (
|
|
RevivedNetwork,
|
|
functional_lib.Functional,
|
|
),
|
|
constants.SEQUENTIAL_IDENTIFIER: (
|
|
RevivedNetwork,
|
|
models_lib.Sequential,
|
|
),
|
|
}
|
|
parent_classes = revived_classes.get(identifier, None)
|
|
|
|
class_name = tf.compat.as_str(metadata["class_name"])
|
|
if parent_classes is not None:
|
|
parent_classes = revived_classes[identifier]
|
|
revived_cls = type(class_name, parent_classes, {})
|
|
return revived_cls._init_from_metadata(metadata)
|
|
else:
|
|
raise ValueError(
|
|
f'Unable to restore custom object of class "{class_name}" '
|
|
f"(type {identifier}). Please make sure that this class is "
|
|
"included in the `custom_objects` arg when calling `load_model()`. "
|
|
"Also, check that the class implements `get_config` and "
|
|
f"`from_config`.\n\nComplete metadata: {metadata}"
|
|
)
|
|
|
|
|
|
def _restore_layer_metrics(layer):
|
|
metrics_list = getattr(_get_keras_attr(layer), "layer_metrics", {})
|
|
layer_metrics = {m.name: m for m in layer._metrics}
|
|
for name, metric in metrics_list.items():
|
|
if name not in layer_metrics:
|
|
# Metrics may be added during initialization/building of custom
|
|
# layers.
|
|
layer._metrics.append(metric)
|
|
|
|
|
|
# TODO(kathywu): Centrally define keys and functions for both serialization and
|
|
# deserialization.
|
|
class RevivedLayer:
|
|
"""Keras layer loaded from a SavedModel."""
|
|
|
|
@classmethod
|
|
def _init_from_metadata(cls, metadata):
|
|
"""Create revived layer from metadata stored in the SavedModel proto."""
|
|
init_args = dict(name=metadata["name"], trainable=metadata["trainable"])
|
|
if metadata.get("dtype") is not None:
|
|
init_args["dtype"] = metadata["dtype"]
|
|
if metadata.get("batch_input_shape") is not None:
|
|
init_args["batch_input_shape"] = metadata["batch_input_shape"]
|
|
|
|
revived_obj = cls(**init_args)
|
|
|
|
with utils.no_automatic_dependency_tracking_scope(revived_obj):
|
|
|
|
revived_obj._call_spec.expects_training_arg = metadata[
|
|
"expects_training_arg"
|
|
]
|
|
config = metadata.get("config")
|
|
if serialization.validate_config(config):
|
|
revived_obj._config = config
|
|
if metadata.get("input_spec") is not None:
|
|
revived_obj.input_spec = recursively_deserialize_keras_object(
|
|
metadata["input_spec"],
|
|
module_objects={"InputSpec": input_spec.InputSpec},
|
|
)
|
|
if metadata.get("activity_regularizer") is not None:
|
|
revived_obj.activity_regularizer = regularizers.deserialize(
|
|
metadata["activity_regularizer"]
|
|
)
|
|
if metadata.get("_is_feature_layer") is not None:
|
|
revived_obj._is_feature_layer = metadata["_is_feature_layer"]
|
|
if metadata.get("stateful") is not None:
|
|
revived_obj.stateful = metadata["stateful"]
|
|
if metadata.get("autocast") is not None:
|
|
revived_obj._autocast = metadata["autocast"]
|
|
if metadata.get("preserve_input_structure_in_config") is not None:
|
|
revived_obj._preserve_input_structure_in_config = metadata[
|
|
"preserve_input_structure_in_config"
|
|
]
|
|
|
|
return revived_obj, _revive_setter
|
|
|
|
@property
|
|
def keras_api(self):
|
|
return self._serialized_attributes.get(constants.KERAS_ATTR, None)
|
|
|
|
def get_config(self):
|
|
if hasattr(self, "_config"):
|
|
return self._config
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def _revive_setter(layer, name, value):
|
|
"""Setter function that saves some attributes to separate dictionary."""
|
|
# Many attributes in the SavedModel conflict with properties defined in
|
|
# Layer and Model. Save these attributes to a separate dictionary.
|
|
if name in PUBLIC_ATTRIBUTES:
|
|
|
|
if isinstance(value, tf.__internal__.tracking.Trackable):
|
|
layer._track_trackable(value, name=name)
|
|
layer._serialized_attributes[name] = value
|
|
|
|
elif (
|
|
isinstance(layer, functional_lib.Functional)
|
|
and re.match(r"^layer(_with_weights)?-[\d+]", name) is not None
|
|
):
|
|
# Edges named "layer-n" or "layer_with_weights-n", which are tracked in
|
|
# network._track_layers, should not be added as an attribute. They
|
|
# should be temporarily added as a dependency so that checkpointed
|
|
# values can be restored. These dependencies are manually deleted in
|
|
# KerasObjectLoader.del_tracking.
|
|
|
|
# Set `overwrite=True` in the case that `layer` already tracks a
|
|
# different layer-n. This may cause variable values to not be loaded
|
|
# properly in the original layer-n, but we already warn the users about
|
|
# this (ctrl-f "shared between different layers/models").
|
|
layer._track_trackable(value, name, overwrite=True)
|
|
elif getattr(layer, name, None) is not None:
|
|
# Don't overwrite already defined attributes.
|
|
pass
|
|
else:
|
|
setattr(layer, name, value)
|
|
|
|
|
|
class RevivedInputLayer:
|
|
"""InputLayer loaded from a SavedModel."""
|
|
|
|
@classmethod
|
|
def _init_from_metadata(cls, metadata):
|
|
"""Revives the saved InputLayer from the Metadata."""
|
|
init_args = dict(
|
|
name=metadata["name"],
|
|
dtype=metadata["dtype"],
|
|
sparse=metadata["sparse"],
|
|
ragged=metadata["ragged"],
|
|
batch_input_shape=metadata["batch_input_shape"],
|
|
)
|
|
revived_obj = cls(**init_args)
|
|
with utils.no_automatic_dependency_tracking_scope(revived_obj):
|
|
revived_obj._config = metadata["config"]
|
|
|
|
return revived_obj, setattr
|
|
|
|
def get_config(self):
|
|
return self._config
|
|
|
|
|
|
def recursively_deserialize_keras_object(config, module_objects=None):
|
|
"""Deserialize Keras object from a nested structure."""
|
|
if isinstance(config, dict):
|
|
if "class_name" in config:
|
|
return serialization.deserialize_keras_object(
|
|
config, module_objects=module_objects
|
|
)
|
|
else:
|
|
return {
|
|
key: recursively_deserialize_keras_object(
|
|
config[key], module_objects
|
|
)
|
|
for key in config
|
|
}
|
|
elif isinstance(config, (tuple, list)):
|
|
return [
|
|
recursively_deserialize_keras_object(x, module_objects)
|
|
for x in config
|
|
]
|
|
else:
|
|
raise ValueError(
|
|
"Unable to decode Keras layer config. Config should be a "
|
|
f"dictionary, tuple or list. Received: config={config}"
|
|
)
|
|
|
|
|
|
def infer_inputs_from_restored_call_function(fn):
|
|
"""Returns TypeSpec of inputs from a restored call function.
|
|
|
|
Args:
|
|
fn: Restored layer call function. It is assumed that `fn` has at least one
|
|
concrete function and that the inputs are in the first argument.
|
|
|
|
Returns:
|
|
TypeSpec of call function inputs in the form of (args, kwargs)
|
|
"""
|
|
|
|
def common_spec(x, y):
|
|
if not isinstance(x, tf.TypeSpec):
|
|
# Doesn't particularly matter what is returned in this case because
|
|
# the result will be filtered out in _set_input_shape.
|
|
return x
|
|
|
|
result = x._without_tensor_names().most_specific_common_supertype(
|
|
[y._without_tensor_names()]
|
|
)
|
|
if result is None:
|
|
# Please file a bug if you are being hindered by this error.
|
|
raise TypeError(f"No common supertype of {x} and {y}.")
|
|
return result
|
|
|
|
spec = fn.concrete_functions[0].structured_input_signature
|
|
for concrete in fn.concrete_functions[1:]:
|
|
spec2 = concrete.structured_input_signature
|
|
spec = tf.nest.map_structure(common_spec, spec, spec2)
|
|
return spec
|
|
|
|
|
|
class RevivedNetwork(RevivedLayer):
|
|
"""Keras network of layers loaded from a SavedModel."""
|
|
|
|
@classmethod
|
|
def _init_from_metadata(cls, metadata):
|
|
"""Create revived network from metadata stored in the SavedModel
|
|
proto."""
|
|
revived_obj = cls(name=metadata["name"])
|
|
|
|
# Store attributes revived from SerializedAttributes in a un-tracked
|
|
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
|
# "keras_api" for keras-specific attributes.
|
|
with utils.no_automatic_dependency_tracking_scope(revived_obj):
|
|
|
|
revived_obj._call_spec.expects_training_arg = metadata[
|
|
"expects_training_arg"
|
|
]
|
|
config = metadata.get("config")
|
|
if serialization.validate_config(config):
|
|
revived_obj._config = config
|
|
|
|
if metadata.get("activity_regularizer") is not None:
|
|
revived_obj.activity_regularizer = regularizers.deserialize(
|
|
metadata["activity_regularizer"]
|
|
)
|
|
if metadata.get("autocast") is not None:
|
|
revived_obj._autocast = metadata["autocast"]
|
|
|
|
return revived_obj, _revive_setter
|
|
|
|
|
|
def _set_network_attributes_from_metadata(revived_obj):
|
|
"""Sets attributes recorded in the metadata."""
|
|
with utils.no_automatic_dependency_tracking_scope(revived_obj):
|
|
|
|
metadata = revived_obj._serialized_attributes["metadata"]
|
|
if metadata.get("dtype") is not None:
|
|
revived_obj._set_dtype_policy(metadata["dtype"])
|
|
revived_obj._trainable = metadata["trainable"]
|
|
|
|
|
|
def _maybe_add_serialized_attributes(layer, metadata):
|
|
# Store attributes revived from SerializedAttributes in a un-tracked
|
|
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
|
# "keras_api" for keras-specific attributes.
|
|
if not hasattr(layer, "_serialized_attributes"):
|
|
with utils.no_automatic_dependency_tracking_scope(layer):
|
|
layer._serialized_attributes = {"metadata": metadata}
|
|
|
|
|
|
def _get_keras_attr(layer):
|
|
return getattr(layer, "_serialized_attributes", {}).get(
|
|
constants.KERAS_ATTR, None
|
|
)
|