3RNN/Lib/site-packages/tensorflow/python/checkpoint/restore.py

726 lines
31 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logic for restoring checkpointed values for Trackables."""
import collections
from tensorflow.python.checkpoint import checkpoint_view
from tensorflow.python.checkpoint import functional_saver
from tensorflow.python.checkpoint import save_util_v1
from tensorflow.python.checkpoint import saveable_compat
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import registration
from tensorflow.python.trackable import base
from tensorflow.python.trackable import constants
from tensorflow.python.trackable import python_state
from tensorflow.python.trackable import trackable_utils
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import object_identity
class CheckpointPosition(object):
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
__slots__ = ["_checkpoint", "_proto_id", "skip_restore"]
def __init__(self, checkpoint, proto_id):
"""Specify an object within a checkpoint.
Args:
checkpoint: A _CheckpointRestoreCoordinator object.
proto_id: The index of this object in TrackableObjectGraph.nodes.
"""
self._checkpoint = checkpoint
self._proto_id = proto_id
# This may be set to True if the registered saver cannot be used with this
# object.
self.skip_restore = False
def restore(self, trackable, reader=None):
"""Restore this value into `trackable`."""
with ops.init_scope():
if self.bind_object(trackable):
# This object's correspondence with a checkpointed object is new, so
# process deferred restorations for it and its dependencies.
restore_ops = self._restore_descendants(reader)
if restore_ops:
self._checkpoint.new_restore_ops(restore_ops)
def bind_object(self, trackable):
"""Set a checkpoint<->object correspondence.
Args:
trackable: The object to record a correspondence for.
Returns:
True if this is a new assignment, False if this object has already been
mapped to a checkpointed `Object` proto.
Raises:
AssertionError: If another object is already bound to the `Object` proto.
"""
checkpoint = self.checkpoint
checkpoint.all_python_objects.add(trackable)
current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
checkpoint.matched_proto_ids.add(self._proto_id)
if current_assignment is None:
checkpoint.object_by_proto_id[self._proto_id] = trackable
return True # New assignment
else:
# The object was already mapped for this checkpoint load, which means
# we don't need to do anything besides check that the mapping is
# consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object).
if current_assignment is not trackable:
logging.warning(
"Inconsistent references when loading the checkpoint into this "
"object graph. For example, in the saved checkpoint object, "
"`model.layer.weight` and `model.layer_copy.weight` reference the "
"same variable, while in the current object these are two different"
" variables. The referenced variables are:"
f"({current_assignment} and {trackable}).")
return False # Not a new assignment
def is_simple_variable(self):
"""Determine whether this value is restorable with a Tensor initializer."""
attributes = self.object_proto.attributes
return (len(attributes) == 1 and
attributes[0].name == constants.VARIABLE_VALUE_KEY and
not self.object_proto.children)
def value_tensors(self, shape_and_slices=None):
"""Create value `Tensor`s for this object's attributes.
Does not require that the Python object has been created. Used for
restore-on-create when executing eagerly.
Args:
shape_and_slices: A dict mapping from object attribute names to a shape
and slice string that will be passed to a RestoreV2 op. If the dict is
None or if an object attribute is not in the dict, the full tensor will
be restored.
Returns:
A dictionary mapping from object attribute names to `Tensor`s.
"""
value_tensors = {}
for serialized_tensor in self.object_proto.attributes:
checkpoint_key = serialized_tensor.checkpoint_key
dtype = self._checkpoint.dtype_map[checkpoint_key]
base_type = dtype.base_dtype
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
with ops.init_scope():
with ops.device(io_device):
# Run the restore itself on the io_device(CPU or specified).
if (shape_and_slices is not None and
serialized_tensor.name in shape_and_slices):
shape_and_slice = shape_and_slices[serialized_tensor.name]
else:
shape_and_slice = ""
value, = io_ops.restore_v2(
prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],
shape_and_slices=[shape_and_slice],
dtypes=[base_type],
name="%s_checkpoint_read" % (serialized_tensor.name,))
# Copy the value to the current device if necessary.
value_tensors[serialized_tensor.name] = array_ops.identity(value)
return value_tensors
def gather_ops_or_named_saveables(self):
"""Looks up or creates SaveableObjects which don't have cached ops.
Returns:
A tuple of (
existing_restore_ops: list,
named_saveables: dict,
python_positions: list,
registered_savers: dict)
"""
recorded_registered_saver = self.get_registered_saver_name()
if not (self.object_proto.attributes or recorded_registered_saver):
return [], {}, [], {}
existing_restore_ops = []
named_saveables = {}
python_positions = []
registered_savers = collections.defaultdict(dict)
saveable_factories = saveable_object_util.saveable_objects_from_trackable(
self.trackable)
saver_name = registration.get_registered_saver_name(self.trackable)
if recorded_registered_saver:
if not self.skip_restore:
name = self.object_proto.registered_saver.object_name
registered_savers[recorded_registered_saver][name] = self.trackable
# Else: Skip restoration of this Trackable. This skip only happens if the
# registered saver has enabled `option_restore`. Otherwise, an error would
# have been raised at `self.get_registered_saver_name()`.
elif saver_name:
# In this case, the checkpoint has a recorded serialized tensor but no
# registered saver, while the Trackable loading the checkpoint has
# migrated to the registered checkpoint functionality (TPUEmbedding is an
# example of this).
# Set the Trackable's object name to the first checkpoint key that is
# stored in checkpoint. If there is a use case that requires the other
# keys, then we can take another look at this.
registered_savers[saver_name] = {
self.object_proto.attributes[0].checkpoint_key: self.trackable
}
elif isinstance(self.trackable, python_state.PythonState):
python_positions.append(self)
elif saveable_factories.keys() == {
trackable_utils.SERIALIZE_TO_TENSORS_NAME
}:
existing_restore_ops, named_saveables = (
self._create_serialize_to_tensor_saveable(saveable_factories))
elif saveable_factories:
existing_restore_ops, named_saveables = (
self._create_saveables_by_attribute_name(saveable_factories))
else:
# If no registered savers were found, then it means that one or more
# serialized tensors were never used.
for serialized_tensor in self.object_proto.attributes:
self._checkpoint.unused_attributes.setdefault(
self._proto_id, []).append(serialized_tensor.name)
return (existing_restore_ops, named_saveables, python_positions,
registered_savers)
def _create_serialize_to_tensor_saveable(self, saveable_factories):
"""Creates a saveable using the _serialize_to_tensor method."""
# Extract the saveable name from the checkpoint key. This will be used as
# the cache key or the name to pass to the saveable factory.
suffix = saveable_compat.get_saveable_name(self.trackable) or ""
saveable_name = _extract_saveable_name(
self.object_proto.attributes[0].checkpoint_key) + suffix
# Try to find the cached saveable (only in graph mode).
if not context.executing_eagerly():
existing_op = self._checkpoint.restore_ops_by_name.get(
saveable_name, None)
if existing_op is not None:
return [existing_op], {}
saveables_cache = self._checkpoint.saveables_cache.setdefault(
self.trackable, {})
if saveable_name in saveables_cache:
return [], {saveable_name: saveables_cache[saveable_name]}
saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME](
name=saveable_name)
if not context.executing_eagerly():
saveables_cache[saveable_name] = saveable
return [], {saveable_name: saveable}
def _create_saveables_by_attribute_name(self, saveable_factories):
"""Creates or caches SaveableObjects by matching the attribute names.
The attribute name keys in the `saveable_factories` is used to find the
corresponding attribute in the object proto. Attributes contain checkpoint
keys which are passed to the factory function to generate the
SaveableObject.
Args:
saveable_factories: a dict mapping attribute name to a callable factory
function that produces a SaveableObject.
Returns:
A tuple of (
existing_restore_ops: list,
named_saveables: dict)
"""
# Name saveables based on the name this object had when it was checkpointed.
named_saveables = {}
existing_restore_ops = []
# Forward compatibility code: when loading a future checkpoint, there may
# be multiple SerializedTensors mapped to a single saveable.
created_compat_names = set()
for serialized_tensor in self.object_proto.attributes:
if context.executing_eagerly():
existing_op = None
else:
existing_op = self._checkpoint.restore_ops_by_name.get(
serialized_tensor.checkpoint_key, None)
if existing_op is not None:
existing_restore_ops.append(existing_op)
continue
if any(serialized_tensor.name.startswith(name)
for name in created_compat_names):
continue # Saveable has already been created for this tensor.
# Only if we don't have cached ops for this SaveableObject, we'll see if
# the SaveableObject itself has been cached. If not, we'll make it, and
# either way we'll extract new ops from it (or if it has Python state to
# restore, we'll run that).
saveables_cache = self._checkpoint.saveables_cache
if saveables_cache is None:
# No SaveableObject caching when executing eagerly.
saveable = None
else:
# If we've already created and cached a SaveableObject for this
# attribute, we can re-use it to avoid re-creating some ops when graph
# building.
saveable_list = saveables_cache.get(self.trackable,
{}).get(serialized_tensor.name,
(None,))
if len(saveable_list) == 1:
# Almost every attribute will have exactly one SaveableObject.
saveable, = saveable_list
else:
# Don't use cached SaveableObjects for partitioned variables, which is
# the only case where we'd have a list of SaveableObjects. Op caching
# will catch them.
saveable = None
if saveable is not None:
# The name of this attribute has changed, so we need to re-generate
# the SaveableObject.
if serialized_tensor.checkpoint_key not in saveable.name:
saveable = None
del saveables_cache[self.trackable]
if saveable is None:
# If there was no cached SaveableObject, create one.
# Use the name to check if the Python object has the same attribute.
saveable = _get_saveable_from_factory(saveable_factories,
serialized_tensor,
created_compat_names)
if saveable is None:
# Purposefully does not throw an exception if attributes have been
# added or deleted. Stores unused attributes so an exception can be
# raised if the user decides to check that everything in the
# checkpoint was loaded.
self._checkpoint.unused_attributes.setdefault(
self._proto_id, []).append(serialized_tensor.name)
continue
if saveables_cache is not None:
saveables_cache.setdefault(self.trackable,
{})[serialized_tensor.name] = [saveable]
named_saveables[serialized_tensor.checkpoint_key] = saveable
return existing_restore_ops, named_saveables
def restore_ops(self, reader=None):
"""Create or fetch restore ops for this object's attributes.
Requires that the `Trackable` Python object has been bound to an object
ID in the checkpoint.
Args:
reader: A `CheckpointReader`. If None, a new instance will be created.
Returns:
A list of operations when graph building, or an empty list when executing
eagerly.
"""
if self._has_registered_saver():
raise ValueError("Unable to run individual checkpoint restore for objects"
" with registered savers.")
(restore_ops, tensor_saveables, python_positions,
_) = self.gather_ops_or_named_saveables()
restore_ops.extend(
self._checkpoint.restore_saveables(
tensor_saveables, python_positions, reader=reader))
return restore_ops
@property
def checkpoint(self):
return self._checkpoint
@property
def trackable(self):
return self._checkpoint.object_by_proto_id[self._proto_id]
@property
def object_proto(self):
return self._checkpoint.object_graph_proto.nodes[self._proto_id]
@property
def proto_id(self):
return self._proto_id
@property
def restore_uid(self):
return self._checkpoint.restore_uid
def __repr__(self):
return repr(self.object_proto)
def value_shape(self):
"""The shape of the VARIABLE_VALUE tensor.
Returns:
If found a TensorShape object, otherwise None.
"""
for serialized_tensor in self.object_proto.attributes:
if serialized_tensor.name == constants.VARIABLE_VALUE_KEY:
return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
return None
def _has_registered_saver(self):
return bool(self.object_proto.registered_saver.name)
def get_registered_saver_name(self):
"""Returns the registered saver name defined in the Checkpoint."""
if self._has_registered_saver():
saver_name = self.object_proto.registered_saver.name
try:
registration.validate_restore_function(self.trackable, saver_name)
except ValueError as e:
if registration.get_strict_predicate_restore(saver_name):
raise e
self.skip_restore = True
return saver_name
return None
def create_slot_variable_position(self, optimizer_object, variable,
slot_variable_id, slot_name):
"""Generates CheckpointPosition for a slot variable.
Args:
optimizer_object: Optimizer that owns the slot variable.
variable: Variable associated with the slot variable.
slot_variable_id: ID of the slot variable.
slot_name: Name of the slot variable.
Returns:
If there is a slot variable in the `optimizer_object` that has not been
bound to the checkpoint, this function returns a tuple of (
new `CheckpointPosition` for the slot variable,
the slot variable itself).
"""
slot_variable_position = CheckpointPosition(
checkpoint=self.checkpoint, proto_id=slot_variable_id)
# pylint: disable=protected-access
slot_variable = optimizer_object._create_or_restore_slot_variable(
slot_variable_position=slot_variable_position,
variable=variable,
slot_name=slot_name)
# pylint: enable=protected-access
if (slot_variable is not None and
slot_variable_position.bind_object(slot_variable)):
return slot_variable_position, slot_variable
else:
return None, None
def create_child_position(self, node_id):
return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id)
def _restore_descendants(self, reader=None):
"""Restore the bound Trackable and dependencies (may be deferred)."""
# Attempt a breadth-first traversal, since presumably the user has more
# control over shorter paths. If we don't have all of the dependencies at
# this point, the end result is not breadth-first (since other deferred
# traversals will happen later).
# You may be wondering why elements in the `visit_queue` are tuples that
# contains both CheckpointPositions and their Trackable. The reason is that
# Optimizers will not keep a strong reference to slot vars for
# ShardedVariables. The slot variable must be kept in memory until the
# restore saveables have been created.
visit_queue = collections.deque([(self, self.trackable)])
restore_ops = []
tensor_saveables = {}
python_positions = []
registered_savers = collections.defaultdict(dict)
while visit_queue:
current_position, _ = visit_queue.popleft()
# Restore using the ops defined in a Saveable or registered function.
(new_restore_ops, new_tensor_saveables, new_python_positions,
new_registered_savers) = current_position._single_restore() # pylint: disable=protected-access
restore_ops.extend(new_restore_ops)
tensor_saveables.update(new_tensor_saveables)
python_positions.extend(new_python_positions)
for saver_name, trackable_map in new_registered_savers.items():
registered_savers[saver_name].update(trackable_map)
# Pass the restoration to the dependencies.
_queue_children_for_restoration(current_position, visit_queue)
_queue_slot_variables(current_position, visit_queue)
restore_ops.extend(
current_position.checkpoint.restore_saveables(
tensor_saveables,
python_positions,
registered_savers,
reader=reader))
return restore_ops
def _single_restore(self):
"""Restores the trackable."""
trackable = self.trackable
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
checkpoint = self.checkpoint
# If the UID of this restore is lower than our current update UID, we don't
# need to actually restore the object.
if checkpoint.restore_uid > trackable._update_uid: # pylint: disable=protected-access
restore_ops, tensor_saveables, python_positions, registered_savers = (
self.gather_ops_or_named_saveables())
trackable._update_uid = checkpoint.restore_uid # pylint: disable=protected-access
else:
restore_ops = ()
tensor_saveables = {}
python_positions = ()
registered_savers = {}
return restore_ops, tensor_saveables, python_positions, registered_savers
def restore_nodes(save_path, nodes_to_restore):
"""Restores nodes from a dict.
Requires that the `Trackable` Python object has been bound to an object
ID in the checkpoint.
Args:
save_path: a string represents path to the checkpoint.
nodes_to_restore: a dict maps `node_id` to `trackable` to be restored.
"""
if save_path is None:
raise ValueError("save_path cannot be empty.")
if not isinstance(nodes_to_restore, dict):
raise ValueError(
"Expecting a dictionary of node_id to Trackable for nodes_to_restore.")
ckpt_view = checkpoint_view.CheckpointView(save_path)
ckpt_view_descendants = ckpt_view.descendants()
for node_id, trackable in nodes_to_restore.items():
# node_id does not have a corresponding Checkpoint value.
if (node_id not in ckpt_view_descendants or
ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
node_id] is None):
raise ValueError(
f"The expected node_id: {node_id} to Trackable {trackable} to "
"restore does not exist in the checkpoint.")
# Trackable mapped to node_id to restore is empty.
if trackable is None or not isinstance(trackable, base.Trackable):
raise ValueError(
f"Expecting a valid Trackable to node_id: {node_id} but got "
f"trackable: {trackable}."
)
serialized_tensors = object_identity.ObjectIdentityDictionary()
for node_id, current_trackable in nodes_to_restore.items():
ckpt_contains_serialized_tensors = ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
node_id].attributes
node = ckpt_view._object_graph_proto.nodes[node_id] # pylint: disable=protected-access
trackable_has_serialize_to_tensor = saveable_object_util.trackable_has_serialize_to_tensor(
current_trackable)
if not trackable_has_serialize_to_tensor:
if not node.attributes:
if saveable_object_util.saveable_objects_from_trackable(
current_trackable):
raise ValueError(
f"Trackable {current_trackable} expects checkpointed values but "
"checkpoint does not contain serialized tensors for node_id: "
f"{node_id}.")
else:
continue
object_names = object_identity.ObjectIdentityDictionary()
object_names[current_trackable] = trackable_utils.extract_object_name(
node.attributes[0].checkpoint_key)
checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys(
object_names, None)
saveable_objects = save_util_v1.generate_saveable_objects(
checkpoint_factory_map)[0]
if len(node.attributes) != len(saveable_objects):
raise ValueError("Size for saveable_objects for Trackable: "
f"{len(saveable_objects)} did not match the size for "
"serialized_tensors for checkpoint: "
f"{len(node.attributes)}.")
current_trackable = saveable_object_util.SaveableCompatibilityConverter(
current_trackable, saveable_objects)
serialized_tensors[
current_trackable] = current_trackable._serialize_to_tensors() # pylint: disable=protected-access
trackable_expects_ckpted_value = bool(serialized_tensors[current_trackable])
if trackable_expects_ckpted_value and not ckpt_contains_serialized_tensors:
raise ValueError(
f"Trackable {current_trackable} expects checkpointed values but "
"checkpoint does not contain serialized tensors for node_id: "
f"{node_id}.")
if not trackable_expects_ckpted_value and ckpt_contains_serialized_tensors:
raise ValueError(
f"Trackable {current_trackable} does not expect checkpointed "
"values but checkpoint contains serialized tensors: "
f"{ckpt_contains_serialized_tensors} for node_id: {node_id}.")
if len(node.attributes) != len(serialized_tensors[current_trackable]):
raise ValueError("Size for serialized_tensors for Trackable: "
f"{len(serialized_tensors[current_trackable])} did not "
"match size for serialized_tensors for checkpoint: "
f"{len(node.attributes)}.")
if not trackable_has_serialize_to_tensor:
functional_saver.MultiDeviceSaver(serialized_tensors).restore(save_path)
else:
# Converts attribute.name to attribute.checkpoint_key since that's what
# restore method is expecting. i.e., converts "a" to "/.ATTRIBUTES/a".
serialized_tensors_renamed = object_identity.ObjectIdentityDictionary()
serialized_tensors_renamed[current_trackable] = {}
for attribute in node.attributes:
name = attribute.name
checkpoint_key = attribute.checkpoint_key
serialized_tensors_renamed[current_trackable][
checkpoint_key] = serialized_tensors[current_trackable][name]
functional_saver.MultiDeviceSaver(serialized_tensors_renamed).restore(
save_path)
def _queue_children_for_restoration(checkpoint_position, visit_queue):
"""Queues the restoration of trackable's children or defers them."""
# pylint: disable=protected-access
trackable = checkpoint_position.trackable
trackable_children = trackable._trackable_children()
for child in checkpoint_position.object_proto.children:
# trackable._lookup_dependency can be expensive so first check if this node
# already has an object correspondence. If so we skip this node.
correspondence = checkpoint_position.checkpoint.object_by_proto_id.get(
child.node_id, None
)
if correspondence is not None:
continue
child_position = checkpoint_position.create_child_position(child.node_id)
local_object = trackable._lookup_dependency(child.local_name,
trackable_children)
child_proto = child_position.object_proto
if local_object is None:
# We don't yet have a dependency registered with this name. Save it
# in case we do.
if child_proto.HasField("has_checkpoint_values"):
has_value = child_proto.has_checkpoint_values.value
else:
# If the field is not set, do a simple check to see if the dependency
# has children and/or checkpointed values.
has_value = bool(
child_proto.children or child_proto.attributes or
child_proto.slot_variables or
child_proto.HasField("registered_saver"))
if has_value:
trackable._deferred_dependencies.setdefault(child.local_name,
[]).append(child_position)
else:
if child_position.bind_object(trackable=local_object):
# This object's correspondence is new, so dependencies need to be
# visited. Delay doing it so that we get a breadth-first dependency
# resolution order (shallowest paths first). The caller is responsible
# for emptying visit_queue.
visit_queue.append((child_position, local_object))
_DeferredSlotVariableRestoration = collections.namedtuple(
"_DeferredSlotVariableRestoration", [
"original_variable",
"slot_variable_id",
"slot_name",
])
def _queue_slot_variables(checkpoint_position, visit_queue):
"""Queues slot variables for restoration."""
trackable = checkpoint_position.trackable
checkpoint = checkpoint_position.checkpoint
for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop(
checkpoint_position.proto_id, ())):
slot_variable_position, slot_variable = (
checkpoint_position.create_slot_variable_position(
trackable, deferred_slot_restoration.original_variable,
deferred_slot_restoration.slot_variable_id,
deferred_slot_restoration.slot_name))
if slot_variable_position is not None:
visit_queue.append((slot_variable_position, slot_variable))
for slot_restoration in checkpoint.slot_restorations.pop(
checkpoint_position.proto_id, ()):
optimizer_object = checkpoint.object_by_proto_id.get(
slot_restoration.optimizer_id, None)
if optimizer_object is None:
# The optimizer has not yet been created or tracked. Record in the
# checkpoint that the slot variables need to be restored when it is.
checkpoint.deferred_slot_restorations.setdefault(
slot_restoration.optimizer_id, []).append(
_DeferredSlotVariableRestoration(
original_variable=trackable,
slot_variable_id=slot_restoration.slot_variable_id,
slot_name=slot_restoration.slot_name))
# `optimizer_object` can be a `Checkpoint` when user only needs the
# attributes the optimizer holds, such as `iterations`. In those cases,
# it would not have the optimizer's `_create_or_restore_slot_variable`
# method.
elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
slot_variable_position, slot_variable = (
checkpoint_position.create_slot_variable_position(
optimizer_object, trackable, slot_restoration.slot_variable_id,
slot_restoration.slot_name))
if slot_variable_position is not None:
visit_queue.append((slot_variable_position, slot_variable))
def _extract_saveable_name(checkpoint_key):
# Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/"
search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/"
return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)]
def _get_saveable_from_factory(saveable_factories, serialized_tensor,
created_compat_names):
"""Returns the saveable generated from the factory method."""
matched_factory = None
# The `expected_factory_name` is used to find the right saveable factory,
# while the `factory_input_name` is the value that is passed to the factory
# method to instantiate the SaveableObject.
expected_factory_name = serialized_tensor.name
factory_input_name = serialized_tensor.checkpoint_key
# Case 1: the name already exactly matches a key in saveable_factories.
if expected_factory_name in saveable_factories:
matched_factory = saveable_factories[expected_factory_name]
# Case 2: (Forward compat) The serialized name is composed of
# "factory_name" + "SUFFIX". Get the matching factory name.
if matched_factory is None:
for factory_name, factory in saveable_factories.items():
if expected_factory_name.startswith(factory_name):
if matched_factory is not None:
# This condition is met in the extreme edge case where the object
# returns two saveable factories with similar names. This is very
# unlikely because there zero objects inside TensorFlow that use
# more than one saveable factory.
raise ValueError("Forward compatibility load error: Unable to load "
"checkpoint saved in future version of TensorFlow. "
"Please update your version of TensorFlow to the "
"version in which the checkpoint was saved.")
matched_factory = factory
factory_input_name = _extract_saveable_name(
serialized_tensor.checkpoint_key) + factory_name
created_compat_names.add(factory_name)
if callable(matched_factory):
return matched_factory(name=factory_input_name)
return matched_factory