# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for saving/loading Trackable objects asynchronously.""" import atexit import copy import queue import threading import time import weakref from absl import logging from tensorflow.python.checkpoint import checkpoint_context from tensorflow.python.checkpoint import trackable_view from tensorflow.python.distribute import device_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import executor from tensorflow.python.framework import ops from tensorflow.python.ops import variables from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.trackable import base from tensorflow.python.util import object_identity # Captures the timestamp of the first Checkpoint instantiation or end of a write # operation. Can be accessed by multiple Checkpoint instances. _END_TIME_OF_LAST_ASYNC_WRITE = None _END_TIME_OF_LAST_ASYNC_WRITE_LOCK = threading.Lock() # API label for cell names used in async checkpoint metrics. _ASYNC_CHECKPOINT = "async_checkpoint" # Name of TPUEmbedding attribute. This is a temporary workaround # to identify TPUEmbedding while avoiding import cycles. _TPU_EMBEDDING_ATTR = "_create_copy_for_async_checkpoint" def _get_duration_microseconds(start_time_seconds, end_time_seconds): """Calculate the duration between start and end time. Args: start_time_seconds: The start time in seconds. end_time_seconds: The end time in seconds. Returns: The duration between the start and the end time. Return 0 if end_time_seconds < start_time_seconds. """ if end_time_seconds < start_time_seconds: # Avoid returning negative value in case of clock skew. return 0 return round((end_time_seconds - start_time_seconds) * 1000000) def _get_all_trackables(root, exclude_set): """Return the list of checkpointable trackables dependent on `root`. Args: root: The root trackable from where we get all its dependent trackables. exclude_set: An ObjectIdentitySet of Trackables to exclude before returning. Each element in `exclude_set` is a specific instance of a `Trackable` and appears precisely once in `TrackableView(root).descendants()`. Returns: saveable_trackables: All trackables that are saveable in `all_trackables` (see definition of "saveable" in `_trackable_needs_to_be_saved()`). A subset of `all_trackables`. all_trackables: All trackables returned by `TrackableView`'s `descendants()` after excluding `exclude_set`. A superset of `saveable_trackables`. """ all_trackables = trackable_view.TrackableView(root=root).descendants() # Kick out the trackable we want to exclude. # The goal of writing such loop is to only scan the list once and stop # scanning as early as possible (unlike filtering with list comprehension). trackable_index = 0 while trackable_index < len(all_trackables) and exclude_set: # While we have not excluded all items, or gone through all trackables. if all_trackables[trackable_index] in exclude_set: # If want to exclude this trackable, we pop it and do not update ptr exclude_set.discard(all_trackables[trackable_index]) all_trackables.pop(trackable_index) else: # Otherwise update ptr trackable_index += 1 # Kick out trackables that do not need to be saved (e.g. ListWrapper, etc.) # We define any trackable that does not implement `_serialize_to_tensor` or # `_gather_saveables` as "no need to be saved". If the trackable has one or # both of the methods defined, it should have `_copy_trackable_to_cpu` # defined; if not, we will raise warning in `_copy_to_cpu()`. In case of # special case, we also check whether a trackable (who has neither of the # other two methods defined) defines `_copy_trackable_to_cpu` only; we still # define such cases as "needs to be saved". def _trackable_needs_to_be_saved(obj): """Returns whether a trackable needs to be saved. Returns a bool to indicate whether obj's class has `_serialize_to_tensors`, `gather_saveables_for_checkpoint`, or `_copy_trackable_to_cpu` defined. Args: obj: A Trackable object. """ if hasattr(obj, "__dict__"): # Data structure proxy wrappers don't have __dict__. if ("_serialize_to_tensors" in obj.__dict__ or "_gather_saveables_for_checkpoint" in obj.__dict__ or "_copy_trackable_to_cpu" in obj.__dict__): return True # Use MRO so that if a parent class has one of the three methods, we still # consider `t` as needed to be saved. for t in type(obj).mro(): if t is base.Trackable: # Base class always has them implemented, but would raise error. continue elif ("_serialize_to_tensors" in t.__dict__ or "_gather_saveables_for_checkpoint" in t.__dict__ or "_copy_trackable_to_cpu" in t.__dict__): return True return False saveable_trackables = [x for x in all_trackables if _trackable_needs_to_be_saved(x)] return saveable_trackables, all_trackables class AsyncCheckpointHelper: """Helper class for async checkpoint.""" def __init__(self, checkpointer_impl, root=None, **kwargs): """Initialize AsyncCheckpoint. Args: checkpointer_impl: The Checkpoint class to power the AsyncCheckpoint. root: The root object to checkpoint. `root` may be a trackable object or `WeakRef` of a trackable object. **kwargs: The keyword arguments representing the checkpointed variables. Raises: AttributeError: when checkpointer_impl is None. """ # TODO(chienchunh): Make sure the processing for the root object is # consistent when integrating with the public API, e.g., adding all kwarg # items as the child of the root object. if root: trackable_root = root() if isinstance(root, weakref.ref) else root kwargs["root"] = trackable_root trackable_root._maybe_initialize_trackable() # The underlying Checkpoint instance and its items. if checkpointer_impl is None: raise AttributeError( "checkpointer_impl cannot be None for AsyncCheckpointHelper." ) self._checkpointer_impl = checkpointer_impl self._checkpoint_items = kwargs self._checkpoint = None self.checkpointer() self._checkpoint_options = None # Indicate whether async checkpoint has finished traversing the variable # list and created the object map between the original and copied variables. self._initialized = False # The list of all nodes from the original checkpoint items. # TODO(chienchunh): Consider changing this to local variable. self._original_nodes = None # The mapping between the original and the copied resource variables. # The copied variables are used for the underlying checkpointing. self._object_map = None # A list of TPUEmbedding objects included in the checkpoint items. self._tpu_embedding_objects = None # A list of highest level `Trackable`s we will copy; does not contain # TPUEmbedding objects self._saveable_trackables = None self._default_device = device_util.current() or "CPU:0" self._default_device = device_util.canonicalize(self._default_device) self._save_file_prefix = None self._use_checkpoint_save = False self._async_save_thread = None # Concurrent queue that coordinates the events for writing/reading the # cpu-copied variables. A 'True' in the queue triggers the async thread to # perform saving; a 'False' breaks the while loop so that the async thread # exits; no other values will be added to the queue. # Maxsize is set to 1 only to ensure the exit procedure. We could have used # queue.join() in _join_async_save_thread(), but queue.join() does not have # a timeout argument. Hence we use queue.put(timeout=300), in case the last # checkpoint takes forever. To achieve that, maxsize needs to be 1. self._queue = queue.Queue(maxsize=1) # Register to join the async save thread upon exit. atexit.register(self._join_async_save_thread) self._async_error = None global _END_TIME_OF_LAST_ASYNC_WRITE with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK: if _END_TIME_OF_LAST_ASYNC_WRITE is None: _END_TIME_OF_LAST_ASYNC_WRITE = time.time() @def_function.function def _copy_to_cpu(self): """Copy the checkpointed variables from the accelerator to the host CPU. TODO(chienchunh): Get the concrete function before firstly called to avoid hangining the accelerators idle during function tracing. """ for t in self._saveable_trackables: try: t._copy_trackable_to_cpu(object_map=self._object_map) # pylint: disable=protected-access except NotImplementedError as e: logging.warning("Trackable %s skipped due to: %s", t, e) for tpu_embedding in self._tpu_embedding_objects: tpu_embedding._retrieve_variables() # pylint: disable=protected-access def checkpointer(self): """Gets or creates the underlying Checkpoint instance.""" if self._checkpoint is None: self._checkpoint = self._checkpointer_impl(**self._checkpoint_items) return self._checkpoint def _ensure_initialized(self): """Initialize the async checkpoint internal state.""" # This map will be used to store the CPU copy of all checkpointable objects self._object_map = object_identity.ObjectIdentityDictionary() self._tpu_embedding_objects = [] # Populate self._all_tracakbles, but exclude the checkpoint instance itself # and its save_counter, as they will be returned by `descendants()`. exclude_set = object_identity.ObjectIdentitySet() exclude_set.add(self.checkpointer()) exclude_set.add(self.checkpointer().save_counter) self._saveable_trackables, all_trackables = _get_all_trackables( root=self.checkpointer(), exclude_set=exclude_set) # Handle special cases: TPU Embedding, and slot variables. # 1. TPUEmbedding: Different from other trackables, TPUEmbedding needs to # call `_retrieve_variables` to checkpoint, while populating a dummy copy to # the object map. # 2. Slot variables: they need to be handled differently as they cannot be # retrieved from `TrackableView.descendants()`. # Note: dir() is used rather than hasattr() here to avoid triggering # custom __getattr__ code, see b/152031870 for context. for t in all_trackables: # Special case 1: TPU Embedding, populate object_map here # Special case 1: Handle TPU Embedding by addnig a dummy instance to the # object map. Also add TPUEmbedding to separate list for special handling # with values copy. if hasattr(type(t), _TPU_EMBEDDING_ATTR): self._handle_tpu_embedding(t) # Special case 2: handle slot variables. The object_map is populated later # when the variable values are being copied to host CPU for the first # time. if "get_slot_names" in dir(t): slot_names = t.get_slot_names() for slot_name in slot_names: for original_variable in all_trackables: if not isinstance(original_variable, variables.Variable): continue try: # Usage of hasattr may result in KeyError original_slot_variable = t.get_slot(original_variable, slot_name) except (AttributeError, KeyError): continue if isinstance(original_slot_variable, base.Trackable): self._saveable_trackables.append(original_slot_variable) # Initiate the underlying Checkpoint instance's save_counter. save_counter = self.checkpointer().save_counter.numpy() logging.info("Initializing async checkpoint's save_counter: %d", save_counter) # Pass the object map of the copied variables to the underlying Checkpoint. self.checkpointer()._saver._object_map = self._object_map # pylint: disable=protected-access # We perform a `_copy_to_cpu()` to populate `self._object_map`, # initializing copies. We do not call `self._copy_to_cpu()` directly # because it is a tf function, which leads to access out of scope error. # TODO(charlieruan) Figure out a better work around to solve the access # out of scope error. for t in self._saveable_trackables: try: t._copy_trackable_to_cpu(object_map=self._object_map) # pylint: disable=protected-access except NotImplementedError as e: logging.warning("Trackable %s skipped due to: %s", t, e) for tpu_embedding in self._tpu_embedding_objects: tpu_embedding._retrieve_variables() # pylint: disable=protected-access # Initiate the async thread for checkpoint saving. self._async_save_thread = threading.Thread( target=self._async_save, daemon=True) self._async_save_thread.start() self._initialized = True def _check_async_thread_error(self): """Expose the most recent error from the async saving thread to the caller. """ if self._async_error: e = self._async_error self._async_error = None logging.error("Propagating the most recent error from the async thread " "before joining: %s", str(e)) raise e def _join_async_save_thread(self): """Join the async save thread. The steps for terminating the async save thread: 1). Put will succeed when the last async save event is done. Putting a false triggers the async save thread's while loop to end. We use put instead of sync because sync does not have a timeout argument. 2). Join the async save thread. (The thread may finish before joining.) """ try: self._queue.put(False, timeout=300) # Step-1. logging.info("Joining the async save thread.") if self._async_save_thread is not None: self._async_save_thread.join() # Step-2. except queue.Full: logging.error("Timeout waiting for the async save thread; terminating the" " thread instead. The last checkpoint may be incomeplete.") finally: self._check_async_thread_error() def _async_save(self): """The thread function for the async checkpoint save.""" with context.executor_scope( executor.new_executor( enable_async=False, enable_streaming_enqueue=False)): # The main thread inserts: a True to the queue when the user calls save, # triggering async save; and a False when we exit the Checkpoint instance. while self._queue.get(): logging.info("Starting async checkpoint save on the device: %s", self._default_device) async_save_start_time = time.time() # Specify the ops placement on the worker if running with # coordinator-worker mode. This is required as launching a new thread # would clear the placement policy and make localhost the default # placement, while the main thread's default placement would be the # master worker's CPU:0. try: with ops.device(self._default_device): with checkpoint_context.async_metrics_context(): if self._use_checkpoint_save: self.checkpointer().save( self._save_file_prefix, self._checkpoint_options ) else: self.checkpointer()._write( # pylint: disable=protected-access self._save_file_prefix, options=self._checkpoint_options, ) except Exception as e: # # pylint: disable=broad-except self._async_error = e finally: self._queue.task_done() async_save_end_time = time.time() metrics.AddAsyncCheckpointWriteDuration( api_label=_ASYNC_CHECKPOINT, microseconds=_get_duration_microseconds(async_save_start_time, async_save_end_time)) # Measure the elapsed time since the last checkpoint. # Due to the nature of async checkpoint, here it actually captures the # duration between the start_time of the previous checkpoint and the # start time of this checkpoint. As a result, the duration of the final # async checkpoint is excluded, which is fine since it does not take # much time. global _END_TIME_OF_LAST_ASYNC_WRITE with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK: metrics.AddTrainingTimeSaved( api_label=_ASYNC_CHECKPOINT, microseconds=_get_duration_microseconds( _END_TIME_OF_LAST_ASYNC_WRITE, async_save_start_time)) _END_TIME_OF_LAST_ASYNC_WRITE = async_save_start_time logging.info("Async save thread reached the end of the execution.") def _handle_tpu_embedding(self, tpu_embedding): """Handle TPUEmbedding. This is the only place where we populate object map in the class of `AsyncCheckpointHelper`. For all other checkpointable trackables, we populate object map using the trackable's own `_copy_trackable_to_cpu()`. Args: tpu_embedding: TPUEmbedding object to be handled. Raises: AttributeError: if the input trackable is not TPUEmbedding type. """ if not hasattr(type(tpu_embedding), _TPU_EMBEDDING_ATTR) or not callable( tpu_embedding._create_copy_for_async_checkpoint # pylint: disable=protected-access ): raise AttributeError( "Expecting TPUEmbedding type; got %s" % type(tpu_embedding) ) # Create a dummy TPUEmbedding object and add it to the object_map. This is # to prevent the TPUEmbedding's save_callback from being triggered because # the embedding values have already being retrieved by AsyncCheckpoint. # pylint: disable=protected-access new_embedding = tpu_embedding._create_copy_for_async_checkpoint( feature_config=tpu_embedding._feature_config, optimizer=tpu_embedding._table_config[0] if tpu_embedding._table_config else None, pipeline_execution_with_tensor_core=tpu_embedding._pipeline_execution_with_tensor_core, ) self._object_map[tpu_embedding] = new_embedding # pylint: enable=protected-access if tpu_embedding not in self._tpu_embedding_objects: self._tpu_embedding_objects.append(tpu_embedding) @property def save_counter(self): """An integer variable numbering the checkpoint events. This is maintained by the underlying tf.train.Checkpoing object employed by AsyncCheckpoint class. The number starts at 0 and gets incremented for each checkpoint event. Returns: The save counter variable. """ return self.checkpointer().save_counter def write(self, save_path, options=None): """Save the checkpointed variables. Args: save_path: The file prefix of the checkpoint file. options: Optional CheckpointOption instance. Returns: The full path of the checkpoint file. """ return self._write(save_path, options) def _write(self, save_path, options=None): """Save the checkpointed variables. This method has exactly the same logic as save(), except it does not increment the underlying save_counter, which is done by the caller, e.g., CheckpointManager. Args: save_path: The file prefix of the checkpoint file. options: Optional CheckpointOption instance. Returns: The full path of the checkpoint file. """ write_start_time = time.time() if not self._initialized: self._ensure_initialized() else: # First wait for async thread to finish the previous save, then copy the # variable values to the host CPU. self._queue.join() self._copy_to_cpu() # Surface the error from the async thread, if any. # This step should come after the sem acquision step in the above, so that # it makes sure it waits until the previous async save finishes storing the # error. self._check_async_thread_error() # Trigger the async thread to checkpoint the cpu-copied variables. # Need to wait until the weight copying finishes before checkpoint save. context.async_wait() self._save_file_prefix = save_path self._use_checkpoint_save = False # Ensure that we do not request async checkpointing to the underlying # checkpointer as this could lead to an infinite loop. self._checkpoint_options = copy.copy(options) if options else None if self._checkpoint_options: self._checkpoint_options.experimental_enable_async_checkpoint = False self._queue.put(True) # Trigger save in async thread write_end_time = time.time() metrics.AddCheckpointWriteDuration( api_label=_ASYNC_CHECKPOINT, microseconds=_get_duration_microseconds(write_start_time, write_end_time)) return save_path def save(self, save_path, options=None): """Save the checkpointed variables. Args: save_path: The file prefix of the checkpoint file. options: Optional CheckpointOption instance. Returns: The full path of the checkpoint file. """ save_start_time = time.time() # If this is the first time that AsyncCheckpoint.save() is called, # initialize the internal states like `self._saveable_trackables`. We also # populate `self._object_map` (i.e. initializing the cpu-copied variables # and copy over the value for the first time) by essentially performing a # `self._copy_to_cpu()`, hence the if-else logic here. # # This is not performed in the initializer because some variables, e.g., # slot variables of the optimizer, were not created until actually running # the train function, so we could only get the complete list of the # variables after some train steps were run. if not self._initialized: self._ensure_initialized() else: # First wait for async thread to finish the previous save, then copy the # variable values to the host CPU. self._queue.join() self._copy_to_cpu() # Surface the error from the async thread, if any. # This step should come after the sem acquision step in the above, so that # it makes sure it waits until the previous async save finishes storing the # error. self._check_async_thread_error() # Retrieve the save counter from the underlying checkpoint object to # re-construct the full path of the checkpoint file. # This step has to happen before triggering the underlying checkpoint; # otherwise, the save_counter value may or may not have been updated. save_counter = self.checkpointer().save_counter.numpy() + 1 full_path = "{}-{}".format(save_path, save_counter) # Trigger the async thread to checkpoint the cpu-copied variables. # Need to wait until the weight copying finishes before checkpoint save. context.async_wait() self._save_file_prefix = save_path self._use_checkpoint_save = True # Ensure that we do not request async checkpointing to the underlying # checkpointer as this could lead to an infinite loop. self._checkpoint_options = copy.copy(options) if options else None if self._checkpoint_options: self._checkpoint_options.experimental_enable_async_checkpoint = False self._queue.put(True) # Trigger save in async thread save_end_time = time.time() metrics.AddCheckpointWriteDuration( api_label=_ASYNC_CHECKPOINT, microseconds=_get_duration_microseconds(save_start_time, save_end_time)) return full_path def read(self, save_path, options=None): """Restore the checkpointed variables. This method has exactly the same logic as restore(). This method is implemented only to fulfill the duty of subclassing tf.train.Checkpoint. Args: save_path: The full name of the checkpoint file to be restored. options: CheckpointOption instance. Returns: A load status object, which can be used to make assertions about the status of a checkpoint restoration. See tf.train.Checkpoint.restore() for more details. """ return self.restore(save_path, options) def restore(self, save_path, options=None): """Restore the checkpointed variables. Args: save_path: The full name of the checkpoint file to be restored. options: CheckpointOption instance. Returns: A load status object, which can be used to make assertions about the status of a checkpoint restoration. See tf.train.Checkpoint.restore() for more details. """ # Ensure that we do not request async checkpointing to the underlying # checkpointer as this could lead to an infinite loop. self._checkpoint_options = ( copy.copy(options) if options else self._checkpoint_options) if self._checkpoint_options: self._checkpoint_options.experimental_enable_async_checkpoint = False # Wait for any ongoing checkpoint event to finish. self._queue.join() # Restore values of the cpu-copied variables directly back to accelerators status = self.checkpointer().restore(save_path, self._checkpoint_options) return status def sync(self): """Sync on any ongoing save or restore events.""" self._queue.join() logging.info("Sync on ongoing save/restore.")