726 lines
31 KiB
Python
726 lines
31 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Logic for restoring checkpointed values for Trackables."""
|
|
|
|
import collections
|
|
|
|
from tensorflow.python.checkpoint import checkpoint_view
|
|
from tensorflow.python.checkpoint import functional_saver
|
|
from tensorflow.python.checkpoint import save_util_v1
|
|
from tensorflow.python.checkpoint import saveable_compat
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gen_io_ops as io_ops
|
|
from tensorflow.python.ops import io_ops
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.saved_model import registration
|
|
from tensorflow.python.trackable import base
|
|
from tensorflow.python.trackable import constants
|
|
from tensorflow.python.trackable import python_state
|
|
from tensorflow.python.trackable import trackable_utils
|
|
from tensorflow.python.training.saving import saveable_object_util
|
|
from tensorflow.python.util import object_identity
|
|
|
|
|
|
class CheckpointPosition(object):
|
|
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
|
|
|
|
__slots__ = ["_checkpoint", "_proto_id", "skip_restore"]
|
|
|
|
def __init__(self, checkpoint, proto_id):
|
|
"""Specify an object within a checkpoint.
|
|
|
|
Args:
|
|
checkpoint: A _CheckpointRestoreCoordinator object.
|
|
proto_id: The index of this object in TrackableObjectGraph.nodes.
|
|
"""
|
|
self._checkpoint = checkpoint
|
|
self._proto_id = proto_id
|
|
# This may be set to True if the registered saver cannot be used with this
|
|
# object.
|
|
self.skip_restore = False
|
|
|
|
def restore(self, trackable, reader=None):
|
|
"""Restore this value into `trackable`."""
|
|
with ops.init_scope():
|
|
if self.bind_object(trackable):
|
|
# This object's correspondence with a checkpointed object is new, so
|
|
# process deferred restorations for it and its dependencies.
|
|
restore_ops = self._restore_descendants(reader)
|
|
if restore_ops:
|
|
self._checkpoint.new_restore_ops(restore_ops)
|
|
|
|
def bind_object(self, trackable):
|
|
"""Set a checkpoint<->object correspondence.
|
|
|
|
Args:
|
|
trackable: The object to record a correspondence for.
|
|
|
|
Returns:
|
|
True if this is a new assignment, False if this object has already been
|
|
mapped to a checkpointed `Object` proto.
|
|
Raises:
|
|
AssertionError: If another object is already bound to the `Object` proto.
|
|
"""
|
|
checkpoint = self.checkpoint
|
|
checkpoint.all_python_objects.add(trackable)
|
|
current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
|
|
checkpoint.matched_proto_ids.add(self._proto_id)
|
|
if current_assignment is None:
|
|
checkpoint.object_by_proto_id[self._proto_id] = trackable
|
|
return True # New assignment
|
|
else:
|
|
# The object was already mapped for this checkpoint load, which means
|
|
# we don't need to do anything besides check that the mapping is
|
|
# consistent (if the dependency DAG is not a tree then there are
|
|
# multiple paths to the same object).
|
|
if current_assignment is not trackable:
|
|
logging.warning(
|
|
"Inconsistent references when loading the checkpoint into this "
|
|
"object graph. For example, in the saved checkpoint object, "
|
|
"`model.layer.weight` and `model.layer_copy.weight` reference the "
|
|
"same variable, while in the current object these are two different"
|
|
" variables. The referenced variables are:"
|
|
f"({current_assignment} and {trackable}).")
|
|
return False # Not a new assignment
|
|
|
|
def is_simple_variable(self):
|
|
"""Determine whether this value is restorable with a Tensor initializer."""
|
|
attributes = self.object_proto.attributes
|
|
return (len(attributes) == 1 and
|
|
attributes[0].name == constants.VARIABLE_VALUE_KEY and
|
|
not self.object_proto.children)
|
|
|
|
def value_tensors(self, shape_and_slices=None):
|
|
"""Create value `Tensor`s for this object's attributes.
|
|
|
|
Does not require that the Python object has been created. Used for
|
|
restore-on-create when executing eagerly.
|
|
|
|
Args:
|
|
shape_and_slices: A dict mapping from object attribute names to a shape
|
|
and slice string that will be passed to a RestoreV2 op. If the dict is
|
|
None or if an object attribute is not in the dict, the full tensor will
|
|
be restored.
|
|
|
|
Returns:
|
|
A dictionary mapping from object attribute names to `Tensor`s.
|
|
"""
|
|
value_tensors = {}
|
|
for serialized_tensor in self.object_proto.attributes:
|
|
checkpoint_key = serialized_tensor.checkpoint_key
|
|
dtype = self._checkpoint.dtype_map[checkpoint_key]
|
|
base_type = dtype.base_dtype
|
|
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
|
|
with ops.init_scope():
|
|
with ops.device(io_device):
|
|
# Run the restore itself on the io_device(CPU or specified).
|
|
if (shape_and_slices is not None and
|
|
serialized_tensor.name in shape_and_slices):
|
|
shape_and_slice = shape_and_slices[serialized_tensor.name]
|
|
else:
|
|
shape_and_slice = ""
|
|
value, = io_ops.restore_v2(
|
|
prefix=self._checkpoint.save_path_tensor,
|
|
tensor_names=[checkpoint_key],
|
|
shape_and_slices=[shape_and_slice],
|
|
dtypes=[base_type],
|
|
name="%s_checkpoint_read" % (serialized_tensor.name,))
|
|
# Copy the value to the current device if necessary.
|
|
value_tensors[serialized_tensor.name] = array_ops.identity(value)
|
|
return value_tensors
|
|
|
|
def gather_ops_or_named_saveables(self):
|
|
"""Looks up or creates SaveableObjects which don't have cached ops.
|
|
|
|
Returns:
|
|
A tuple of (
|
|
existing_restore_ops: list,
|
|
named_saveables: dict,
|
|
python_positions: list,
|
|
registered_savers: dict)
|
|
"""
|
|
|
|
recorded_registered_saver = self.get_registered_saver_name()
|
|
if not (self.object_proto.attributes or recorded_registered_saver):
|
|
return [], {}, [], {}
|
|
|
|
existing_restore_ops = []
|
|
named_saveables = {}
|
|
python_positions = []
|
|
registered_savers = collections.defaultdict(dict)
|
|
|
|
saveable_factories = saveable_object_util.saveable_objects_from_trackable(
|
|
self.trackable)
|
|
saver_name = registration.get_registered_saver_name(self.trackable)
|
|
|
|
if recorded_registered_saver:
|
|
if not self.skip_restore:
|
|
name = self.object_proto.registered_saver.object_name
|
|
registered_savers[recorded_registered_saver][name] = self.trackable
|
|
# Else: Skip restoration of this Trackable. This skip only happens if the
|
|
# registered saver has enabled `option_restore`. Otherwise, an error would
|
|
# have been raised at `self.get_registered_saver_name()`.
|
|
elif saver_name:
|
|
# In this case, the checkpoint has a recorded serialized tensor but no
|
|
# registered saver, while the Trackable loading the checkpoint has
|
|
# migrated to the registered checkpoint functionality (TPUEmbedding is an
|
|
# example of this).
|
|
|
|
# Set the Trackable's object name to the first checkpoint key that is
|
|
# stored in checkpoint. If there is a use case that requires the other
|
|
# keys, then we can take another look at this.
|
|
registered_savers[saver_name] = {
|
|
self.object_proto.attributes[0].checkpoint_key: self.trackable
|
|
}
|
|
elif isinstance(self.trackable, python_state.PythonState):
|
|
python_positions.append(self)
|
|
elif saveable_factories.keys() == {
|
|
trackable_utils.SERIALIZE_TO_TENSORS_NAME
|
|
}:
|
|
existing_restore_ops, named_saveables = (
|
|
self._create_serialize_to_tensor_saveable(saveable_factories))
|
|
elif saveable_factories:
|
|
existing_restore_ops, named_saveables = (
|
|
self._create_saveables_by_attribute_name(saveable_factories))
|
|
else:
|
|
# If no registered savers were found, then it means that one or more
|
|
# serialized tensors were never used.
|
|
for serialized_tensor in self.object_proto.attributes:
|
|
self._checkpoint.unused_attributes.setdefault(
|
|
self._proto_id, []).append(serialized_tensor.name)
|
|
return (existing_restore_ops, named_saveables, python_positions,
|
|
registered_savers)
|
|
|
|
def _create_serialize_to_tensor_saveable(self, saveable_factories):
|
|
"""Creates a saveable using the _serialize_to_tensor method."""
|
|
# Extract the saveable name from the checkpoint key. This will be used as
|
|
# the cache key or the name to pass to the saveable factory.
|
|
suffix = saveable_compat.get_saveable_name(self.trackable) or ""
|
|
saveable_name = _extract_saveable_name(
|
|
self.object_proto.attributes[0].checkpoint_key) + suffix
|
|
|
|
# Try to find the cached saveable (only in graph mode).
|
|
if not context.executing_eagerly():
|
|
existing_op = self._checkpoint.restore_ops_by_name.get(
|
|
saveable_name, None)
|
|
if existing_op is not None:
|
|
return [existing_op], {}
|
|
|
|
saveables_cache = self._checkpoint.saveables_cache.setdefault(
|
|
self.trackable, {})
|
|
if saveable_name in saveables_cache:
|
|
return [], {saveable_name: saveables_cache[saveable_name]}
|
|
|
|
saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME](
|
|
name=saveable_name)
|
|
if not context.executing_eagerly():
|
|
saveables_cache[saveable_name] = saveable
|
|
return [], {saveable_name: saveable}
|
|
|
|
def _create_saveables_by_attribute_name(self, saveable_factories):
|
|
"""Creates or caches SaveableObjects by matching the attribute names.
|
|
|
|
The attribute name keys in the `saveable_factories` is used to find the
|
|
corresponding attribute in the object proto. Attributes contain checkpoint
|
|
keys which are passed to the factory function to generate the
|
|
SaveableObject.
|
|
|
|
Args:
|
|
saveable_factories: a dict mapping attribute name to a callable factory
|
|
function that produces a SaveableObject.
|
|
|
|
Returns:
|
|
A tuple of (
|
|
existing_restore_ops: list,
|
|
named_saveables: dict)
|
|
"""
|
|
# Name saveables based on the name this object had when it was checkpointed.
|
|
named_saveables = {}
|
|
existing_restore_ops = []
|
|
|
|
# Forward compatibility code: when loading a future checkpoint, there may
|
|
# be multiple SerializedTensors mapped to a single saveable.
|
|
created_compat_names = set()
|
|
|
|
for serialized_tensor in self.object_proto.attributes:
|
|
if context.executing_eagerly():
|
|
existing_op = None
|
|
else:
|
|
existing_op = self._checkpoint.restore_ops_by_name.get(
|
|
serialized_tensor.checkpoint_key, None)
|
|
if existing_op is not None:
|
|
existing_restore_ops.append(existing_op)
|
|
continue
|
|
|
|
if any(serialized_tensor.name.startswith(name)
|
|
for name in created_compat_names):
|
|
continue # Saveable has already been created for this tensor.
|
|
|
|
# Only if we don't have cached ops for this SaveableObject, we'll see if
|
|
# the SaveableObject itself has been cached. If not, we'll make it, and
|
|
# either way we'll extract new ops from it (or if it has Python state to
|
|
# restore, we'll run that).
|
|
saveables_cache = self._checkpoint.saveables_cache
|
|
if saveables_cache is None:
|
|
# No SaveableObject caching when executing eagerly.
|
|
saveable = None
|
|
else:
|
|
# If we've already created and cached a SaveableObject for this
|
|
# attribute, we can re-use it to avoid re-creating some ops when graph
|
|
# building.
|
|
saveable_list = saveables_cache.get(self.trackable,
|
|
{}).get(serialized_tensor.name,
|
|
(None,))
|
|
if len(saveable_list) == 1:
|
|
# Almost every attribute will have exactly one SaveableObject.
|
|
saveable, = saveable_list
|
|
else:
|
|
# Don't use cached SaveableObjects for partitioned variables, which is
|
|
# the only case where we'd have a list of SaveableObjects. Op caching
|
|
# will catch them.
|
|
saveable = None
|
|
if saveable is not None:
|
|
# The name of this attribute has changed, so we need to re-generate
|
|
# the SaveableObject.
|
|
if serialized_tensor.checkpoint_key not in saveable.name:
|
|
saveable = None
|
|
del saveables_cache[self.trackable]
|
|
if saveable is None:
|
|
# If there was no cached SaveableObject, create one.
|
|
# Use the name to check if the Python object has the same attribute.
|
|
saveable = _get_saveable_from_factory(saveable_factories,
|
|
serialized_tensor,
|
|
created_compat_names)
|
|
if saveable is None:
|
|
# Purposefully does not throw an exception if attributes have been
|
|
# added or deleted. Stores unused attributes so an exception can be
|
|
# raised if the user decides to check that everything in the
|
|
# checkpoint was loaded.
|
|
self._checkpoint.unused_attributes.setdefault(
|
|
self._proto_id, []).append(serialized_tensor.name)
|
|
continue
|
|
if saveables_cache is not None:
|
|
saveables_cache.setdefault(self.trackable,
|
|
{})[serialized_tensor.name] = [saveable]
|
|
named_saveables[serialized_tensor.checkpoint_key] = saveable
|
|
|
|
return existing_restore_ops, named_saveables
|
|
|
|
def restore_ops(self, reader=None):
|
|
"""Create or fetch restore ops for this object's attributes.
|
|
|
|
Requires that the `Trackable` Python object has been bound to an object
|
|
ID in the checkpoint.
|
|
|
|
Args:
|
|
reader: A `CheckpointReader`. If None, a new instance will be created.
|
|
|
|
Returns:
|
|
A list of operations when graph building, or an empty list when executing
|
|
eagerly.
|
|
"""
|
|
if self._has_registered_saver():
|
|
raise ValueError("Unable to run individual checkpoint restore for objects"
|
|
" with registered savers.")
|
|
(restore_ops, tensor_saveables, python_positions,
|
|
_) = self.gather_ops_or_named_saveables()
|
|
restore_ops.extend(
|
|
self._checkpoint.restore_saveables(
|
|
tensor_saveables, python_positions, reader=reader))
|
|
return restore_ops
|
|
|
|
@property
|
|
def checkpoint(self):
|
|
return self._checkpoint
|
|
|
|
@property
|
|
def trackable(self):
|
|
return self._checkpoint.object_by_proto_id[self._proto_id]
|
|
|
|
@property
|
|
def object_proto(self):
|
|
return self._checkpoint.object_graph_proto.nodes[self._proto_id]
|
|
|
|
@property
|
|
def proto_id(self):
|
|
return self._proto_id
|
|
|
|
@property
|
|
def restore_uid(self):
|
|
return self._checkpoint.restore_uid
|
|
|
|
def __repr__(self):
|
|
return repr(self.object_proto)
|
|
|
|
def value_shape(self):
|
|
"""The shape of the VARIABLE_VALUE tensor.
|
|
|
|
Returns:
|
|
If found a TensorShape object, otherwise None.
|
|
"""
|
|
for serialized_tensor in self.object_proto.attributes:
|
|
if serialized_tensor.name == constants.VARIABLE_VALUE_KEY:
|
|
return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
|
|
return None
|
|
|
|
def _has_registered_saver(self):
|
|
return bool(self.object_proto.registered_saver.name)
|
|
|
|
def get_registered_saver_name(self):
|
|
"""Returns the registered saver name defined in the Checkpoint."""
|
|
if self._has_registered_saver():
|
|
saver_name = self.object_proto.registered_saver.name
|
|
try:
|
|
registration.validate_restore_function(self.trackable, saver_name)
|
|
except ValueError as e:
|
|
if registration.get_strict_predicate_restore(saver_name):
|
|
raise e
|
|
self.skip_restore = True
|
|
return saver_name
|
|
return None
|
|
|
|
def create_slot_variable_position(self, optimizer_object, variable,
|
|
slot_variable_id, slot_name):
|
|
"""Generates CheckpointPosition for a slot variable.
|
|
|
|
Args:
|
|
optimizer_object: Optimizer that owns the slot variable.
|
|
variable: Variable associated with the slot variable.
|
|
slot_variable_id: ID of the slot variable.
|
|
slot_name: Name of the slot variable.
|
|
|
|
Returns:
|
|
If there is a slot variable in the `optimizer_object` that has not been
|
|
bound to the checkpoint, this function returns a tuple of (
|
|
new `CheckpointPosition` for the slot variable,
|
|
the slot variable itself).
|
|
"""
|
|
slot_variable_position = CheckpointPosition(
|
|
checkpoint=self.checkpoint, proto_id=slot_variable_id)
|
|
# pylint: disable=protected-access
|
|
slot_variable = optimizer_object._create_or_restore_slot_variable(
|
|
slot_variable_position=slot_variable_position,
|
|
variable=variable,
|
|
slot_name=slot_name)
|
|
# pylint: enable=protected-access
|
|
if (slot_variable is not None and
|
|
slot_variable_position.bind_object(slot_variable)):
|
|
return slot_variable_position, slot_variable
|
|
else:
|
|
return None, None
|
|
|
|
def create_child_position(self, node_id):
|
|
return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id)
|
|
|
|
def _restore_descendants(self, reader=None):
|
|
"""Restore the bound Trackable and dependencies (may be deferred)."""
|
|
# Attempt a breadth-first traversal, since presumably the user has more
|
|
# control over shorter paths. If we don't have all of the dependencies at
|
|
# this point, the end result is not breadth-first (since other deferred
|
|
# traversals will happen later).
|
|
|
|
# You may be wondering why elements in the `visit_queue` are tuples that
|
|
# contains both CheckpointPositions and their Trackable. The reason is that
|
|
# Optimizers will not keep a strong reference to slot vars for
|
|
# ShardedVariables. The slot variable must be kept in memory until the
|
|
# restore saveables have been created.
|
|
visit_queue = collections.deque([(self, self.trackable)])
|
|
restore_ops = []
|
|
tensor_saveables = {}
|
|
python_positions = []
|
|
registered_savers = collections.defaultdict(dict)
|
|
while visit_queue:
|
|
current_position, _ = visit_queue.popleft()
|
|
|
|
# Restore using the ops defined in a Saveable or registered function.
|
|
(new_restore_ops, new_tensor_saveables, new_python_positions,
|
|
new_registered_savers) = current_position._single_restore() # pylint: disable=protected-access
|
|
restore_ops.extend(new_restore_ops)
|
|
tensor_saveables.update(new_tensor_saveables)
|
|
python_positions.extend(new_python_positions)
|
|
for saver_name, trackable_map in new_registered_savers.items():
|
|
registered_savers[saver_name].update(trackable_map)
|
|
|
|
# Pass the restoration to the dependencies.
|
|
_queue_children_for_restoration(current_position, visit_queue)
|
|
_queue_slot_variables(current_position, visit_queue)
|
|
|
|
restore_ops.extend(
|
|
current_position.checkpoint.restore_saveables(
|
|
tensor_saveables,
|
|
python_positions,
|
|
registered_savers,
|
|
reader=reader))
|
|
return restore_ops
|
|
|
|
def _single_restore(self):
|
|
"""Restores the trackable."""
|
|
trackable = self.trackable
|
|
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
|
|
checkpoint = self.checkpoint
|
|
# If the UID of this restore is lower than our current update UID, we don't
|
|
# need to actually restore the object.
|
|
if checkpoint.restore_uid > trackable._update_uid: # pylint: disable=protected-access
|
|
restore_ops, tensor_saveables, python_positions, registered_savers = (
|
|
self.gather_ops_or_named_saveables())
|
|
trackable._update_uid = checkpoint.restore_uid # pylint: disable=protected-access
|
|
else:
|
|
restore_ops = ()
|
|
tensor_saveables = {}
|
|
python_positions = ()
|
|
registered_savers = {}
|
|
return restore_ops, tensor_saveables, python_positions, registered_savers
|
|
|
|
|
|
def restore_nodes(save_path, nodes_to_restore):
|
|
"""Restores nodes from a dict.
|
|
|
|
Requires that the `Trackable` Python object has been bound to an object
|
|
ID in the checkpoint.
|
|
|
|
Args:
|
|
save_path: a string represents path to the checkpoint.
|
|
nodes_to_restore: a dict maps `node_id` to `trackable` to be restored.
|
|
"""
|
|
if save_path is None:
|
|
raise ValueError("save_path cannot be empty.")
|
|
if not isinstance(nodes_to_restore, dict):
|
|
raise ValueError(
|
|
"Expecting a dictionary of node_id to Trackable for nodes_to_restore.")
|
|
|
|
ckpt_view = checkpoint_view.CheckpointView(save_path)
|
|
ckpt_view_descendants = ckpt_view.descendants()
|
|
for node_id, trackable in nodes_to_restore.items():
|
|
# node_id does not have a corresponding Checkpoint value.
|
|
if (node_id not in ckpt_view_descendants or
|
|
ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
|
|
node_id] is None):
|
|
raise ValueError(
|
|
f"The expected node_id: {node_id} to Trackable {trackable} to "
|
|
"restore does not exist in the checkpoint.")
|
|
# Trackable mapped to node_id to restore is empty.
|
|
if trackable is None or not isinstance(trackable, base.Trackable):
|
|
raise ValueError(
|
|
f"Expecting a valid Trackable to node_id: {node_id} but got "
|
|
f"trackable: {trackable}."
|
|
)
|
|
|
|
serialized_tensors = object_identity.ObjectIdentityDictionary()
|
|
for node_id, current_trackable in nodes_to_restore.items():
|
|
ckpt_contains_serialized_tensors = ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
|
|
node_id].attributes
|
|
node = ckpt_view._object_graph_proto.nodes[node_id] # pylint: disable=protected-access
|
|
trackable_has_serialize_to_tensor = saveable_object_util.trackable_has_serialize_to_tensor(
|
|
current_trackable)
|
|
if not trackable_has_serialize_to_tensor:
|
|
if not node.attributes:
|
|
if saveable_object_util.saveable_objects_from_trackable(
|
|
current_trackable):
|
|
raise ValueError(
|
|
f"Trackable {current_trackable} expects checkpointed values but "
|
|
"checkpoint does not contain serialized tensors for node_id: "
|
|
f"{node_id}.")
|
|
else:
|
|
continue
|
|
object_names = object_identity.ObjectIdentityDictionary()
|
|
object_names[current_trackable] = trackable_utils.extract_object_name(
|
|
node.attributes[0].checkpoint_key)
|
|
checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys(
|
|
object_names, None)
|
|
saveable_objects = save_util_v1.generate_saveable_objects(
|
|
checkpoint_factory_map)[0]
|
|
if len(node.attributes) != len(saveable_objects):
|
|
raise ValueError("Size for saveable_objects for Trackable: "
|
|
f"{len(saveable_objects)} did not match the size for "
|
|
"serialized_tensors for checkpoint: "
|
|
f"{len(node.attributes)}.")
|
|
current_trackable = saveable_object_util.SaveableCompatibilityConverter(
|
|
current_trackable, saveable_objects)
|
|
|
|
serialized_tensors[
|
|
current_trackable] = current_trackable._serialize_to_tensors() # pylint: disable=protected-access
|
|
trackable_expects_ckpted_value = bool(serialized_tensors[current_trackable])
|
|
|
|
if trackable_expects_ckpted_value and not ckpt_contains_serialized_tensors:
|
|
raise ValueError(
|
|
f"Trackable {current_trackable} expects checkpointed values but "
|
|
"checkpoint does not contain serialized tensors for node_id: "
|
|
f"{node_id}.")
|
|
|
|
if not trackable_expects_ckpted_value and ckpt_contains_serialized_tensors:
|
|
raise ValueError(
|
|
f"Trackable {current_trackable} does not expect checkpointed "
|
|
"values but checkpoint contains serialized tensors: "
|
|
f"{ckpt_contains_serialized_tensors} for node_id: {node_id}.")
|
|
|
|
if len(node.attributes) != len(serialized_tensors[current_trackable]):
|
|
raise ValueError("Size for serialized_tensors for Trackable: "
|
|
f"{len(serialized_tensors[current_trackable])} did not "
|
|
"match size for serialized_tensors for checkpoint: "
|
|
f"{len(node.attributes)}.")
|
|
|
|
if not trackable_has_serialize_to_tensor:
|
|
functional_saver.MultiDeviceSaver(serialized_tensors).restore(save_path)
|
|
else:
|
|
# Converts attribute.name to attribute.checkpoint_key since that's what
|
|
# restore method is expecting. i.e., converts "a" to "/.ATTRIBUTES/a".
|
|
serialized_tensors_renamed = object_identity.ObjectIdentityDictionary()
|
|
serialized_tensors_renamed[current_trackable] = {}
|
|
for attribute in node.attributes:
|
|
name = attribute.name
|
|
checkpoint_key = attribute.checkpoint_key
|
|
serialized_tensors_renamed[current_trackable][
|
|
checkpoint_key] = serialized_tensors[current_trackable][name]
|
|
functional_saver.MultiDeviceSaver(serialized_tensors_renamed).restore(
|
|
save_path)
|
|
|
|
|
|
def _queue_children_for_restoration(checkpoint_position, visit_queue):
|
|
"""Queues the restoration of trackable's children or defers them."""
|
|
# pylint: disable=protected-access
|
|
trackable = checkpoint_position.trackable
|
|
trackable_children = trackable._trackable_children()
|
|
for child in checkpoint_position.object_proto.children:
|
|
# trackable._lookup_dependency can be expensive so first check if this node
|
|
# already has an object correspondence. If so we skip this node.
|
|
correspondence = checkpoint_position.checkpoint.object_by_proto_id.get(
|
|
child.node_id, None
|
|
)
|
|
if correspondence is not None:
|
|
continue
|
|
child_position = checkpoint_position.create_child_position(child.node_id)
|
|
local_object = trackable._lookup_dependency(child.local_name,
|
|
trackable_children)
|
|
child_proto = child_position.object_proto
|
|
if local_object is None:
|
|
# We don't yet have a dependency registered with this name. Save it
|
|
# in case we do.
|
|
if child_proto.HasField("has_checkpoint_values"):
|
|
has_value = child_proto.has_checkpoint_values.value
|
|
else:
|
|
# If the field is not set, do a simple check to see if the dependency
|
|
# has children and/or checkpointed values.
|
|
has_value = bool(
|
|
child_proto.children or child_proto.attributes or
|
|
child_proto.slot_variables or
|
|
child_proto.HasField("registered_saver"))
|
|
if has_value:
|
|
trackable._deferred_dependencies.setdefault(child.local_name,
|
|
[]).append(child_position)
|
|
else:
|
|
if child_position.bind_object(trackable=local_object):
|
|
# This object's correspondence is new, so dependencies need to be
|
|
# visited. Delay doing it so that we get a breadth-first dependency
|
|
# resolution order (shallowest paths first). The caller is responsible
|
|
# for emptying visit_queue.
|
|
visit_queue.append((child_position, local_object))
|
|
|
|
|
|
_DeferredSlotVariableRestoration = collections.namedtuple(
|
|
"_DeferredSlotVariableRestoration", [
|
|
"original_variable",
|
|
"slot_variable_id",
|
|
"slot_name",
|
|
])
|
|
|
|
|
|
def _queue_slot_variables(checkpoint_position, visit_queue):
|
|
"""Queues slot variables for restoration."""
|
|
trackable = checkpoint_position.trackable
|
|
checkpoint = checkpoint_position.checkpoint
|
|
for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop(
|
|
checkpoint_position.proto_id, ())):
|
|
slot_variable_position, slot_variable = (
|
|
checkpoint_position.create_slot_variable_position(
|
|
trackable, deferred_slot_restoration.original_variable,
|
|
deferred_slot_restoration.slot_variable_id,
|
|
deferred_slot_restoration.slot_name))
|
|
if slot_variable_position is not None:
|
|
visit_queue.append((slot_variable_position, slot_variable))
|
|
for slot_restoration in checkpoint.slot_restorations.pop(
|
|
checkpoint_position.proto_id, ()):
|
|
optimizer_object = checkpoint.object_by_proto_id.get(
|
|
slot_restoration.optimizer_id, None)
|
|
if optimizer_object is None:
|
|
# The optimizer has not yet been created or tracked. Record in the
|
|
# checkpoint that the slot variables need to be restored when it is.
|
|
checkpoint.deferred_slot_restorations.setdefault(
|
|
slot_restoration.optimizer_id, []).append(
|
|
_DeferredSlotVariableRestoration(
|
|
original_variable=trackable,
|
|
slot_variable_id=slot_restoration.slot_variable_id,
|
|
slot_name=slot_restoration.slot_name))
|
|
|
|
# `optimizer_object` can be a `Checkpoint` when user only needs the
|
|
# attributes the optimizer holds, such as `iterations`. In those cases,
|
|
# it would not have the optimizer's `_create_or_restore_slot_variable`
|
|
# method.
|
|
elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
|
|
slot_variable_position, slot_variable = (
|
|
checkpoint_position.create_slot_variable_position(
|
|
optimizer_object, trackable, slot_restoration.slot_variable_id,
|
|
slot_restoration.slot_name))
|
|
if slot_variable_position is not None:
|
|
visit_queue.append((slot_variable_position, slot_variable))
|
|
|
|
|
|
def _extract_saveable_name(checkpoint_key):
|
|
# Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/"
|
|
search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/"
|
|
return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)]
|
|
|
|
|
|
def _get_saveable_from_factory(saveable_factories, serialized_tensor,
|
|
created_compat_names):
|
|
"""Returns the saveable generated from the factory method."""
|
|
matched_factory = None
|
|
|
|
# The `expected_factory_name` is used to find the right saveable factory,
|
|
# while the `factory_input_name` is the value that is passed to the factory
|
|
# method to instantiate the SaveableObject.
|
|
expected_factory_name = serialized_tensor.name
|
|
factory_input_name = serialized_tensor.checkpoint_key
|
|
|
|
# Case 1: the name already exactly matches a key in saveable_factories.
|
|
if expected_factory_name in saveable_factories:
|
|
matched_factory = saveable_factories[expected_factory_name]
|
|
|
|
# Case 2: (Forward compat) The serialized name is composed of
|
|
# "factory_name" + "SUFFIX". Get the matching factory name.
|
|
if matched_factory is None:
|
|
|
|
for factory_name, factory in saveable_factories.items():
|
|
if expected_factory_name.startswith(factory_name):
|
|
if matched_factory is not None:
|
|
# This condition is met in the extreme edge case where the object
|
|
# returns two saveable factories with similar names. This is very
|
|
# unlikely because there zero objects inside TensorFlow that use
|
|
# more than one saveable factory.
|
|
raise ValueError("Forward compatibility load error: Unable to load "
|
|
"checkpoint saved in future version of TensorFlow. "
|
|
"Please update your version of TensorFlow to the "
|
|
"version in which the checkpoint was saved.")
|
|
|
|
matched_factory = factory
|
|
factory_input_name = _extract_saveable_name(
|
|
serialized_tensor.checkpoint_key) + factory_name
|
|
created_compat_names.add(factory_name)
|
|
|
|
if callable(matched_factory):
|
|
return matched_factory(name=factory_input_name)
|
|
return matched_factory
|