659 lines
27 KiB
Python
659 lines
27 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Saves and restore variables inside traced @tf.functions."""
|
|
|
|
import dataclasses
|
|
import math
|
|
import time
|
|
from typing import Callable, Mapping, MutableMapping, MutableSequence, Sequence
|
|
|
|
from absl import logging
|
|
|
|
from tensorflow.core.protobuf import saver_pb2
|
|
from tensorflow.python.checkpoint import checkpoint_options
|
|
from tensorflow.python.checkpoint.sharding import sharding_policies
|
|
from tensorflow.python.checkpoint.sharding import sharding_util
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import device as device_lib
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor as tensor_lib
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gen_io_ops
|
|
from tensorflow.python.ops import io_ops
|
|
from tensorflow.python.ops import string_ops
|
|
from tensorflow.python.saved_model import registration
|
|
from tensorflow.python.saved_model.pywrap_saved_model import metrics
|
|
from tensorflow.python.trackable import base
|
|
from tensorflow.python.trackable import trackable_utils
|
|
from tensorflow.python.training.saving import saveable_object
|
|
from tensorflow.python.training.saving import saveable_object_util
|
|
from tensorflow.python.types import core
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import object_identity
|
|
|
|
|
|
RegisteredSaversDict = Mapping[
|
|
registration.RegisteredSaver, Mapping[str, base.Trackable]]
|
|
MappedCapturesCallable = Callable[
|
|
[core.ConcreteFunction, Sequence[tensor_lib.Tensor]], tensor_lib.Tensor]
|
|
|
|
|
|
def _single_shard_save(
|
|
file_prefix: tensor_lib.Tensor,
|
|
shard: sharding_util.TensorSliceDict,
|
|
task: device_lib.DeviceSpec,
|
|
options: "checkpoint_options.CheckpointOptions | None" = None,
|
|
) -> ops.Operation:
|
|
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
|
|
|
Args:
|
|
file_prefix: A string or scalar string Tensor containing the prefix to
|
|
save under.
|
|
shard: Dict containing tensors. {checkpoint key: {slice_spec: tensor} }
|
|
task: The device spec task of the tensors in the shard.
|
|
options: Optional `CheckpointOptions` object.
|
|
|
|
Returns:
|
|
An `Operation`, or None when executing eagerly.
|
|
"""
|
|
options = options or checkpoint_options.CheckpointOptions()
|
|
|
|
tensor_names = []
|
|
tensors = []
|
|
slice_specs = []
|
|
for checkpoint_key, tensor_slices in shard.items():
|
|
for slice_spec, tensor in tensor_slices.items():
|
|
# A tensor value of `None` indicates that this SaveableObject gets
|
|
# recorded in the object graph, but that no value is saved in the
|
|
# checkpoint.
|
|
if tensor is not None:
|
|
# See `MultiDeviceSaver._get_shards_by_task` for an explanation on the
|
|
# wrapped properties.
|
|
name = (tensor._wrapped_name # pylint: disable=protected-access
|
|
if hasattr(tensor, "_wrapped_name")
|
|
else checkpoint_key)
|
|
spec = (tensor._wrapped_slice_spec # pylint: disable=protected-access
|
|
if hasattr(tensor, "_wrapped_slice_spec")
|
|
else slice_spec)
|
|
|
|
tensor_names.append(name)
|
|
tensors.append(tensor)
|
|
slice_specs.append(spec)
|
|
|
|
save_device = options.experimental_io_device or (len(tensors) and task)
|
|
with ops.device(save_device or "CPU:0"):
|
|
return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors)
|
|
|
|
|
|
def _single_shard_restore(
|
|
file_prefix: tensor_lib.Tensor,
|
|
shardable_tensors: Sequence[sharding_util.ShardableTensor],
|
|
options: "checkpoint_options.CheckpointOptions | None" = None
|
|
) -> sharding_util.TensorSliceDict:
|
|
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
|
|
|
Args:
|
|
file_prefix: A string or scalar string Tensor containing the prefix for
|
|
files to read from.
|
|
shardable_tensors: A list of ShardableTensors to restore.
|
|
options: Optional `CheckpointOptions` object.
|
|
|
|
Returns:
|
|
A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor).
|
|
"""
|
|
options = options or checkpoint_options.CheckpointOptions()
|
|
|
|
tensor_names = []
|
|
tensor_dtypes = []
|
|
slice_specs = []
|
|
for shardable_tensor in shardable_tensors:
|
|
if shardable_tensor._tensor_save_spec: # pylint: disable=protected-access
|
|
name = shardable_tensor._tensor_save_spec.name # pylint: disable=protected-access
|
|
spec = shardable_tensor._tensor_save_spec.slice_spec # pylint: disable=protected-access
|
|
else:
|
|
name, spec = shardable_tensor.checkpoint_key, shardable_tensor.slice_spec
|
|
tensor_names.append(name)
|
|
slice_specs.append(spec)
|
|
tensor_dtypes.append(shardable_tensor.dtype)
|
|
|
|
restore_device = options.experimental_io_device or "cpu:0"
|
|
with ops.device(restore_device):
|
|
restored_tensors = io_ops.restore_v2(
|
|
file_prefix, tensor_names, slice_specs, tensor_dtypes)
|
|
|
|
restored_tensor_dict = {}
|
|
for shardable_tensor in shardable_tensors:
|
|
restored_tensor = restored_tensors.pop(0)
|
|
(restored_tensor_dict
|
|
.setdefault(shardable_tensor.checkpoint_key, {}
|
|
)[shardable_tensor.slice_spec]) = restored_tensor
|
|
return restored_tensor_dict
|
|
|
|
|
|
def sharded_filename(
|
|
filename_tensor: tensor_lib.Tensor,
|
|
shard: int,
|
|
num_shards: tensor_lib.Tensor
|
|
) -> tensor_lib.Tensor:
|
|
"""Append sharding information to a filename.
|
|
|
|
Args:
|
|
filename_tensor: A string tensor.
|
|
shard: Integer. The shard for the filename.
|
|
num_shards: An int Tensor for the number of shards.
|
|
|
|
Returns:
|
|
A string tensor.
|
|
"""
|
|
return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
|
|
|
|
|
|
def registered_saver_filename(
|
|
filename_tensor: tensor_lib.Tensor,
|
|
saver_name: registration.RegisteredSaver
|
|
) -> tensor_lib.Tensor:
|
|
return string_ops.string_join(
|
|
[filename_tensor, constant_op.constant(f"-{saver_name}")])
|
|
|
|
|
|
def _get_mapped_registered_save_fn(
|
|
fn: Callable[..., tensor_lib.Tensor],
|
|
trackables: Sequence[base.Trackable],
|
|
call_with_mapped_captures: MappedCapturesCallable
|
|
) -> Callable[[tensor_lib.Tensor], MappedCapturesCallable]:
|
|
"""Converts the function to a python or tf.function with a single file arg."""
|
|
|
|
def save_fn(file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor:
|
|
return fn(trackables=trackables, file_prefix=file_prefix)
|
|
if call_with_mapped_captures is None:
|
|
return save_fn
|
|
else:
|
|
tf_fn = def_function.function(save_fn, autograph=False)
|
|
concrete = tf_fn.get_concrete_function(
|
|
file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
|
|
|
|
def save_fn_with_replaced_captures(
|
|
file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor:
|
|
return call_with_mapped_captures(concrete, [file_prefix])
|
|
|
|
return save_fn_with_replaced_captures
|
|
|
|
|
|
def _get_mapped_registered_restore_fn(
|
|
fn: Callable[..., tensor_lib.Tensor],
|
|
trackables: Sequence[base.Trackable],
|
|
call_with_mapped_captures: MappedCapturesCallable
|
|
) -> Callable[..., tensor_lib.Tensor]:
|
|
"""Converts the function to a python or tf.function with a single file arg."""
|
|
|
|
def restore_fn(merged_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor:
|
|
return fn(trackables=trackables, merged_prefix=merged_prefix)
|
|
if call_with_mapped_captures is None:
|
|
return restore_fn
|
|
else:
|
|
tf_fn = def_function.function(restore_fn, autograph=False)
|
|
concrete = tf_fn.get_concrete_function(
|
|
merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
|
|
|
|
def restore_fn_with_replaced_captures(
|
|
merged_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor:
|
|
return call_with_mapped_captures(concrete, [merged_prefix])
|
|
|
|
return restore_fn_with_replaced_captures
|
|
|
|
|
|
_restore_noop = lambda *args, **kwargs: None
|
|
|
|
|
|
class MultiDeviceSaver:
|
|
"""Saves checkpoints directly from multiple devices.
|
|
|
|
Note that this is a low-level utility which stores Tensors in the keys
|
|
specified by `SaveableObject`s. Higher-level utilities for object-based
|
|
checkpointing are built on top of it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
serialized_tensors: Mapping[
|
|
base.Trackable, sharding_util.TensorSliceDict],
|
|
registered_savers: "RegisteredSaversDict | None" = None,
|
|
call_with_mapped_captures: "MappedCapturesCallable | None" = None):
|
|
"""Specify a list of `SaveableObject`s to save and restore.
|
|
|
|
Args:
|
|
serialized_tensors: A dictionary mapping `Trackable` to a tensor dict,
|
|
which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. The
|
|
`Trackable` key is used to get the `restore_from_tensors` function,
|
|
and may be `None` if the tensor is not meant to be restored.
|
|
registered_savers: A dictionary mapping `registration.RegisteredSaver`
|
|
namedtuples to a dictionary of named Trackables. The keys of the
|
|
Trackable dictionary are string names that uniquely identify the
|
|
Trackable in the checkpoint.
|
|
call_with_mapped_captures: TODO
|
|
"""
|
|
self._shardable_tensors_by_task: MutableMapping[
|
|
device_lib.DeviceSpec,
|
|
MutableSequence[sharding_util.ShardableTensor]] = {}
|
|
# Keep these two data structures so that we can map restored tensors to
|
|
# the Trackable restore functions.
|
|
self._keys_to_restore_fn: MutableMapping[
|
|
sharding_util.TensorSlice,
|
|
Callable[Mapping[str, tensor_lib.Tensor]]] = {}
|
|
self._restore_fn_to_keys: MutableMapping[
|
|
Callable[Mapping[str, tensor_lib.Tensor]],
|
|
MutableSequence[sharding_util.TensorSlice]] = {}
|
|
|
|
unique_tasks = set()
|
|
for obj, tensor_dict in serialized_tensors.items():
|
|
restore_fn = _restore_noop if obj is None else obj._restore_from_tensors
|
|
|
|
# Divide tensor_dict by task.
|
|
for checkpoint_key, tensor_slice_dict in tensor_dict.items():
|
|
if not isinstance(tensor_slice_dict, dict):
|
|
# Make sure that maybe_tensor is structured as {slice_spec -> tensor}.
|
|
tensor_slice_dict = {"": tensor_slice_dict}
|
|
|
|
for slice_spec, tensor_save_spec in tensor_slice_dict.items():
|
|
tensor_value = None
|
|
if not isinstance(tensor_save_spec, saveable_object.SaveSpec):
|
|
tensor_value = tensor_save_spec
|
|
tensor_save_spec = saveable_object.SaveSpec(
|
|
tensor=tensor_value,
|
|
slice_spec=slice_spec,
|
|
name=checkpoint_key,
|
|
dtype=tensor_save_spec.dtype,
|
|
device=tensor_save_spec.device)
|
|
|
|
if (checkpoint_key, slice_spec) in self._keys_to_restore_fn:
|
|
raise ValueError(
|
|
"Recieved multiple tensors with the same checkpoint key and "
|
|
"slice spec. This is invalid because one will overwrite the "
|
|
"other in the checkpoint. This indicates a bug in the "
|
|
"Checkpoint key-generation.")
|
|
self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn
|
|
self._restore_fn_to_keys.setdefault(restore_fn, []).append(
|
|
(checkpoint_key, slice_spec))
|
|
|
|
if isinstance(tensor_save_spec.device, str):
|
|
device = device_lib.DeviceSpec.from_string(tensor_save_spec.device)
|
|
task = device_lib.DeviceSpec.from_string(
|
|
saveable_object_util.set_cpu0(tensor_save_spec.device))
|
|
else:
|
|
device = tensor_save_spec.device
|
|
task = device_lib.DeviceSpec.from_string(
|
|
saveable_object_util.set_cpu0(device.to_string()))
|
|
|
|
self._shardable_tensors_by_task.setdefault(task, []).append(
|
|
sharding_util.ShardableTensor(
|
|
_tensor_save_spec=tensor_save_spec,
|
|
tensor=tensor_value,
|
|
dtype=tensor_save_spec.dtype,
|
|
device=device,
|
|
name=tensor_save_spec.name,
|
|
shape=None,
|
|
slice_spec=slice_spec.strip(),
|
|
checkpoint_key=checkpoint_key,
|
|
trackable=obj))
|
|
unique_tasks.add(
|
|
saveable_object_util.set_cpu0(device.to_string()))
|
|
|
|
self._num_unique_tasks = len(unique_tasks)
|
|
|
|
self._registered_savers = {}
|
|
if registered_savers:
|
|
for registered_name, trackables in registered_savers.items():
|
|
save_fn = _get_mapped_registered_save_fn(
|
|
registration.get_save_function(registered_name),
|
|
trackables, call_with_mapped_captures)
|
|
restore_fn = _get_mapped_registered_restore_fn(
|
|
registration.get_restore_function(registered_name),
|
|
trackables, call_with_mapped_captures)
|
|
self._registered_savers[registered_name] = (save_fn, restore_fn)
|
|
|
|
@classmethod
|
|
def from_saveables(
|
|
cls,
|
|
saveables: Sequence[base.Trackable],
|
|
registered_savers: "RegisteredSaversDict | None" = None,
|
|
call_with_mapped_captures: "MappedCapturesCallable | None" = None
|
|
) -> "MultiDeviceSaver":
|
|
"""Constructs a MultiDeviceSaver from a list of `SaveableObject`s."""
|
|
serialized_tensors = object_identity.ObjectIdentityDictionary()
|
|
for saveable in saveables:
|
|
trackable = saveable_object_util.SaveableCompatibilityConverter(
|
|
saveable, saveables=[saveable])
|
|
serialized_tensors[trackable] = trackable._serialize_to_tensors() # pylint: disable=protected-access
|
|
return cls(serialized_tensors, registered_savers, call_with_mapped_captures)
|
|
|
|
def to_proto(self) -> saver_pb2.SaverDef:
|
|
"""Serializes to a SaverDef referencing the current graph."""
|
|
filename_tensor = array_ops.placeholder(
|
|
shape=[], dtype=dtypes.string, name="saver_filename")
|
|
save_tensor = self._traced_save(filename_tensor)
|
|
restore_op = self._traced_restore(filename_tensor).op
|
|
return saver_pb2.SaverDef(
|
|
filename_tensor_name=filename_tensor.name,
|
|
save_tensor_name=save_tensor.name,
|
|
restore_op_name=restore_op.name,
|
|
version=saver_pb2.SaverDef.V2)
|
|
|
|
@def_function.function(
|
|
input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
|
|
autograph=False)
|
|
def _traced_save(self, file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor:
|
|
save_op = self.save(file_prefix)
|
|
with ops.device("cpu:0"):
|
|
with ops.control_dependencies([save_op]):
|
|
return array_ops.identity(file_prefix)
|
|
|
|
@def_function.function(
|
|
input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
|
|
autograph=False)
|
|
def _traced_restore(
|
|
self, file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor:
|
|
restore_ops = self.restore(file_prefix)
|
|
with ops.device("cpu:0"):
|
|
with ops.control_dependencies(restore_ops.values()):
|
|
return array_ops.identity(file_prefix)
|
|
|
|
def _get_shards_by_task(
|
|
self,
|
|
sharding_callback: sharding_util.ShardingCallback
|
|
) -> Sequence[sharding_util.TensorSliceDict]:
|
|
"""Calls the sharding callback with shardable_tensors.
|
|
|
|
Args:
|
|
sharding_callback: ShardingCallback. The callback function wrapper that
|
|
splits shardable_tensors into shards.
|
|
|
|
Returns:
|
|
A list of shards.
|
|
"""
|
|
def wrap_tensor(shardable_tensor):
|
|
tensor_val = shardable_tensor.tensor
|
|
tensor_shape = shardable_tensor.shape
|
|
save_spec = shardable_tensor._tensor_save_spec # pylint: disable=protected-access
|
|
with ops.device(shardable_tensor.device):
|
|
save_spec_tensor = save_spec.tensor
|
|
|
|
if tensor_val is None and save_spec_tensor is None:
|
|
# A tensor value of `None` indicates that this SaveableObject gets
|
|
# recorded in the object graph, but that no value is saved in the
|
|
# checkpoint.
|
|
return None
|
|
elif save_spec_tensor is not None:
|
|
# Pull the tensor value from _tensor_save_spec.
|
|
tensor_val = save_spec_tensor
|
|
tensor_shape = save_spec_tensor.shape
|
|
|
|
# Propagate the save spec name and/or slice spec when they are tensors.
|
|
# This makes sure properties like `layout` for dtensor names/slice specs
|
|
# are preserved during sharding.
|
|
if isinstance(save_spec.name, tensor_lib.Tensor):
|
|
tensor_val._wrapped_name = save_spec.name # pylint: disable=protected-access
|
|
if isinstance(shardable_tensor.slice_spec, tensor_lib.Tensor):
|
|
tensor_val._wrapped_slice_spec = save_spec.slice_spec # pylint: disable=protected-access
|
|
|
|
return dataclasses.replace(
|
|
shardable_tensor,
|
|
tensor=tensor_val,
|
|
shape=tensor_shape)
|
|
|
|
shardable_tensors_by_task = {
|
|
task: [shardable_tensor
|
|
for shardable_tensor in map(wrap_tensor, shardable_tensors)
|
|
if shardable_tensor is not None]
|
|
for task, shardable_tensors in self._shardable_tensors_by_task.items()}
|
|
|
|
sharding_callback = (
|
|
sharding_callback or sharding_policies.ShardByTaskPolicy())
|
|
metrics.SetShardingCallbackDescription(
|
|
description=sharding_callback.description)
|
|
|
|
start_time = time.time() * 1e6
|
|
shards_by_task = [
|
|
(task, sharding_callback(shardable_tensors))
|
|
for task, shardable_tensors in shardable_tensors_by_task.items()]
|
|
callback_duration = math.ceil(time.time() * 1e6 - start_time)
|
|
metrics.AddShardingCallbackDuration(
|
|
callback_duration=max(1, callback_duration)) # in microseconds
|
|
logging.info("Sharding callback duration: %s", callback_duration)
|
|
|
|
return shards_by_task
|
|
|
|
def save(
|
|
self,
|
|
file_prefix: tensor_lib.Tensor,
|
|
options: "checkpoint_options.CheckpointOptions | None" = None
|
|
) -> ops.Operation:
|
|
"""Save the saveable objects to a checkpoint with `file_prefix`.
|
|
|
|
Args:
|
|
file_prefix: A string or scalar string Tensor containing the prefix to
|
|
save under.
|
|
options: Optional `CheckpointOptions` object.
|
|
Returns:
|
|
An `Operation`, or None when executing eagerly.
|
|
"""
|
|
options = options or checkpoint_options.CheckpointOptions()
|
|
|
|
# IMPLEMENTATION DETAILS: most clients should skip.
|
|
#
|
|
# Suffix for any well-formed "checkpoint_prefix", when sharded.
|
|
# Transformations:
|
|
# * Users pass in "save_path" in save() and restore(). Say "myckpt".
|
|
# * checkpoint_prefix gets fed <save_path><sharded_suffix>.
|
|
#
|
|
# Example:
|
|
# During runtime, a temporary directory is first created, which contains
|
|
# files
|
|
#
|
|
# <train dir>/myckpt_temp/
|
|
# part-?????-of-?????{.index, .data-00000-of-00001}
|
|
#
|
|
# Before .save() finishes, they will be (hopefully, atomically) renamed to
|
|
#
|
|
# <train dir>/
|
|
# myckpt{.index, .data-?????-of-?????}
|
|
#
|
|
# Filesystems with eventual consistency (such as S3), don't need a
|
|
# temporary location. Using a temporary directory in those cases might
|
|
# cause situations where files are not available during copy.
|
|
#
|
|
# Users only need to interact with the user-specified prefix, which is
|
|
# "<train dir>/myckpt" in this case. Save() and Restore() work with the
|
|
# prefix directly, instead of any physical pathname. (On failure and
|
|
# subsequent restore, an outdated and orphaned temporary directory can be
|
|
# safely removed.)
|
|
with ops.device("CPU"):
|
|
sharded_suffix = array_ops.where(
|
|
string_ops.regex_full_match(file_prefix, "^s3://.*"),
|
|
constant_op.constant(".part"),
|
|
constant_op.constant("_temp/part"))
|
|
tmp_checkpoint_prefix = string_ops.string_join(
|
|
[file_prefix, sharded_suffix])
|
|
registered_paths = {
|
|
saver_name: registered_saver_filename(file_prefix, saver_name)
|
|
for saver_name in self._registered_savers
|
|
}
|
|
|
|
def save_fn() -> ops.Operation:
|
|
saved_prefixes = []
|
|
# Save with the registered savers. These run before default savers due to
|
|
# the API contract.
|
|
for saver_name, (save_fn, _) in self._registered_savers.items():
|
|
maybe_saved_prefixes = save_fn(registered_paths[saver_name])
|
|
if maybe_saved_prefixes is not None:
|
|
flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
|
|
if not all(
|
|
tensor_util.is_tf_type(x) and x.dtype == dtypes.string
|
|
for x in flattened_saved_prefixes):
|
|
raise ValueError(
|
|
"Registered saver must return a (maybe empty) list of "
|
|
f"string type tensors. Got {maybe_saved_prefixes}.")
|
|
saved_prefixes.extend(flattened_saved_prefixes)
|
|
|
|
shards_by_task = self._get_shards_by_task(
|
|
options.experimental_sharding_callback)
|
|
num_shards = sum([len(shards) for _, shards in shards_by_task])
|
|
metrics.AddNumCheckpointShardsWritten(num_shards=num_shards)
|
|
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
|
|
sharded_saves = []
|
|
|
|
shard_idx = 0
|
|
for task, shards in shards_by_task:
|
|
for shard in shards:
|
|
with ops.device(task):
|
|
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard_idx,
|
|
num_shards_tensor)
|
|
shard_idx += 1
|
|
saved_prefixes.append(shard_prefix)
|
|
sharded_saves.append(
|
|
_single_shard_save(shard_prefix, shard, task, options))
|
|
|
|
with ops.control_dependencies(sharded_saves):
|
|
# Merge on the io_device if specified, otherwise co-locates the merge op
|
|
# with the last device used.
|
|
tensor_device_spec = list(self._shardable_tensors_by_task.keys())[-1]
|
|
merge_device_spec = (
|
|
options.experimental_io_device or
|
|
saveable_object_util.set_cpu0(tensor_device_spec.to_string()))
|
|
with ops.device(merge_device_spec):
|
|
# V2 format write path consists of a metadata merge step. Once
|
|
# merged, attempts to delete the temporary directory,
|
|
# "<user-fed prefix>_temp".
|
|
return gen_io_ops.merge_v2_checkpoints(
|
|
saved_prefixes, file_prefix, delete_old_dirs=True)
|
|
|
|
# Since this will causes a function re-trace on each save, limit this to the
|
|
# cases where it is needed: eager and when there are multiple tasks. Note
|
|
# that the retrace is needed to ensure we pickup the latest values of
|
|
# options like experimental_io_device.
|
|
if context.executing_eagerly() and self._num_unique_tasks > 1:
|
|
# Explicitly place the identity op on the first device.
|
|
@def_function.function(jit_compile=False)
|
|
def tf_function_save() -> None:
|
|
save_fn()
|
|
tf_function_save()
|
|
else:
|
|
return save_fn()
|
|
|
|
def restore(
|
|
self,
|
|
file_prefix: tensor_lib.Tensor,
|
|
options: "checkpoint_options.CheckpointOptions | None" = None
|
|
) -> Mapping[str, ops.Operation]:
|
|
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
|
|
|
Args:
|
|
file_prefix: A string or scalar string Tensor containing the prefix for
|
|
files to read from.
|
|
options: Optional `CheckpointOptions` object.
|
|
|
|
Returns:
|
|
When not run eagerly or when saving on a single device, returns a
|
|
dictionary mapping from SaveableObject names to restore operations;
|
|
otherwise, returns an empty dict.
|
|
"""
|
|
options = options or checkpoint_options.CheckpointOptions()
|
|
|
|
def restore_fn() -> Mapping[str, ops.Operation]:
|
|
restore_fn_inputs = {}
|
|
restore_fn_input_count = {
|
|
fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()}
|
|
|
|
restore_ops = {}
|
|
|
|
for task, shard in self._shardable_tensors_by_task.items():
|
|
with ops.device(task):
|
|
# Load values from checkpoint
|
|
restored_tensor_dict = _single_shard_restore(
|
|
file_prefix, shard, options)
|
|
|
|
# Map restored tensors to the corresponding restore_fn, and see if
|
|
# all inputs have all been loaded. Call `restore_fn` if that is the
|
|
# case.
|
|
for ckpt_key, slice_and_tensor in restored_tensor_dict.items():
|
|
for slice_spec, tensor in slice_and_tensor.items():
|
|
restore_fn = self._keys_to_restore_fn[(ckpt_key,
|
|
slice_spec)]
|
|
|
|
# Processing the returned restored_tensor_dict to prepare for
|
|
# the Trackable `restore` function. The `restore` function
|
|
# expects a map of `string name (checkpoint_key) -> Tensor`.
|
|
# Unless there is a slice_spec, in which case the map will be of
|
|
# `string name (checkpoint_key)-> slice_spec -> Tensor`.
|
|
if slice_spec:
|
|
(restore_fn_inputs.setdefault(restore_fn, {}).setdefault(
|
|
ckpt_key, {})[slice_spec]) = tensor
|
|
else:
|
|
restore_fn_inputs.setdefault(restore_fn,
|
|
{})[ckpt_key] = tensor
|
|
restore_fn_input_count[restore_fn] -= 1
|
|
|
|
if restore_fn_input_count[restore_fn] == 0:
|
|
restored_tensors = {}
|
|
# Extracts the substring after the "/.ATTRIBUTES/" in the
|
|
# ckpt_key from restore_fn_inputs[restore_fn] to
|
|
# restored_tensors. For example, if
|
|
# restore_fn_input[restore_fn] is dict
|
|
# { "/.ATTIBUTES/a": Tensor}, restored_tensors will be
|
|
# changed to dict {"a": Tensor}
|
|
for ckpt_key, tensor in restore_fn_inputs[restore_fn].items():
|
|
restored_tensors[trackable_utils.extract_local_name(
|
|
ckpt_key)] = tensor
|
|
ret = restore_fn(restored_tensors)
|
|
if isinstance(ret, dict):
|
|
restore_ops.update(ret)
|
|
# Run registered restore methods after the default restore ops.
|
|
for _, (_, restore_fn) in self._registered_savers.items():
|
|
restore_fn(file_prefix)
|
|
return restore_ops
|
|
|
|
has_custom_device_saver = False
|
|
for sts in self._shardable_tensors_by_task.values():
|
|
if any([context.is_custom_device(st.device.to_string()) for st in sts]):
|
|
has_custom_device_saver = True
|
|
break
|
|
# Since this will cause a function re-trace on each restore, limit this to
|
|
# cases where it is needed: eager and when there are multiple tasks or any
|
|
# device_spec is a custom device. Note that the retrace is needed to ensure
|
|
# we pickup the latest values of options like experimental_io_device.
|
|
#
|
|
# We run in a function when there is a custom device saver because custom
|
|
# devices, such as DTensor, usually do a sharded save and restore.
|
|
# Doing a sharded save and restore requires knowledge about what shards
|
|
# of variables we are restoring to. In practice, this means that custom
|
|
# devices need the AssignVariableOps along with the Restore op within the
|
|
# same graph to infer shapes and shard specs for Restore op.
|
|
if context.executing_eagerly() and (self._num_unique_tasks > 1 or
|
|
has_custom_device_saver):
|
|
@def_function.function(jit_compile=False, autograph=False)
|
|
def tf_function_restore() -> Mapping[str, ops.Operation]:
|
|
restore_fn()
|
|
return {}
|
|
|
|
restore_ops = tf_function_restore()
|
|
else:
|
|
restore_ops = restore_fn()
|
|
|
|
return restore_ops
|