"""Manages a graph of Trackable objects.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import copy import weakref from tensorflow.python.checkpoint import save_util_v1 from tensorflow.python.checkpoint import trackable_view from tensorflow.python.trackable import base from tensorflow.python.util.tf_export import tf_export @tf_export("__internal__.tracking.ObjectGraphView", v1=[]) class ObjectGraphView(trackable_view.TrackableView): """Gathers and serializes an object graph.""" def __init__(self, root, attached_dependencies=None): """Configure the graph view. Args: root: A `Trackable` object whose variables (including the variables of dependencies, recursively) should be saved. May be a weak reference. attached_dependencies: List of dependencies to attach to the root object. Used when saving a Checkpoint with a defined root object. To avoid reference cycles, this should use the WeakTrackableReference class. """ trackable_view.TrackableView.__init__(self, root) # ObjectGraphView should never contain a strong reference to root, since it # may result in a cycle: # root -> deferred dependencies -> CheckpointPosition # -> CheckpointRestoreCoordinator -> ObjectGraphView -> root self._root_ref = (root if isinstance(root, weakref.ref) else weakref.ref(root)) self._attached_dependencies = attached_dependencies def __deepcopy__(self, memo): # By default, weak references are not copied, which leads to surprising # deepcopy behavior. To fix, we first we copy the object itself, then we # make a weak reference to the copy. strong_root = self._root_ref() if strong_root is not None: strong_copy = copy.deepcopy(strong_root, memo) memo[id(self._root_ref)] = weakref.ref(strong_copy) # super() does not have a __deepcopy__, so we need to re-implement it copied = super().__new__(type(self)) memo[id(self)] = copied for key, value in vars(self).items(): setattr(copied, key, copy.deepcopy(value, memo)) return copied def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): """Returns list of all child trackables attached to obj. Args: obj: A `Trackable` object. save_type: A string, can be 'savedmodel' or 'checkpoint'. **kwargs: kwargs to use when retrieving the object's children. Returns: List of all children attached to the object. """ children = [] for name, ref in super(ObjectGraphView, self).children(obj, save_type, **kwargs).items(): children.append(base.TrackableReference(name, ref)) # GraphView objects may define children of the root object that are not # actually attached, e.g. a Checkpoint object's save_counter. if obj is self.root and self._attached_dependencies: children.extend(self._attached_dependencies) return children def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): """Returns all child trackables attached to obj. Args: obj: A `Trackable` object. save_type: A string, can be 'savedmodel' or 'checkpoint'. **kwargs: kwargs to use when retrieving the object's children. Returns: Dictionary of all children attached to the object with name to trackable. """ children = {} for name, ref in self.list_children(obj, **kwargs): children[name] = ref return children @property def attached_dependencies(self): """Returns list of dependencies that should be saved in the checkpoint. These dependencies are not tracked by root, but are in the checkpoint. This is defined when the user creates a Checkpoint with both root and kwargs set. Returns: A list of TrackableReferences. """ return self._attached_dependencies @property def root(self): if isinstance(self._root_ref, weakref.ref): derefed = self._root_ref() assert derefed is not None return derefed else: return self._root_ref def breadth_first_traversal(self): return self._breadth_first_traversal() def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" return super(ObjectGraphView, self)._descendants_with_paths() def serialize_object_graph(self, saveables_cache=None): """Determine checkpoint keys for variables and build a serialized graph. Non-slot variables are keyed based on a shortest path from the root saveable to the object which owns the variable (i.e. the one which called `Trackable._add_variable` to create it). Slot variables are keyed based on a shortest path to the variable being slotted for, a shortest path to their optimizer, and the slot name. Args: saveables_cache: An optional cache storing previously created SaveableObjects created for each Trackable. Maps Trackables to a dictionary of attribute names to Trackable. Returns: A tuple of (named_variables, object_graph_proto, feed_additions): named_variables: A dictionary mapping names to variable objects. object_graph_proto: A TrackableObjectGraph protocol buffer containing the serialized object graph and variable references. feed_additions: A dictionary mapping from Tensors to values which should be fed when saving. Raises: ValueError: If there are invalid characters in an optimizer's slot names. """ named_saveable_objects, object_graph_proto, feed_additions, _ = ( save_util_v1.serialize_object_graph_with_registered_savers( self, saveables_cache)) return named_saveable_objects, object_graph_proto, feed_additions def frozen_saveable_objects(self, object_map=None, to_graph=None, call_with_mapped_captures=None): """Creates SaveableObjects with the current object graph frozen.""" return save_util_v1.frozen_saveables_and_savers( self, object_map, to_graph, call_with_mapped_captures)[0]