# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Saves and restore variables inside traced @tf.functions.""" import dataclasses import math import time from typing import Callable, Mapping, MutableMapping, MutableSequence, Sequence from absl import logging from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.checkpoint import checkpoint_options from tensorflow.python.checkpoint.sharding import sharding_policies from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import string_ops from tensorflow.python.saved_model import registration from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.trackable import base from tensorflow.python.trackable import trackable_utils from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.types import core from tensorflow.python.util import nest from tensorflow.python.util import object_identity RegisteredSaversDict = Mapping[ registration.RegisteredSaver, Mapping[str, base.Trackable]] MappedCapturesCallable = Callable[ [core.ConcreteFunction, Sequence[tensor_lib.Tensor]], tensor_lib.Tensor] def _single_shard_save( file_prefix: tensor_lib.Tensor, shard: sharding_util.TensorSliceDict, task: device_lib.DeviceSpec, options: "checkpoint_options.CheckpointOptions | None" = None, ) -> ops.Operation: """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. shard: Dict containing tensors. {checkpoint key: {slice_spec: tensor} } task: The device spec task of the tensors in the shard. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensors = [] slice_specs = [] for checkpoint_key, tensor_slices in shard.items(): for slice_spec, tensor in tensor_slices.items(): # A tensor value of `None` indicates that this SaveableObject gets # recorded in the object graph, but that no value is saved in the # checkpoint. if tensor is not None: # See `MultiDeviceSaver._get_shards_by_task` for an explanation on the # wrapped properties. name = (tensor._wrapped_name # pylint: disable=protected-access if hasattr(tensor, "_wrapped_name") else checkpoint_key) spec = (tensor._wrapped_slice_spec # pylint: disable=protected-access if hasattr(tensor, "_wrapped_slice_spec") else slice_spec) tensor_names.append(name) tensors.append(tensor) slice_specs.append(spec) save_device = options.experimental_io_device or (len(tensors) and task) with ops.device(save_device or "CPU:0"): return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors) def _single_shard_restore( file_prefix: tensor_lib.Tensor, shardable_tensors: Sequence[sharding_util.ShardableTensor], options: "checkpoint_options.CheckpointOptions | None" = None ) -> sharding_util.TensorSliceDict: """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. shardable_tensors: A list of ShardableTensors to restore. options: Optional `CheckpointOptions` object. Returns: A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensor_dtypes = [] slice_specs = [] for shardable_tensor in shardable_tensors: if shardable_tensor._tensor_save_spec: # pylint: disable=protected-access name = shardable_tensor._tensor_save_spec.name # pylint: disable=protected-access spec = shardable_tensor._tensor_save_spec.slice_spec # pylint: disable=protected-access else: name, spec = shardable_tensor.checkpoint_key, shardable_tensor.slice_spec tensor_names.append(name) slice_specs.append(spec) tensor_dtypes.append(shardable_tensor.dtype) restore_device = options.experimental_io_device or "cpu:0" with ops.device(restore_device): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, slice_specs, tensor_dtypes) restored_tensor_dict = {} for shardable_tensor in shardable_tensors: restored_tensor = restored_tensors.pop(0) (restored_tensor_dict .setdefault(shardable_tensor.checkpoint_key, {} )[shardable_tensor.slice_spec]) = restored_tensor return restored_tensor_dict def sharded_filename( filename_tensor: tensor_lib.Tensor, shard: int, num_shards: tensor_lib.Tensor ) -> tensor_lib.Tensor: """Append sharding information to a filename. Args: filename_tensor: A string tensor. shard: Integer. The shard for the filename. num_shards: An int Tensor for the number of shards. Returns: A string tensor. """ return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) def registered_saver_filename( filename_tensor: tensor_lib.Tensor, saver_name: registration.RegisteredSaver ) -> tensor_lib.Tensor: return string_ops.string_join( [filename_tensor, constant_op.constant(f"-{saver_name}")]) def _get_mapped_registered_save_fn( fn: Callable[..., tensor_lib.Tensor], trackables: Sequence[base.Trackable], call_with_mapped_captures: MappedCapturesCallable ) -> Callable[[tensor_lib.Tensor], MappedCapturesCallable]: """Converts the function to a python or tf.function with a single file arg.""" def save_fn(file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return fn(trackables=trackables, file_prefix=file_prefix) if call_with_mapped_captures is None: return save_fn else: tf_fn = def_function.function(save_fn, autograph=False) concrete = tf_fn.get_concrete_function( file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) def save_fn_with_replaced_captures( file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return call_with_mapped_captures(concrete, [file_prefix]) return save_fn_with_replaced_captures def _get_mapped_registered_restore_fn( fn: Callable[..., tensor_lib.Tensor], trackables: Sequence[base.Trackable], call_with_mapped_captures: MappedCapturesCallable ) -> Callable[..., tensor_lib.Tensor]: """Converts the function to a python or tf.function with a single file arg.""" def restore_fn(merged_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return fn(trackables=trackables, merged_prefix=merged_prefix) if call_with_mapped_captures is None: return restore_fn else: tf_fn = def_function.function(restore_fn, autograph=False) concrete = tf_fn.get_concrete_function( merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) def restore_fn_with_replaced_captures( merged_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return call_with_mapped_captures(concrete, [merged_prefix]) return restore_fn_with_replaced_captures _restore_noop = lambda *args, **kwargs: None class MultiDeviceSaver: """Saves checkpoints directly from multiple devices. Note that this is a low-level utility which stores Tensors in the keys specified by `SaveableObject`s. Higher-level utilities for object-based checkpointing are built on top of it. """ def __init__( self, serialized_tensors: Mapping[ base.Trackable, sharding_util.TensorSliceDict], registered_savers: "RegisteredSaversDict | None" = None, call_with_mapped_captures: "MappedCapturesCallable | None" = None): """Specify a list of `SaveableObject`s to save and restore. Args: serialized_tensors: A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. The `Trackable` key is used to get the `restore_from_tensors` function, and may be `None` if the tensor is not meant to be restored. registered_savers: A dictionary mapping `registration.RegisteredSaver` namedtuples to a dictionary of named Trackables. The keys of the Trackable dictionary are string names that uniquely identify the Trackable in the checkpoint. call_with_mapped_captures: TODO """ self._shardable_tensors_by_task: MutableMapping[ device_lib.DeviceSpec, MutableSequence[sharding_util.ShardableTensor]] = {} # Keep these two data structures so that we can map restored tensors to # the Trackable restore functions. self._keys_to_restore_fn: MutableMapping[ sharding_util.TensorSlice, Callable[Mapping[str, tensor_lib.Tensor]]] = {} self._restore_fn_to_keys: MutableMapping[ Callable[Mapping[str, tensor_lib.Tensor]], MutableSequence[sharding_util.TensorSlice]] = {} unique_tasks = set() for obj, tensor_dict in serialized_tensors.items(): restore_fn = _restore_noop if obj is None else obj._restore_from_tensors # Divide tensor_dict by task. for checkpoint_key, tensor_slice_dict in tensor_dict.items(): if not isinstance(tensor_slice_dict, dict): # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. tensor_slice_dict = {"": tensor_slice_dict} for slice_spec, tensor_save_spec in tensor_slice_dict.items(): tensor_value = None if not isinstance(tensor_save_spec, saveable_object.SaveSpec): tensor_value = tensor_save_spec tensor_save_spec = saveable_object.SaveSpec( tensor=tensor_value, slice_spec=slice_spec, name=checkpoint_key, dtype=tensor_save_spec.dtype, device=tensor_save_spec.device) if (checkpoint_key, slice_spec) in self._keys_to_restore_fn: raise ValueError( "Recieved multiple tensors with the same checkpoint key and " "slice spec. This is invalid because one will overwrite the " "other in the checkpoint. This indicates a bug in the " "Checkpoint key-generation.") self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn self._restore_fn_to_keys.setdefault(restore_fn, []).append( (checkpoint_key, slice_spec)) if isinstance(tensor_save_spec.device, str): device = device_lib.DeviceSpec.from_string(tensor_save_spec.device) task = device_lib.DeviceSpec.from_string( saveable_object_util.set_cpu0(tensor_save_spec.device)) else: device = tensor_save_spec.device task = device_lib.DeviceSpec.from_string( saveable_object_util.set_cpu0(device.to_string())) self._shardable_tensors_by_task.setdefault(task, []).append( sharding_util.ShardableTensor( _tensor_save_spec=tensor_save_spec, tensor=tensor_value, dtype=tensor_save_spec.dtype, device=device, name=tensor_save_spec.name, shape=None, slice_spec=slice_spec.strip(), checkpoint_key=checkpoint_key, trackable=obj)) unique_tasks.add( saveable_object_util.set_cpu0(device.to_string())) self._num_unique_tasks = len(unique_tasks) self._registered_savers = {} if registered_savers: for registered_name, trackables in registered_savers.items(): save_fn = _get_mapped_registered_save_fn( registration.get_save_function(registered_name), trackables, call_with_mapped_captures) restore_fn = _get_mapped_registered_restore_fn( registration.get_restore_function(registered_name), trackables, call_with_mapped_captures) self._registered_savers[registered_name] = (save_fn, restore_fn) @classmethod def from_saveables( cls, saveables: Sequence[base.Trackable], registered_savers: "RegisteredSaversDict | None" = None, call_with_mapped_captures: "MappedCapturesCallable | None" = None ) -> "MultiDeviceSaver": """Constructs a MultiDeviceSaver from a list of `SaveableObject`s.""" serialized_tensors = object_identity.ObjectIdentityDictionary() for saveable in saveables: trackable = saveable_object_util.SaveableCompatibilityConverter( saveable, saveables=[saveable]) serialized_tensors[trackable] = trackable._serialize_to_tensors() # pylint: disable=protected-access return cls(serialized_tensors, registered_savers, call_with_mapped_captures) def to_proto(self) -> saver_pb2.SaverDef: """Serializes to a SaverDef referencing the current graph.""" filename_tensor = array_ops.placeholder( shape=[], dtype=dtypes.string, name="saver_filename") save_tensor = self._traced_save(filename_tensor) restore_op = self._traced_restore(filename_tensor).op return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.name, save_tensor_name=save_tensor.name, restore_op_name=restore_op.name, version=saver_pb2.SaverDef.V2) @def_function.function( input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), autograph=False) def _traced_save(self, file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: save_op = self.save(file_prefix) with ops.device("cpu:0"): with ops.control_dependencies([save_op]): return array_ops.identity(file_prefix) @def_function.function( input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), autograph=False) def _traced_restore( self, file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: restore_ops = self.restore(file_prefix) with ops.device("cpu:0"): with ops.control_dependencies(restore_ops.values()): return array_ops.identity(file_prefix) def _get_shards_by_task( self, sharding_callback: sharding_util.ShardingCallback ) -> Sequence[sharding_util.TensorSliceDict]: """Calls the sharding callback with shardable_tensors. Args: sharding_callback: ShardingCallback. The callback function wrapper that splits shardable_tensors into shards. Returns: A list of shards. """ def wrap_tensor(shardable_tensor): tensor_val = shardable_tensor.tensor tensor_shape = shardable_tensor.shape save_spec = shardable_tensor._tensor_save_spec # pylint: disable=protected-access with ops.device(shardable_tensor.device): save_spec_tensor = save_spec.tensor if tensor_val is None and save_spec_tensor is None: # A tensor value of `None` indicates that this SaveableObject gets # recorded in the object graph, but that no value is saved in the # checkpoint. return None elif save_spec_tensor is not None: # Pull the tensor value from _tensor_save_spec. tensor_val = save_spec_tensor tensor_shape = save_spec_tensor.shape # Propagate the save spec name and/or slice spec when they are tensors. # This makes sure properties like `layout` for dtensor names/slice specs # are preserved during sharding. if isinstance(save_spec.name, tensor_lib.Tensor): tensor_val._wrapped_name = save_spec.name # pylint: disable=protected-access if isinstance(shardable_tensor.slice_spec, tensor_lib.Tensor): tensor_val._wrapped_slice_spec = save_spec.slice_spec # pylint: disable=protected-access return dataclasses.replace( shardable_tensor, tensor=tensor_val, shape=tensor_shape) shardable_tensors_by_task = { task: [shardable_tensor for shardable_tensor in map(wrap_tensor, shardable_tensors) if shardable_tensor is not None] for task, shardable_tensors in self._shardable_tensors_by_task.items()} sharding_callback = ( sharding_callback or sharding_policies.ShardByTaskPolicy()) metrics.SetShardingCallbackDescription( description=sharding_callback.description) start_time = time.time() * 1e6 shards_by_task = [ (task, sharding_callback(shardable_tensors)) for task, shardable_tensors in shardable_tensors_by_task.items()] callback_duration = math.ceil(time.time() * 1e6 - start_time) metrics.AddShardingCallbackDuration( callback_duration=max(1, callback_duration)) # in microseconds logging.info("Sharding callback duration: %s", callback_duration) return shards_by_task def save( self, file_prefix: tensor_lib.Tensor, options: "checkpoint_options.CheckpointOptions | None" = None ) -> ops.Operation: """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() # IMPLEMENTATION DETAILS: most clients should skip. # # Suffix for any well-formed "checkpoint_prefix", when sharded. # Transformations: # * Users pass in "save_path" in save() and restore(). Say "myckpt". # * checkpoint_prefix gets fed . # # Example: # During runtime, a temporary directory is first created, which contains # files # # /myckpt_temp/ # part-?????-of-?????{.index, .data-00000-of-00001} # # Before .save() finishes, they will be (hopefully, atomically) renamed to # # / # myckpt{.index, .data-?????-of-?????} # # Filesystems with eventual consistency (such as S3), don't need a # temporary location. Using a temporary directory in those cases might # cause situations where files are not available during copy. # # Users only need to interact with the user-specified prefix, which is # "/myckpt" in this case. Save() and Restore() work with the # prefix directly, instead of any physical pathname. (On failure and # subsequent restore, an outdated and orphaned temporary directory can be # safely removed.) with ops.device("CPU"): sharded_suffix = array_ops.where( string_ops.regex_full_match(file_prefix, "^s3://.*"), constant_op.constant(".part"), constant_op.constant("_temp/part")) tmp_checkpoint_prefix = string_ops.string_join( [file_prefix, sharded_suffix]) registered_paths = { saver_name: registered_saver_filename(file_prefix, saver_name) for saver_name in self._registered_savers } def save_fn() -> ops.Operation: saved_prefixes = [] # Save with the registered savers. These run before default savers due to # the API contract. for saver_name, (save_fn, _) in self._registered_savers.items(): maybe_saved_prefixes = save_fn(registered_paths[saver_name]) if maybe_saved_prefixes is not None: flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes) if not all( tensor_util.is_tf_type(x) and x.dtype == dtypes.string for x in flattened_saved_prefixes): raise ValueError( "Registered saver must return a (maybe empty) list of " f"string type tensors. Got {maybe_saved_prefixes}.") saved_prefixes.extend(flattened_saved_prefixes) shards_by_task = self._get_shards_by_task( options.experimental_sharding_callback) num_shards = sum([len(shards) for _, shards in shards_by_task]) metrics.AddNumCheckpointShardsWritten(num_shards=num_shards) num_shards_tensor = constant_op.constant(num_shards, name="num_shards") sharded_saves = [] shard_idx = 0 for task, shards in shards_by_task: for shard in shards: with ops.device(task): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard_idx, num_shards_tensor) shard_idx += 1 saved_prefixes.append(shard_prefix) sharded_saves.append( _single_shard_save(shard_prefix, shard, task, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. tensor_device_spec = list(self._shardable_tensors_by_task.keys())[-1] merge_device_spec = ( options.experimental_io_device or saveable_object_util.set_cpu0(tensor_device_spec.to_string())) with ops.device(merge_device_spec): # V2 format write path consists of a metadata merge step. Once # merged, attempts to delete the temporary directory, # "_temp". return gen_io_ops.merge_v2_checkpoints( saved_prefixes, file_prefix, delete_old_dirs=True) # Since this will causes a function re-trace on each save, limit this to the # cases where it is needed: eager and when there are multiple tasks. Note # that the retrace is needed to ensure we pickup the latest values of # options like experimental_io_device. if context.executing_eagerly() and self._num_unique_tasks > 1: # Explicitly place the identity op on the first device. @def_function.function(jit_compile=False) def tf_function_save() -> None: save_fn() tf_function_save() else: return save_fn() def restore( self, file_prefix: tensor_lib.Tensor, options: "checkpoint_options.CheckpointOptions | None" = None ) -> Mapping[str, ops.Operation]: """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: When not run eagerly or when saving on a single device, returns a dictionary mapping from SaveableObject names to restore operations; otherwise, returns an empty dict. """ options = options or checkpoint_options.CheckpointOptions() def restore_fn() -> Mapping[str, ops.Operation]: restore_fn_inputs = {} restore_fn_input_count = { fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()} restore_ops = {} for task, shard in self._shardable_tensors_by_task.items(): with ops.device(task): # Load values from checkpoint restored_tensor_dict = _single_shard_restore( file_prefix, shard, options) # Map restored tensors to the corresponding restore_fn, and see if # all inputs have all been loaded. Call `restore_fn` if that is the # case. for ckpt_key, slice_and_tensor in restored_tensor_dict.items(): for slice_spec, tensor in slice_and_tensor.items(): restore_fn = self._keys_to_restore_fn[(ckpt_key, slice_spec)] # Processing the returned restored_tensor_dict to prepare for # the Trackable `restore` function. The `restore` function # expects a map of `string name (checkpoint_key) -> Tensor`. # Unless there is a slice_spec, in which case the map will be of # `string name (checkpoint_key)-> slice_spec -> Tensor`. if slice_spec: (restore_fn_inputs.setdefault(restore_fn, {}).setdefault( ckpt_key, {})[slice_spec]) = tensor else: restore_fn_inputs.setdefault(restore_fn, {})[ckpt_key] = tensor restore_fn_input_count[restore_fn] -= 1 if restore_fn_input_count[restore_fn] == 0: restored_tensors = {} # Extracts the substring after the "/.ATTRIBUTES/" in the # ckpt_key from restore_fn_inputs[restore_fn] to # restored_tensors. For example, if # restore_fn_input[restore_fn] is dict # { "/.ATTIBUTES/a": Tensor}, restored_tensors will be # changed to dict {"a": Tensor} for ckpt_key, tensor in restore_fn_inputs[restore_fn].items(): restored_tensors[trackable_utils.extract_local_name( ckpt_key)] = tensor ret = restore_fn(restored_tensors) if isinstance(ret, dict): restore_ops.update(ret) # Run registered restore methods after the default restore ops. for _, (_, restore_fn) in self._registered_savers.items(): restore_fn(file_prefix) return restore_ops has_custom_device_saver = False for sts in self._shardable_tensors_by_task.values(): if any([context.is_custom_device(st.device.to_string()) for st in sts]): has_custom_device_saver = True break # Since this will cause a function re-trace on each restore, limit this to # cases where it is needed: eager and when there are multiple tasks or any # device_spec is a custom device. Note that the retrace is needed to ensure # we pickup the latest values of options like experimental_io_device. # # We run in a function when there is a custom device saver because custom # devices, such as DTensor, usually do a sharded save and restore. # Doing a sharded save and restore requires knowledge about what shards # of variables we are restoring to. In practice, this means that custom # devices need the AssignVariableOps along with the Restore op within the # same graph to infer shapes and shard specs for Restore op. if context.executing_eagerly() and (self._num_unique_tasks > 1 or has_custom_device_saver): @def_function.function(jit_compile=False, autograph=False) def tf_function_restore() -> Mapping[str, ops.Operation]: restore_fn() return {} restore_ops = tf_function_restore() else: restore_ops = restore_fn() return restore_ops