3RNN/Lib/site-packages/tensorflow/python/checkpoint/checkpoint_view.py
2024-05-26 19:49:15 +02:00

302 lines
12 KiB
Python

"""Manages a Checkpoint View."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python.checkpoint import trackable_view
from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.trackable import base
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
@tf_export("train.CheckpointView", v1=[])
class CheckpointView(object):
"""Gathers and serializes a checkpoint view.
This is for loading specific portions of a module from a
checkpoint, and be able to compare two modules by matching components.
Example usage:
>>> class SimpleModule(tf.Module):
... def __init__(self, name=None):
... super().__init__(name=name)
... self.a_var = tf.Variable(5.0)
... self.b_var = tf.Variable(4.0)
... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
>>> root = SimpleModule(name="root")
>>> root.leaf = SimpleModule(name="leaf")
>>> ckpt = tf.train.Checkpoint(root)
>>> save_path = ckpt.save('/tmp/tf_ckpts')
>>> checkpoint_view = tf.train.CheckpointView(save_path)
Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary
of all children directly linked to the checkpoint root.
>>> for name, node_id in checkpoint_view.children(0).items():
... print(f"- name: '{name}', node_id: {node_id}")
- name: 'a_var', node_id: 1
- name: 'b_var', node_id: 2
- name: 'vars', node_id: 3
- name: 'leaf', node_id: 4
- name: 'root', node_id: 0
- name: 'save_counter', node_id: 5
"""
def __init__(self, save_path):
"""Configure the checkpoint view.
Args:
save_path: The path to the checkpoint.
Raises:
ValueError: If the save_path does not lead to a TF2 checkpoint.
"""
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
try:
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError as not_found_error:
raise ValueError(
f"The specified checkpoint \"{save_path}\" does not appear to be "
"object-based (saved with TF2) since it is missing the key "
f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
"TF1 name-based saver and does not contain an object dependency graph."
) from not_found_error
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
self._object_graph_proto = object_graph_proto
def children(self, node_id):
"""Returns all child trackables attached to obj.
Args:
node_id: Id of the node to return its children.
Returns:
Dictionary of all children attached to the object with name to node_id.
"""
return {
child.local_name: child.node_id
for child in self._object_graph_proto.nodes[node_id].children
}
def descendants(self):
"""Returns a list of trackables by node_id attached to obj."""
return list(self._descendants_with_paths().keys())
def _descendants_with_paths(self):
"""Returns a dict of descendants by node_id and paths to node.
The names returned by this private method are subject to change.
"""
all_nodes_with_paths = {}
to_visit = collections.deque([0])
# node_id:0 will always be "root".
all_nodes_with_paths[0] = "root"
path = all_nodes_with_paths.get(0)
while to_visit:
node_id = to_visit.popleft()
obj = self._object_graph_proto.nodes[node_id]
for child in obj.children:
if child.node_id == 0 or child.node_id in all_nodes_with_paths.keys():
continue
path = all_nodes_with_paths.get(node_id)
if child.node_id not in all_nodes_with_paths.keys():
to_visit.append(child.node_id)
all_nodes_with_paths[child.node_id] = path + "." + child.local_name
return all_nodes_with_paths
def match(self, obj):
"""Returns all matching trackables between CheckpointView and Trackable.
Matching trackables represents trackables with the same name and position in
graph.
Args:
obj: `Trackable` root.
Returns:
Dictionary containing all overlapping trackables that maps `node_id` to
`Trackable`.
Example usage:
>>> class SimpleModule(tf.Module):
... def __init__(self, name=None):
... super().__init__(name=name)
... self.a_var = tf.Variable(5.0)
... self.b_var = tf.Variable(4.0)
... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
>>> root = SimpleModule(name="root")
>>> leaf = root.leaf = SimpleModule(name="leaf")
>>> leaf.leaf3 = tf.Variable(6.0, name="leaf3")
>>> leaf.leaf4 = tf.Variable(7.0, name="leaf4")
>>> ckpt = tf.train.Checkpoint(root)
>>> save_path = ckpt.save('/tmp/tf_ckpts')
>>> checkpoint_view = tf.train.CheckpointView(save_path)
>>> root2 = SimpleModule(name="root")
>>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2")
>>> leaf2.leaf3 = tf.Variable(6.0)
>>> leaf2.leaf4 = tf.Variable(7.0)
Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the
dictionary of all children directly linked to the checkpoint root.
>>> checkpoint_view_match = checkpoint_view.match(root2).items()
>>> for item in checkpoint_view_match:
... print(item)
(0, ...)
(1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
(2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
(3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
(6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)
(7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
"""
if not isinstance(obj, base.Trackable):
raise ValueError(f"Expected a Trackable, got {obj} of type {type(obj)}.")
overlapping_nodes = {}
# Root node is always matched.
overlapping_nodes[0] = obj
# Queue of tuples of node_id and trackable.
to_visit = collections.deque([(0, obj)])
visited = set()
view = trackable_view.TrackableView(obj)
while to_visit:
current_node_id, current_trackable = to_visit.popleft()
trackable_children = view.children(current_trackable)
for child_name, child_node_id in self.children(current_node_id).items():
if child_node_id in visited or child_node_id == 0:
continue
if child_name in trackable_children:
current_assignment = overlapping_nodes.get(child_node_id)
if current_assignment is None:
overlapping_nodes[child_node_id] = trackable_children[child_name]
to_visit.append((child_node_id, trackable_children[child_name]))
else:
# The object was already mapped for this checkpoint load, which
# means we don't need to do anything besides check that the mapping
# is consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object).
if current_assignment is not trackable_children[child_name]:
logging.warning(
"Inconsistent references when matching the checkpoint into "
"this object graph. The referenced objects are: "
f"({current_assignment} and "
f"{trackable_children[child_name]}).")
visited.add(current_node_id)
return overlapping_nodes
def diff(self, obj):
"""Returns diff between CheckpointView and Trackable.
This method is intended to be used to compare the object stored in a
checkpoint vs a live model in Python. For example, if checkpoint
restoration fails the `assert_consumed()` or
`assert_existing_objects_matched()` checks, you can use this to list out
the objects/checkpoint nodes which were not restored.
Example Usage:
>>> class SimpleModule(tf.Module):
... def __init__(self, name=None):
... super().__init__(name=name)
... self.a_var = tf.Variable(5.0)
... self.b_var = tf.Variable(4.0)
... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
>>> root = SimpleModule(name="root")
>>> leaf = root.leaf = SimpleModule(name="leaf")
>>> leaf.leaf3 = tf.Variable(6.0, name="leaf3")
>>> leaf.leaf4 = tf.Variable(7.0, name="leaf4")
>>> ckpt = tf.train.Checkpoint(root)
>>> save_path = ckpt.save('/tmp/tf_ckpts')
>>> checkpoint_view = tf.train.CheckpointView(save_path)
>>> root2 = SimpleModule(name="root")
>>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2")
>>> leaf2.leaf3 = tf.Variable(6.0)
>>> leaf2.leaf4 = tf.Variable(7.0)
Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the
dictionary of all children directly linked to the checkpoint root.
>>> checkpoint_view_diff = checkpoint_view.diff(root2)
>>> checkpoint_view_match = checkpoint_view_diff[0].items()
>>> for item in checkpoint_view_match:
... print(item)
(0, ...)
(1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
(2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
(3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
(6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)
(7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
>>> only_in_checkpoint_view = checkpoint_view_diff[1]
>>> print(only_in_checkpoint_view)
[4, 5, 8, 9, 10, 11, 12, 13, 14]
>>> only_in_trackable = checkpoint_view_diff[2]
>>> print(only_in_trackable)
[..., <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>,
ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]),
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=6.0>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]
Args:
obj: `Trackable` root.
Returns:
Tuple of (
- Overlaps: Dictionary containing all overlapping trackables that maps
`node_id` to `Trackable`, same as CheckpointView.match().
- Only in CheckpointView: List of `node_id` that only exist in
CheckpointView.
- Only in Trackable: List of `Trackable` that only exist in Trackable.
)
"""
overlapping_nodes = self.match(obj)
only_in_checkpoint_view = []
only_in_trackable = []
for node_id in self.descendants():
if node_id not in overlapping_nodes.keys():
only_in_checkpoint_view.append(node_id)
for trackable in trackable_view.TrackableView(obj).descendants():
if trackable not in object_identity.ObjectIdentitySet(
overlapping_nodes.values()):
only_in_trackable.append(trackable)
return overlapping_nodes, only_in_checkpoint_view, only_in_trackable