118 lines
4.3 KiB
Python
118 lines
4.3 KiB
Python
"""Manages a Trackable object graph."""
|
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import collections
|
|
import weakref
|
|
|
|
from tensorflow.python.trackable import base
|
|
from tensorflow.python.trackable import converter
|
|
from tensorflow.python.util import object_identity
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export("train.TrackableView", v1=[])
|
|
class TrackableView(object):
|
|
"""Gathers and serializes a trackable view.
|
|
|
|
Example usage:
|
|
|
|
>>> class SimpleModule(tf.Module):
|
|
... def __init__(self, name=None):
|
|
... super().__init__(name=name)
|
|
... self.a_var = tf.Variable(5.0)
|
|
... self.b_var = tf.Variable(4.0)
|
|
... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
|
|
|
|
>>> root = SimpleModule(name="root")
|
|
>>> root.leaf = SimpleModule(name="leaf")
|
|
>>> trackable_view = tf.train.TrackableView(root)
|
|
|
|
Pass root to tf.train.TrackableView.children() to get the dictionary of all
|
|
children directly linked to root by name.
|
|
>>> trackable_view_children = trackable_view.children(root)
|
|
>>> for item in trackable_view_children.items():
|
|
... print(item)
|
|
('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
|
|
('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
|
|
('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
|
|
numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
|
|
('leaf', ...)
|
|
|
|
"""
|
|
|
|
def __init__(self, root):
|
|
"""Configure the trackable view.
|
|
|
|
Args:
|
|
root: A `Trackable` object whose variables (including the variables of
|
|
dependencies, recursively) should be saved. May be a weak reference.
|
|
"""
|
|
# TrackableView should never contain a strong reference to root, since it
|
|
# may result in a cycle:
|
|
# root -> deferred dependencies -> CheckpointPosition
|
|
# -> CheckpointRestoreCoordinator -> TrackableView -> root
|
|
self._root_ref = (root if isinstance(root, weakref.ref)
|
|
else weakref.ref(root))
|
|
|
|
@classmethod
|
|
def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
|
|
"""Returns all child trackables attached to obj.
|
|
|
|
Args:
|
|
obj: A `Trackable` object.
|
|
save_type: A string, can be 'savedmodel' or 'checkpoint'.
|
|
**kwargs: kwargs to use when retrieving the object's children.
|
|
|
|
Returns:
|
|
Dictionary of all children attached to the object with name to trackable.
|
|
"""
|
|
# pylint: disable=protected-access
|
|
obj._maybe_initialize_trackable()
|
|
children = {}
|
|
for name, ref in obj._trackable_children(save_type, **kwargs).items():
|
|
ref = converter.convert_to_trackable(ref, parent=obj)
|
|
children[name] = ref
|
|
return children
|
|
|
|
@property
|
|
def root(self):
|
|
if isinstance(self._root_ref, weakref.ref):
|
|
derefed = self._root_ref()
|
|
assert derefed is not None
|
|
return derefed
|
|
else:
|
|
return self._root_ref
|
|
|
|
def descendants(self):
|
|
"""Returns a list of all nodes from self.root using a breadth first traversal."""
|
|
return self._descendants_with_paths()[0]
|
|
|
|
def _descendants_with_paths(self):
|
|
"""Returns a list of all nodes and its paths from self.root using a breadth first traversal."""
|
|
bfs_sorted = []
|
|
to_visit = collections.deque([self.root])
|
|
node_paths = object_identity.ObjectIdentityDictionary()
|
|
node_paths[self.root] = ()
|
|
while to_visit:
|
|
current_trackable = to_visit.popleft()
|
|
bfs_sorted.append(current_trackable)
|
|
for name, dependency in self.children(current_trackable).items():
|
|
if dependency not in node_paths:
|
|
node_paths[dependency] = (
|
|
node_paths[current_trackable] +
|
|
(base.TrackableReference(name, dependency),))
|
|
to_visit.append(dependency)
|
|
return bfs_sorted, node_paths
|