3RNN/Lib/site-packages/tensorflow/python/checkpoint/util.py
2024-05-26 19:49:15 +02:00

184 lines
7.6 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for extracting and writing checkpoint info`."""
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.trackable import trackable_utils
from tensorflow.python.util import object_identity
def serialize_slot_variables(trackable_objects, node_ids, object_names):
"""Gather and name slot variables."""
non_slot_objects = list(trackable_objects)
slot_variables = object_identity.ObjectIdentityDictionary()
for trackable in non_slot_objects:
# TODO(b/110718070): Fix Keras imports.
# Note: dir() is used rather than hasattr() here to avoid triggering
# custom __getattr__ code, see b/152031870 for context.
if "get_slot_names" in dir(trackable):
slot_names = trackable.get_slot_names()
for slot_name in slot_names:
for original_variable_node_id, original_variable in enumerate(
non_slot_objects):
try:
slot_variable = trackable.get_slot(original_variable, slot_name)
except (AttributeError, KeyError):
slot_variable = None
if slot_variable is None:
continue
slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access
if slot_variable._trackable_children(): # pylint: disable=protected-access
# TODO(allenl): Gather dependencies of slot variables.
raise NotImplementedError(
"Currently only variables with no dependencies can be saved as "
"slot variables. File a feature request if this limitation "
"bothers you.")
if slot_variable in node_ids:
raise NotImplementedError(
"A slot variable was re-used as a dependency of a Trackable "
f"object: {slot_variable}. This is not currently allowed. "
"File a feature request if this limitation bothers you.")
checkpoint_name = trackable_utils.slot_variable_key(
variable_path=object_names[original_variable],
optimizer_path=object_names[trackable],
slot_name=slot_name)
object_names[slot_variable] = checkpoint_name
slot_variable_node_id = len(trackable_objects)
node_ids[slot_variable] = slot_variable_node_id
trackable_objects.append(slot_variable)
slot_variable_proto = (
trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject
.SlotVariableReference(
slot_name=slot_name,
original_variable_node_id=original_variable_node_id,
slot_variable_node_id=slot_variable_node_id))
slot_variables.setdefault(trackable, []).append(slot_variable_proto)
return slot_variables
def get_mapped_trackable(trackable, object_map):
"""Returns the mapped trackable if possible, otherwise returns trackable."""
if object_map is None:
return trackable
else:
return object_map.get(trackable, trackable)
def get_full_name(var):
"""Gets the full name of variable for name-based checkpoint compatiblity."""
# pylint: disable=protected-access
if (not (isinstance(var, variables.Variable) or
# Some objects do not subclass Variable but still act as one.
resource_variable_ops.is_resource_variable(var))):
return ""
if getattr(var, "_save_slice_info", None) is not None:
# Use getattr because `var._save_slice_info` may be set as `None`.
return var._save_slice_info.full_name
else:
return var._shared_name
# pylint: enable=protected-access
def add_checkpoint_values_check(object_graph_proto):
"""Determines which objects have checkpoint values and save this to the proto.
Args:
object_graph_proto: A `TrackableObjectGraph` proto.
"""
# Trackable -> set of all trackables that depend on it (the "parents").
# If a trackable has checkpoint values, then all of the parents can be
# marked as having checkpoint values.
parents = {}
checkpointed_trackables = object_identity.ObjectIdentitySet()
# First pass: build dictionary of parent objects and initial set of
# checkpointed trackables.
checkpointed_trackables = set()
for node_id, object_proto in enumerate(object_graph_proto.nodes):
if (object_proto.attributes or object_proto.slot_variables or
object_proto.HasField("registered_saver")):
checkpointed_trackables.add(node_id)
for child_proto in object_proto.children:
child = child_proto.node_id
if child not in parents:
parents[child] = set()
parents[child].add(node_id)
# Second pass: add all connected parents to set of checkpointed trackables.
to_visit = set()
to_visit.update(checkpointed_trackables)
while to_visit:
trackable = to_visit.pop()
if trackable not in parents:
# Some trackables may not have parents (e.g. slot variables).
continue
current_parents = parents.pop(trackable)
checkpointed_trackables.update(current_parents)
for parent in current_parents:
if parent in parents:
to_visit.add(parent)
for node_id, object_proto in enumerate(object_graph_proto.nodes):
object_proto.has_checkpoint_values.value = bool(
node_id in checkpointed_trackables)
def objects_ids_and_slot_variables_and_paths(graph_view,
skip_slot_variables=False):
"""Traverse the object graph and list all accessible objects.
Looks for `Trackable` objects which are dependencies of
`root_trackable`. Includes slot variables only if the variable they are
slotting for and the optimizer are dependencies of `root_trackable`
(i.e. if they would be saved with a checkpoint).
Args:
graph_view: A GraphView object.
skip_slot_variables: If True does not return trackables for slot variable.
Default False.
Returns:
A tuple of (trackable objects, paths from root for each object,
object -> node id, slot variables, object_names)
"""
trackable_objects, node_paths = graph_view.breadth_first_traversal()
object_names = object_identity.ObjectIdentityDictionary()
for obj, path in node_paths.items():
object_names[obj] = trackable_utils.object_path_to_string(path)
node_ids = object_identity.ObjectIdentityDictionary()
for node_id, node in enumerate(trackable_objects):
node_ids[node] = node_id
if skip_slot_variables:
slot_variables = object_identity.ObjectIdentityDictionary()
else:
slot_variables = serialize_slot_variables(
trackable_objects=trackable_objects,
node_ids=node_ids,
object_names=object_names,
)
return (trackable_objects, node_paths, node_ids, slot_variables, object_names)
def list_objects(graph_view, skip_slot_variables=False):
"""Traverse the object graph and list all accessible objects."""
trackable_objects = objects_ids_and_slot_variables_and_paths(
graph_view, skip_slot_variables
)[0]
return trackable_objects