168 lines
6.6 KiB
Python
168 lines
6.6 KiB
Python
"""Manages a graph of Trackable objects."""
|
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import copy
|
|
import weakref
|
|
|
|
from tensorflow.python.checkpoint import save_util_v1
|
|
from tensorflow.python.checkpoint import trackable_view
|
|
from tensorflow.python.trackable import base
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export("__internal__.tracking.ObjectGraphView", v1=[])
|
|
class ObjectGraphView(trackable_view.TrackableView):
|
|
"""Gathers and serializes an object graph."""
|
|
|
|
def __init__(self, root, attached_dependencies=None):
|
|
"""Configure the graph view.
|
|
|
|
Args:
|
|
root: A `Trackable` object whose variables (including the variables of
|
|
dependencies, recursively) should be saved. May be a weak reference.
|
|
attached_dependencies: List of dependencies to attach to the root object.
|
|
Used when saving a Checkpoint with a defined root object. To avoid
|
|
reference cycles, this should use the WeakTrackableReference class.
|
|
"""
|
|
trackable_view.TrackableView.__init__(self, root)
|
|
# ObjectGraphView should never contain a strong reference to root, since it
|
|
# may result in a cycle:
|
|
# root -> deferred dependencies -> CheckpointPosition
|
|
# -> CheckpointRestoreCoordinator -> ObjectGraphView -> root
|
|
self._root_ref = (root if isinstance(root, weakref.ref)
|
|
else weakref.ref(root))
|
|
self._attached_dependencies = attached_dependencies
|
|
|
|
def __deepcopy__(self, memo):
|
|
# By default, weak references are not copied, which leads to surprising
|
|
# deepcopy behavior. To fix, we first we copy the object itself, then we
|
|
# make a weak reference to the copy.
|
|
strong_root = self._root_ref()
|
|
if strong_root is not None:
|
|
strong_copy = copy.deepcopy(strong_root, memo)
|
|
memo[id(self._root_ref)] = weakref.ref(strong_copy)
|
|
# super() does not have a __deepcopy__, so we need to re-implement it
|
|
copied = super().__new__(type(self))
|
|
memo[id(self)] = copied
|
|
for key, value in vars(self).items():
|
|
setattr(copied, key, copy.deepcopy(value, memo))
|
|
return copied
|
|
|
|
def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
|
|
"""Returns list of all child trackables attached to obj.
|
|
|
|
Args:
|
|
obj: A `Trackable` object.
|
|
save_type: A string, can be 'savedmodel' or 'checkpoint'.
|
|
**kwargs: kwargs to use when retrieving the object's children.
|
|
|
|
Returns:
|
|
List of all children attached to the object.
|
|
"""
|
|
children = []
|
|
for name, ref in super(ObjectGraphView,
|
|
self).children(obj, save_type, **kwargs).items():
|
|
children.append(base.TrackableReference(name, ref))
|
|
|
|
# GraphView objects may define children of the root object that are not
|
|
# actually attached, e.g. a Checkpoint object's save_counter.
|
|
if obj is self.root and self._attached_dependencies:
|
|
children.extend(self._attached_dependencies)
|
|
return children
|
|
|
|
def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
|
|
"""Returns all child trackables attached to obj.
|
|
|
|
Args:
|
|
obj: A `Trackable` object.
|
|
save_type: A string, can be 'savedmodel' or 'checkpoint'.
|
|
**kwargs: kwargs to use when retrieving the object's children.
|
|
|
|
Returns:
|
|
Dictionary of all children attached to the object with name to trackable.
|
|
"""
|
|
children = {}
|
|
for name, ref in self.list_children(obj, **kwargs):
|
|
children[name] = ref
|
|
return children
|
|
|
|
@property
|
|
def attached_dependencies(self):
|
|
"""Returns list of dependencies that should be saved in the checkpoint.
|
|
|
|
These dependencies are not tracked by root, but are in the checkpoint.
|
|
This is defined when the user creates a Checkpoint with both root and kwargs
|
|
set.
|
|
|
|
Returns:
|
|
A list of TrackableReferences.
|
|
"""
|
|
return self._attached_dependencies
|
|
|
|
@property
|
|
def root(self):
|
|
if isinstance(self._root_ref, weakref.ref):
|
|
derefed = self._root_ref()
|
|
assert derefed is not None
|
|
return derefed
|
|
else:
|
|
return self._root_ref
|
|
|
|
def breadth_first_traversal(self):
|
|
return self._breadth_first_traversal()
|
|
|
|
def _breadth_first_traversal(self):
|
|
"""Find shortest paths to all dependencies of self.root."""
|
|
return super(ObjectGraphView, self)._descendants_with_paths()
|
|
|
|
def serialize_object_graph(self, saveables_cache=None):
|
|
"""Determine checkpoint keys for variables and build a serialized graph.
|
|
|
|
Non-slot variables are keyed based on a shortest path from the root saveable
|
|
to the object which owns the variable (i.e. the one which called
|
|
`Trackable._add_variable` to create it).
|
|
|
|
Slot variables are keyed based on a shortest path to the variable being
|
|
slotted for, a shortest path to their optimizer, and the slot name.
|
|
|
|
Args:
|
|
saveables_cache: An optional cache storing previously created
|
|
SaveableObjects created for each Trackable. Maps Trackables to a
|
|
dictionary of attribute names to Trackable.
|
|
|
|
Returns:
|
|
A tuple of (named_variables, object_graph_proto, feed_additions):
|
|
named_variables: A dictionary mapping names to variable objects.
|
|
object_graph_proto: A TrackableObjectGraph protocol buffer
|
|
containing the serialized object graph and variable references.
|
|
feed_additions: A dictionary mapping from Tensors to values which should
|
|
be fed when saving.
|
|
|
|
Raises:
|
|
ValueError: If there are invalid characters in an optimizer's slot names.
|
|
"""
|
|
named_saveable_objects, object_graph_proto, feed_additions, _ = (
|
|
save_util_v1.serialize_object_graph_with_registered_savers(
|
|
self, saveables_cache))
|
|
return named_saveable_objects, object_graph_proto, feed_additions
|
|
|
|
def frozen_saveable_objects(self,
|
|
object_map=None,
|
|
to_graph=None,
|
|
call_with_mapped_captures=None):
|
|
"""Creates SaveableObjects with the current object graph frozen."""
|
|
return save_util_v1.frozen_saveables_and_savers(
|
|
self, object_map, to_graph, call_with_mapped_captures)[0]
|