"""Manages a Checkpoint View.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import collections from tensorflow.core.protobuf import trackable_object_graph_pb2 from tensorflow.python.checkpoint import trackable_view from tensorflow.python.framework import errors_impl from tensorflow.python.platform import tf_logging as logging from tensorflow.python.trackable import base from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import tf_export @tf_export("train.CheckpointView", v1=[]) class CheckpointView(object): """Gathers and serializes a checkpoint view. This is for loading specific portions of a module from a checkpoint, and be able to compare two modules by matching components. Example usage: >>> class SimpleModule(tf.Module): ... def __init__(self, name=None): ... super().__init__(name=name) ... self.a_var = tf.Variable(5.0) ... self.b_var = tf.Variable(4.0) ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] >>> root = SimpleModule(name="root") >>> root.leaf = SimpleModule(name="leaf") >>> ckpt = tf.train.Checkpoint(root) >>> save_path = ckpt.save('/tmp/tf_ckpts') >>> checkpoint_view = tf.train.CheckpointView(save_path) Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary of all children directly linked to the checkpoint root. >>> for name, node_id in checkpoint_view.children(0).items(): ... print(f"- name: '{name}', node_id: {node_id}") - name: 'a_var', node_id: 1 - name: 'b_var', node_id: 2 - name: 'vars', node_id: 3 - name: 'leaf', node_id: 4 - name: 'root', node_id: 0 - name: 'save_counter', node_id: 5 """ def __init__(self, save_path): """Configure the checkpoint view. Args: save_path: The path to the checkpoint. Raises: ValueError: If the save_path does not lead to a TF2 checkpoint. """ reader = py_checkpoint_reader.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError as not_found_error: raise ValueError( f"The specified checkpoint \"{save_path}\" does not appear to be " "object-based (saved with TF2) since it is missing the key " f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " "TF1 name-based saver and does not contain an object dependency graph." ) from not_found_error object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) self._object_graph_proto = object_graph_proto def children(self, node_id): """Returns all child trackables attached to obj. Args: node_id: Id of the node to return its children. Returns: Dictionary of all children attached to the object with name to node_id. """ return { child.local_name: child.node_id for child in self._object_graph_proto.nodes[node_id].children } def descendants(self): """Returns a list of trackables by node_id attached to obj.""" return list(self._descendants_with_paths().keys()) def _descendants_with_paths(self): """Returns a dict of descendants by node_id and paths to node. The names returned by this private method are subject to change. """ all_nodes_with_paths = {} to_visit = collections.deque([0]) # node_id:0 will always be "root". all_nodes_with_paths[0] = "root" path = all_nodes_with_paths.get(0) while to_visit: node_id = to_visit.popleft() obj = self._object_graph_proto.nodes[node_id] for child in obj.children: if child.node_id == 0 or child.node_id in all_nodes_with_paths.keys(): continue path = all_nodes_with_paths.get(node_id) if child.node_id not in all_nodes_with_paths.keys(): to_visit.append(child.node_id) all_nodes_with_paths[child.node_id] = path + "." + child.local_name return all_nodes_with_paths def match(self, obj): """Returns all matching trackables between CheckpointView and Trackable. Matching trackables represents trackables with the same name and position in graph. Args: obj: `Trackable` root. Returns: Dictionary containing all overlapping trackables that maps `node_id` to `Trackable`. Example usage: >>> class SimpleModule(tf.Module): ... def __init__(self, name=None): ... super().__init__(name=name) ... self.a_var = tf.Variable(5.0) ... self.b_var = tf.Variable(4.0) ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] >>> root = SimpleModule(name="root") >>> leaf = root.leaf = SimpleModule(name="leaf") >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3") >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4") >>> ckpt = tf.train.Checkpoint(root) >>> save_path = ckpt.save('/tmp/tf_ckpts') >>> checkpoint_view = tf.train.CheckpointView(save_path) >>> root2 = SimpleModule(name="root") >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2") >>> leaf2.leaf3 = tf.Variable(6.0) >>> leaf2.leaf4 = tf.Variable(7.0) Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary of all children directly linked to the checkpoint root. >>> checkpoint_view_match = checkpoint_view.match(root2).items() >>> for item in checkpoint_view_match: ... print(item) (0, ...) (1, ) (2, ) (3, ListWrapper([, ])) (6, ) (7, ) """ if not isinstance(obj, base.Trackable): raise ValueError(f"Expected a Trackable, got {obj} of type {type(obj)}.") overlapping_nodes = {} # Root node is always matched. overlapping_nodes[0] = obj # Queue of tuples of node_id and trackable. to_visit = collections.deque([(0, obj)]) visited = set() view = trackable_view.TrackableView(obj) while to_visit: current_node_id, current_trackable = to_visit.popleft() trackable_children = view.children(current_trackable) for child_name, child_node_id in self.children(current_node_id).items(): if child_node_id in visited or child_node_id == 0: continue if child_name in trackable_children: current_assignment = overlapping_nodes.get(child_node_id) if current_assignment is None: overlapping_nodes[child_node_id] = trackable_children[child_name] to_visit.append((child_node_id, trackable_children[child_name])) else: # The object was already mapped for this checkpoint load, which # means we don't need to do anything besides check that the mapping # is consistent (if the dependency DAG is not a tree then there are # multiple paths to the same object). if current_assignment is not trackable_children[child_name]: logging.warning( "Inconsistent references when matching the checkpoint into " "this object graph. The referenced objects are: " f"({current_assignment} and " f"{trackable_children[child_name]}).") visited.add(current_node_id) return overlapping_nodes def diff(self, obj): """Returns diff between CheckpointView and Trackable. This method is intended to be used to compare the object stored in a checkpoint vs a live model in Python. For example, if checkpoint restoration fails the `assert_consumed()` or `assert_existing_objects_matched()` checks, you can use this to list out the objects/checkpoint nodes which were not restored. Example Usage: >>> class SimpleModule(tf.Module): ... def __init__(self, name=None): ... super().__init__(name=name) ... self.a_var = tf.Variable(5.0) ... self.b_var = tf.Variable(4.0) ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] >>> root = SimpleModule(name="root") >>> leaf = root.leaf = SimpleModule(name="leaf") >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3") >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4") >>> ckpt = tf.train.Checkpoint(root) >>> save_path = ckpt.save('/tmp/tf_ckpts') >>> checkpoint_view = tf.train.CheckpointView(save_path) >>> root2 = SimpleModule(name="root") >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2") >>> leaf2.leaf3 = tf.Variable(6.0) >>> leaf2.leaf4 = tf.Variable(7.0) Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary of all children directly linked to the checkpoint root. >>> checkpoint_view_diff = checkpoint_view.diff(root2) >>> checkpoint_view_match = checkpoint_view_diff[0].items() >>> for item in checkpoint_view_match: ... print(item) (0, ...) (1, ) (2, ) (3, ListWrapper([, ])) (6, ) (7, ) >>> only_in_checkpoint_view = checkpoint_view_diff[1] >>> print(only_in_checkpoint_view) [4, 5, 8, 9, 10, 11, 12, 13, 14] >>> only_in_trackable = checkpoint_view_diff[2] >>> print(only_in_trackable) [..., , , ListWrapper([, ]), , , , ] Args: obj: `Trackable` root. Returns: Tuple of ( - Overlaps: Dictionary containing all overlapping trackables that maps `node_id` to `Trackable`, same as CheckpointView.match(). - Only in CheckpointView: List of `node_id` that only exist in CheckpointView. - Only in Trackable: List of `Trackable` that only exist in Trackable. ) """ overlapping_nodes = self.match(obj) only_in_checkpoint_view = [] only_in_trackable = [] for node_id in self.descendants(): if node_id not in overlapping_nodes.keys(): only_in_checkpoint_view.append(node_id) for trackable in trackable_view.TrackableView(obj).descendants(): if trackable not in object_identity.ObjectIdentitySet( overlapping_nodes.values()): only_in_trackable.append(trackable) return overlapping_nodes, only_in_checkpoint_view, only_in_trackable