345 lines
14 KiB
Python
345 lines
14 KiB
Python
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Extracts tensors for checkpointing while updating a TrackableObjectGraph.
|
||
|
|
||
|
The tensors are extracted from `Trackable._serialize_to_tensors`.
|
||
|
"""
|
||
|
import collections
|
||
|
|
||
|
from typing import Any, Callable, List, Optional, Tuple, Mapping, Union, Dict
|
||
|
|
||
|
from tensorflow.core.protobuf import trackable_object_graph_pb2
|
||
|
from tensorflow.python.checkpoint import graph_view as graph_view_lib
|
||
|
from tensorflow.python.checkpoint import save_util_v1
|
||
|
from tensorflow.python.checkpoint import saveable_compat
|
||
|
from tensorflow.python.checkpoint import util
|
||
|
from tensorflow.python.framework import constant_op
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.saved_model import registration
|
||
|
from tensorflow.python.trackable import base
|
||
|
from tensorflow.python.trackable import python_state
|
||
|
from tensorflow.python.trackable import trackable_utils
|
||
|
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
|
||
|
from tensorflow.python.training.saving import saveable_object_util
|
||
|
from tensorflow.python.types import core
|
||
|
from tensorflow.python.util import object_identity
|
||
|
|
||
|
# Attributes for each Trackable in the checkpointed object graph.
|
||
|
_TrackableData = collections.namedtuple("_TrackableData", [
|
||
|
# A trackable in the root Trackable object graph.
|
||
|
"trackable",
|
||
|
# The index at which the Trackable appears in TrackableObjectGraph.nodes.
|
||
|
"node_id",
|
||
|
# The BFS-generated path from the root object / used to generate readable
|
||
|
# checkpoint keys.
|
||
|
"object_name",
|
||
|
# A list of ObjectReference for each child connected to this Trackable.
|
||
|
"children_proto",
|
||
|
# A list of SlotVariableReference to save to the object (only valid for
|
||
|
# Optimizer objects).
|
||
|
"slot_variable_proto",
|
||
|
# The object to save to checkpoint. Usually this is the same as `trackable`,
|
||
|
# but can differ when the the caller wants to specify a different object to
|
||
|
# save. For example, when saving checkpoints asynchronously, variables are
|
||
|
# copied to the CPU. `object_to_save` is set as the copied variable.
|
||
|
"object_to_save",
|
||
|
])
|
||
|
|
||
|
|
||
|
def _split_trackables(
|
||
|
trackable_data: List[_TrackableData]
|
||
|
) -> Tuple[List[_TrackableData], List[_TrackableData],
|
||
|
Dict[str, List[_TrackableData]]]:
|
||
|
"""Splits Trackables into 3 categories (tensor/pystate/registered)."""
|
||
|
tensor_trackables = []
|
||
|
pystate_trackables = []
|
||
|
registered_trackables = collections.defaultdict(list)
|
||
|
|
||
|
for td in trackable_data:
|
||
|
saver_name = registration.get_registered_saver_name(td.object_to_save)
|
||
|
if isinstance(td.object_to_save, python_state.PythonState):
|
||
|
pystate_trackables.append(td)
|
||
|
elif saver_name:
|
||
|
registered_trackables[saver_name].append(td)
|
||
|
else:
|
||
|
tensor_trackables.append(td)
|
||
|
|
||
|
return tensor_trackables, pystate_trackables, registered_trackables
|
||
|
|
||
|
|
||
|
def _gather_trackable_data(
|
||
|
graph_view: graph_view_lib.ObjectGraphView,
|
||
|
object_map: Mapping[base.Trackable, base.Trackable]
|
||
|
) -> Tuple[List[_TrackableData], Dict[base.Trackable, int]]:
|
||
|
"""Returns a list of generated TrackableData based on the ObjectGraphView."""
|
||
|
trackable_objects, node_paths = graph_view.breadth_first_traversal()
|
||
|
object_names = object_identity.ObjectIdentityDictionary()
|
||
|
for obj, path in node_paths.items():
|
||
|
object_names[obj] = trackable_utils.object_path_to_string(path)
|
||
|
node_ids = object_identity.ObjectIdentityDictionary()
|
||
|
for node_id, node in enumerate(trackable_objects):
|
||
|
node_ids[node] = node_id
|
||
|
slot_variables = util.serialize_slot_variables(
|
||
|
trackable_objects=trackable_objects,
|
||
|
node_ids=node_ids,
|
||
|
object_names=object_names)
|
||
|
trackable_data = []
|
||
|
for trackable in trackable_objects:
|
||
|
children_proto = []
|
||
|
for child in graph_view.list_children(trackable):
|
||
|
children_proto.append(
|
||
|
trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject
|
||
|
.ObjectReference(node_id=node_ids[child.ref],
|
||
|
local_name=child.name))
|
||
|
|
||
|
trackable_data.append(_TrackableData(
|
||
|
trackable,
|
||
|
node_id=node_ids[trackable],
|
||
|
object_name=object_names[trackable],
|
||
|
children_proto=children_proto,
|
||
|
slot_variable_proto=slot_variables.get(trackable, []),
|
||
|
object_to_save=util.get_mapped_trackable(trackable, object_map)))
|
||
|
return trackable_data, node_ids
|
||
|
|
||
|
|
||
|
def _fill_object_graph_proto(
|
||
|
trackable_data: List[_TrackableData]
|
||
|
) -> trackable_object_graph_pb2.TrackableObjectGraph:
|
||
|
"""Name non-slot `Trackable`s and add them to `object_graph_proto`."""
|
||
|
object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
|
||
|
for checkpoint_id, td in enumerate(trackable_data):
|
||
|
assert td.node_id == checkpoint_id
|
||
|
object_graph_proto.nodes.add(
|
||
|
slot_variables=td.slot_variable_proto,
|
||
|
children=td.children_proto)
|
||
|
return object_graph_proto
|
||
|
|
||
|
|
||
|
def _get_and_write_tensors_to_serialize(
|
||
|
tensor_trackables: List[_TrackableData],
|
||
|
node_ids: Dict[base.Trackable, int],
|
||
|
call_with_mapped_captures: Union[Callable[..., Any], None],
|
||
|
cache: Union[Dict[base.Trackable, any], None],
|
||
|
object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
|
||
|
) -> Dict[base.Trackable, Any]:
|
||
|
"""Creates dictionary of tensors to checkpoint, and updates the proto."""
|
||
|
# Maps trackable to the a dictionary of tensors, which maps
|
||
|
# checkpoint key (-> slice_spec) -> tensor.
|
||
|
serialized_tensors = object_identity.ObjectIdentityDictionary()
|
||
|
|
||
|
for td in tensor_trackables:
|
||
|
if cache is not None and td.object_to_save in cache:
|
||
|
trackable, tensor_dict, object_proto = cache[td.object_to_save]
|
||
|
serialized_tensors[trackable] = tensor_dict
|
||
|
object_graph_proto.nodes[td.node_id].attributes.MergeFrom(object_proto)
|
||
|
continue
|
||
|
|
||
|
legacy_name = saveable_compat.get_saveable_name(td.object_to_save) or ""
|
||
|
|
||
|
if (not saveable_object_util.trackable_has_serialize_to_tensor(
|
||
|
td.object_to_save) or
|
||
|
legacy_name):
|
||
|
# Use the legacy code path for objects that are using SaveableObjects
|
||
|
# or the compat saveable name decorator.
|
||
|
trackable, tensor_dict = _get_tensors_from_legacy_saveable(
|
||
|
td, node_ids, call_with_mapped_captures, object_graph_proto)
|
||
|
else:
|
||
|
tensor_dict = _get_tensors_from_trackable(
|
||
|
td, call_with_mapped_captures, object_graph_proto)
|
||
|
trackable = td.object_to_save
|
||
|
serialized_tensors[trackable] = tensor_dict
|
||
|
|
||
|
if cache is not None and td.object_to_save not in cache:
|
||
|
cache[td.object_to_save] = (
|
||
|
trackable, tensor_dict,
|
||
|
object_graph_proto.nodes[td.node_id].attributes)
|
||
|
|
||
|
return serialized_tensors
|
||
|
|
||
|
|
||
|
def _get_tensors_from_legacy_saveable(
|
||
|
trackable_data: _TrackableData,
|
||
|
node_ids: Dict[base.Trackable, int],
|
||
|
call_with_mapped_captures: Callable[..., Any],
|
||
|
object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
|
||
|
) -> Tuple[base.Trackable, Dict[str, Any]]:
|
||
|
"""Gets tensors to serialize from a Trackable with legacy SaveableObjects."""
|
||
|
# Call `save_util_v1` methods to create legacy SaveableObjects and update the
|
||
|
# proto.
|
||
|
object_names = object_identity.ObjectIdentityDictionary()
|
||
|
object_names[trackable_data.trackable] = trackable_data.object_name
|
||
|
object_map = object_identity.ObjectIdentityDictionary()
|
||
|
object_map[trackable_data.trackable] = trackable_data.object_to_save
|
||
|
|
||
|
checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys(
|
||
|
object_names, object_map)
|
||
|
named_saveable_objects, _ = (
|
||
|
save_util_v1.generate_saveable_objects(
|
||
|
checkpoint_factory_map,
|
||
|
object_graph_proto,
|
||
|
node_ids,
|
||
|
object_map,
|
||
|
call_with_mapped_captures,
|
||
|
saveables_cache=None))
|
||
|
trackable = (
|
||
|
saveable_object_util.SaveableCompatibilityConverter(
|
||
|
trackable_data.object_to_save, named_saveable_objects))
|
||
|
return trackable, trackable._serialize_to_tensors() # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
def _get_tensors_from_trackable(
|
||
|
trackable_data: _TrackableData,
|
||
|
call_with_mapped_captures: Union[Callable[..., Any], None],
|
||
|
object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
|
||
|
) -> Dict[str, Any]:
|
||
|
"""Gets tensors to serialize from a Trackable."""
|
||
|
trackable = trackable_data.object_to_save
|
||
|
save_fn = trackable._serialize_to_tensors # pylint: disable=protected-access
|
||
|
|
||
|
if (call_with_mapped_captures and
|
||
|
isinstance(save_fn, core.ConcreteFunction)):
|
||
|
ret_tensor_dict = call_with_mapped_captures(save_fn, [])
|
||
|
else:
|
||
|
ret_tensor_dict = save_fn()
|
||
|
|
||
|
# Create checkpoint keys for each entry in the returned tensor dict, and
|
||
|
# write each entry to the object proto.
|
||
|
tensor_dict = {}
|
||
|
for tensor_name, maybe_tensor in ret_tensor_dict.items():
|
||
|
local_name = trackable_utils.escape_local_name(tensor_name)
|
||
|
checkpoint_key = trackable_utils.checkpoint_key(trackable_data.object_name,
|
||
|
local_name)
|
||
|
tensor_dict[checkpoint_key] = maybe_tensor
|
||
|
|
||
|
# TODO(b/261786493): Delete this when DCheckpoint is removed.
|
||
|
if isinstance(maybe_tensor, saveable_object_lib.SaveSpec):
|
||
|
maybe_tensor.name = checkpoint_key
|
||
|
maybe_tensor.slice_spec = ""
|
||
|
|
||
|
if object_graph_proto is not None:
|
||
|
object_graph_proto.nodes[trackable_data.node_id].attributes.add(
|
||
|
name=local_name,
|
||
|
checkpoint_key=checkpoint_key,
|
||
|
full_name=util.get_full_name(trackable))
|
||
|
|
||
|
return tensor_dict
|
||
|
|
||
|
|
||
|
def _get_and_write_pystate_feed_additions(
|
||
|
pystate_trackables: List[_TrackableData],
|
||
|
cache: Union[Dict[base.Trackable, Any], None],
|
||
|
object_graph_proto=None
|
||
|
) -> Tuple[Dict[base.Trackable, Any], Dict[base.Trackable, Any]]:
|
||
|
"""Gets feed additions needed for checkpointing Python State."""
|
||
|
serialized_tensors = object_identity.ObjectIdentityDictionary()
|
||
|
# Maps tensor placeholders to python values.
|
||
|
feed_additions = {}
|
||
|
|
||
|
for td in pystate_trackables:
|
||
|
trackable = td.object_to_save
|
||
|
checkpoint_key = trackable_utils.checkpoint_key(td.object_name,
|
||
|
python_state.PYTHON_STATE)
|
||
|
if trackable in cache:
|
||
|
save_string = cache[td.object_to_save][python_state.PYTHON_STATE]
|
||
|
else:
|
||
|
with ops.device("/cpu:0"):
|
||
|
save_string = constant_op.constant("", dtype=dtypes.string)
|
||
|
cache[trackable] = {python_state.PYTHON_STATE: save_string}
|
||
|
|
||
|
with ops.init_scope():
|
||
|
value = trackable.serialize()
|
||
|
feed_additions[save_string] = value
|
||
|
serialized_tensors[trackable] = {checkpoint_key: save_string}
|
||
|
|
||
|
object_graph_proto.nodes[td.node_id].attributes.add(
|
||
|
name=python_state.PYTHON_STATE,
|
||
|
checkpoint_key=checkpoint_key,
|
||
|
full_name=util.get_full_name(trackable))
|
||
|
|
||
|
return serialized_tensors, feed_additions
|
||
|
|
||
|
|
||
|
def _get_and_write_registered_savers(
|
||
|
registered_trackables: Dict[str, List[_TrackableData]],
|
||
|
object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
|
||
|
) -> Dict[str, Dict[str, base.Trackable]]:
|
||
|
"""Generates dictionary of registered savers and updates the proto."""
|
||
|
registered_savers = collections.defaultdict(dict)
|
||
|
for saver_name, trackables in registered_trackables.items():
|
||
|
for td in trackables:
|
||
|
registered_savers[saver_name][td.object_name] = td.object_to_save
|
||
|
|
||
|
object_proto = object_graph_proto.nodes[td.node_id]
|
||
|
object_proto.registered_saver.name = saver_name
|
||
|
object_proto.registered_saver.object_name = td.object_name
|
||
|
|
||
|
return registered_savers
|
||
|
|
||
|
|
||
|
def serialize_graph_view(
|
||
|
graph_view: graph_view_lib.ObjectGraphView,
|
||
|
object_map: Optional[Mapping[base.Trackable, base.Trackable]] = None,
|
||
|
call_with_mapped_captures: Optional[Callable[..., Any]] = None,
|
||
|
cache: Optional[Dict[base.Trackable, Any]] = None) -> ...:
|
||
|
"""Gathers serialization objects, and creates a TrackableObjectGraph proto."""
|
||
|
# There are 3 types of checkpoint serialization types supported:
|
||
|
# 1. Trackables that override `Trackable._serialize_to_tensor()`.
|
||
|
# 2. PythonState: A special type of Trackable that serializes a Python string.
|
||
|
# 3. Registered Trackable Savers: For objects that need to define advanced
|
||
|
# checkpointing operations not supported by (1) or (2).
|
||
|
trackable_data, node_ids = _gather_trackable_data(graph_view, object_map)
|
||
|
tensor_trackables, pystate_trackables, registered_trackables = (
|
||
|
_split_trackables(trackable_data))
|
||
|
|
||
|
object_graph_proto = _fill_object_graph_proto(trackable_data)
|
||
|
|
||
|
serialized_tensors = _get_and_write_tensors_to_serialize(
|
||
|
tensor_trackables,
|
||
|
node_ids,
|
||
|
call_with_mapped_captures,
|
||
|
cache,
|
||
|
object_graph_proto)
|
||
|
registered_savers = _get_and_write_registered_savers(
|
||
|
registered_trackables, object_graph_proto)
|
||
|
|
||
|
# PythonState trackables must be treated differently depending on if the
|
||
|
# checkpoint is being saved in TF1 graph mode (`cache` exists) or
|
||
|
# eager mode (`cache` is None).
|
||
|
if cache is None:
|
||
|
# When the tensor cache is None, get the serialized tensors directly.
|
||
|
feed_additions = None
|
||
|
serialized_tensors.update(_get_and_write_tensors_to_serialize(
|
||
|
pystate_trackables,
|
||
|
node_ids,
|
||
|
call_with_mapped_captures,
|
||
|
cache,
|
||
|
object_graph_proto))
|
||
|
else:
|
||
|
# Python state is not automatically updated within a TF session so these
|
||
|
# values must be passed to sess.run(feed_additions=...).
|
||
|
new_serialized_tensors, feed_additions = (
|
||
|
_get_and_write_pystate_feed_additions(pystate_trackables,
|
||
|
cache,
|
||
|
object_graph_proto))
|
||
|
serialized_tensors.update(new_serialized_tensors)
|
||
|
|
||
|
# Gather all trackables that have checkpoint values or descendants with
|
||
|
# checkpoint values, and add that info to the proto.
|
||
|
util.add_checkpoint_values_check(object_graph_proto)
|
||
|
return (serialized_tensors, feed_additions, registered_savers,
|
||
|
object_graph_proto)
|
||
|
|