3RNN/Lib/site-packages/tensorflow/python/checkpoint/async_checkpoint_helper.py

629 lines
26 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for saving/loading Trackable objects asynchronously."""
import atexit
import copy
import queue
import threading
import time
import weakref
from absl import logging
from tensorflow.python.checkpoint import checkpoint_context
from tensorflow.python.checkpoint import trackable_view
from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.trackable import base
from tensorflow.python.util import object_identity
# Captures the timestamp of the first Checkpoint instantiation or end of a write
# operation. Can be accessed by multiple Checkpoint instances.
_END_TIME_OF_LAST_ASYNC_WRITE = None
_END_TIME_OF_LAST_ASYNC_WRITE_LOCK = threading.Lock()
# API label for cell names used in async checkpoint metrics.
_ASYNC_CHECKPOINT = "async_checkpoint"
# Name of TPUEmbedding attribute. This is a temporary workaround
# to identify TPUEmbedding while avoiding import cycles.
_TPU_EMBEDDING_ATTR = "_create_copy_for_async_checkpoint"
def _get_duration_microseconds(start_time_seconds, end_time_seconds):
"""Calculate the duration between start and end time.
Args:
start_time_seconds: The start time in seconds.
end_time_seconds: The end time in seconds.
Returns:
The duration between the start and the end time. Return 0 if
end_time_seconds < start_time_seconds.
"""
if end_time_seconds < start_time_seconds:
# Avoid returning negative value in case of clock skew.
return 0
return round((end_time_seconds - start_time_seconds) * 1000000)
def _get_all_trackables(root, exclude_set):
"""Return the list of checkpointable trackables dependent on `root`.
Args:
root: The root trackable from where we get all its dependent trackables.
exclude_set: An ObjectIdentitySet of Trackables to exclude before returning.
Each element in `exclude_set` is a specific instance of a `Trackable`
and appears precisely once in `TrackableView(root).descendants()`.
Returns:
saveable_trackables: All trackables that are saveable in `all_trackables`
(see definition of "saveable" in `_trackable_needs_to_be_saved()`). A
subset of `all_trackables`.
all_trackables: All trackables returned by `TrackableView`'s `descendants()`
after excluding `exclude_set`. A superset of `saveable_trackables`.
"""
all_trackables = trackable_view.TrackableView(root=root).descendants()
# Kick out the trackable we want to exclude.
# The goal of writing such loop is to only scan the list once and stop
# scanning as early as possible (unlike filtering with list comprehension).
trackable_index = 0
while trackable_index < len(all_trackables) and exclude_set:
# While we have not excluded all items, or gone through all trackables.
if all_trackables[trackable_index] in exclude_set:
# If want to exclude this trackable, we pop it and do not update ptr
exclude_set.discard(all_trackables[trackable_index])
all_trackables.pop(trackable_index)
else:
# Otherwise update ptr
trackable_index += 1
# Kick out trackables that do not need to be saved (e.g. ListWrapper, etc.)
# We define any trackable that does not implement `_serialize_to_tensor` or
# `_gather_saveables` as "no need to be saved". If the trackable has one or
# both of the methods defined, it should have `_copy_trackable_to_cpu`
# defined; if not, we will raise warning in `_copy_to_cpu()`. In case of
# special case, we also check whether a trackable (who has neither of the
# other two methods defined) defines `_copy_trackable_to_cpu` only; we still
# define such cases as "needs to be saved".
def _trackable_needs_to_be_saved(obj):
"""Returns whether a trackable needs to be saved.
Returns a bool to indicate whether obj's class has `_serialize_to_tensors`,
`gather_saveables_for_checkpoint`, or `_copy_trackable_to_cpu` defined.
Args:
obj: A Trackable object.
"""
if hasattr(obj, "__dict__"):
# Data structure proxy wrappers don't have __dict__.
if ("_serialize_to_tensors" in obj.__dict__
or "_gather_saveables_for_checkpoint" in obj.__dict__
or "_copy_trackable_to_cpu" in obj.__dict__):
return True
# Use MRO so that if a parent class has one of the three methods, we still
# consider `t` as needed to be saved.
for t in type(obj).mro():
if t is base.Trackable:
# Base class always has them implemented, but would raise error.
continue
elif ("_serialize_to_tensors" in t.__dict__
or "_gather_saveables_for_checkpoint" in t.__dict__
or "_copy_trackable_to_cpu" in t.__dict__):
return True
return False
saveable_trackables = [x for x in all_trackables if
_trackable_needs_to_be_saved(x)]
return saveable_trackables, all_trackables
class AsyncCheckpointHelper:
"""Helper class for async checkpoint."""
def __init__(self, checkpointer_impl, root=None, **kwargs):
"""Initialize AsyncCheckpoint.
Args:
checkpointer_impl: The Checkpoint class to power the AsyncCheckpoint.
root: The root object to checkpoint. `root` may be a trackable object or
`WeakRef` of a trackable object.
**kwargs: The keyword arguments representing the checkpointed variables.
Raises:
AttributeError: when checkpointer_impl is None.
"""
# TODO(chienchunh): Make sure the processing for the root object is
# consistent when integrating with the public API, e.g., adding all kwarg
# items as the child of the root object.
if root:
trackable_root = root() if isinstance(root, weakref.ref) else root
kwargs["root"] = trackable_root
trackable_root._maybe_initialize_trackable()
# The underlying Checkpoint instance and its items.
if checkpointer_impl is None:
raise AttributeError(
"checkpointer_impl cannot be None for AsyncCheckpointHelper."
)
self._checkpointer_impl = checkpointer_impl
self._checkpoint_items = kwargs
self._checkpoint = None
self.checkpointer()
self._checkpoint_options = None
# Indicate whether async checkpoint has finished traversing the variable
# list and created the object map between the original and copied variables.
self._initialized = False
# The list of all nodes from the original checkpoint items.
# TODO(chienchunh): Consider changing this to local variable.
self._original_nodes = None
# The mapping between the original and the copied resource variables.
# The copied variables are used for the underlying checkpointing.
self._object_map = None
# A list of TPUEmbedding objects included in the checkpoint items.
self._tpu_embedding_objects = None
# A list of highest level `Trackable`s we will copy; does not contain
# TPUEmbedding objects
self._saveable_trackables = None
self._default_device = device_util.current() or "CPU:0"
self._default_device = device_util.canonicalize(self._default_device)
self._save_file_prefix = None
self._use_checkpoint_save = False
self._async_save_thread = None
# Concurrent queue that coordinates the events for writing/reading the
# cpu-copied variables. A 'True' in the queue triggers the async thread to
# perform saving; a 'False' breaks the while loop so that the async thread
# exits; no other values will be added to the queue.
# Maxsize is set to 1 only to ensure the exit procedure. We could have used
# queue.join() in _join_async_save_thread(), but queue.join() does not have
# a timeout argument. Hence we use queue.put(timeout=300), in case the last
# checkpoint takes forever. To achieve that, maxsize needs to be 1.
self._queue = queue.Queue(maxsize=1)
# Register to join the async save thread upon exit.
atexit.register(self._join_async_save_thread)
self._async_error = None
global _END_TIME_OF_LAST_ASYNC_WRITE
with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK:
if _END_TIME_OF_LAST_ASYNC_WRITE is None:
_END_TIME_OF_LAST_ASYNC_WRITE = time.time()
@def_function.function
def _copy_to_cpu(self):
"""Copy the checkpointed variables from the accelerator to the host CPU.
TODO(chienchunh): Get the concrete function before firstly called to avoid
hangining the accelerators idle during function tracing.
"""
for t in self._saveable_trackables:
try:
t._copy_trackable_to_cpu(object_map=self._object_map) # pylint: disable=protected-access
except NotImplementedError as e:
logging.warning("Trackable %s skipped due to: %s", t, e)
for tpu_embedding in self._tpu_embedding_objects:
tpu_embedding._retrieve_variables() # pylint: disable=protected-access
def checkpointer(self):
"""Gets or creates the underlying Checkpoint instance."""
if self._checkpoint is None:
self._checkpoint = self._checkpointer_impl(**self._checkpoint_items)
return self._checkpoint
def _ensure_initialized(self):
"""Initialize the async checkpoint internal state."""
# This map will be used to store the CPU copy of all checkpointable objects
self._object_map = object_identity.ObjectIdentityDictionary()
self._tpu_embedding_objects = []
# Populate self._all_tracakbles, but exclude the checkpoint instance itself
# and its save_counter, as they will be returned by `descendants()`.
exclude_set = object_identity.ObjectIdentitySet()
exclude_set.add(self.checkpointer())
exclude_set.add(self.checkpointer().save_counter)
self._saveable_trackables, all_trackables = _get_all_trackables(
root=self.checkpointer(), exclude_set=exclude_set)
# Handle special cases: TPU Embedding, and slot variables.
# 1. TPUEmbedding: Different from other trackables, TPUEmbedding needs to
# call `_retrieve_variables` to checkpoint, while populating a dummy copy to
# the object map.
# 2. Slot variables: they need to be handled differently as they cannot be
# retrieved from `TrackableView.descendants()`.
# Note: dir() is used rather than hasattr() here to avoid triggering
# custom __getattr__ code, see b/152031870 for context.
for t in all_trackables:
# Special case 1: TPU Embedding, populate object_map here
# Special case 1: Handle TPU Embedding by addnig a dummy instance to the
# object map. Also add TPUEmbedding to separate list for special handling
# with values copy.
if hasattr(type(t), _TPU_EMBEDDING_ATTR):
self._handle_tpu_embedding(t)
# Special case 2: handle slot variables. The object_map is populated later
# when the variable values are being copied to host CPU for the first
# time.
if "get_slot_names" in dir(t):
slot_names = t.get_slot_names()
for slot_name in slot_names:
for original_variable in all_trackables:
if not isinstance(original_variable, variables.Variable):
continue
try:
# Usage of hasattr may result in KeyError
original_slot_variable = t.get_slot(original_variable, slot_name)
except (AttributeError, KeyError):
continue
if isinstance(original_slot_variable, base.Trackable):
self._saveable_trackables.append(original_slot_variable)
# Initiate the underlying Checkpoint instance's save_counter.
save_counter = self.checkpointer().save_counter.numpy()
logging.info("Initializing async checkpoint's save_counter: %d",
save_counter)
# Pass the object map of the copied variables to the underlying Checkpoint.
self.checkpointer()._saver._object_map = self._object_map # pylint: disable=protected-access
# We perform a `_copy_to_cpu()` to populate `self._object_map`,
# initializing copies. We do not call `self._copy_to_cpu()` directly
# because it is a tf function, which leads to access out of scope error.
# TODO(charlieruan) Figure out a better work around to solve the access
# out of scope error.
for t in self._saveable_trackables:
try:
t._copy_trackable_to_cpu(object_map=self._object_map) # pylint: disable=protected-access
except NotImplementedError as e:
logging.warning("Trackable %s skipped due to: %s", t, e)
for tpu_embedding in self._tpu_embedding_objects:
tpu_embedding._retrieve_variables() # pylint: disable=protected-access
# Initiate the async thread for checkpoint saving.
self._async_save_thread = threading.Thread(
target=self._async_save, daemon=True)
self._async_save_thread.start()
self._initialized = True
def _check_async_thread_error(self):
"""Expose the most recent error from the async saving thread to the caller.
"""
if self._async_error:
e = self._async_error
self._async_error = None
logging.error("Propagating the most recent error from the async thread "
"before joining: %s", str(e))
raise e
def _join_async_save_thread(self):
"""Join the async save thread.
The steps for terminating the async save thread:
1). Put will succeed when the last async save event is done. Putting a false
triggers the async save thread's while loop to end. We use put instead
of sync because sync does not have a timeout argument.
2). Join the async save thread. (The thread may finish before joining.)
"""
try:
self._queue.put(False, timeout=300) # Step-1.
logging.info("Joining the async save thread.")
if self._async_save_thread is not None:
self._async_save_thread.join() # Step-2.
except queue.Full:
logging.error("Timeout waiting for the async save thread; terminating the"
" thread instead. The last checkpoint may be incomeplete.")
finally:
self._check_async_thread_error()
def _async_save(self):
"""The thread function for the async checkpoint save."""
with context.executor_scope(
executor.new_executor(
enable_async=False, enable_streaming_enqueue=False)):
# The main thread inserts: a True to the queue when the user calls save,
# triggering async save; and a False when we exit the Checkpoint instance.
while self._queue.get():
logging.info("Starting async checkpoint save on the device: %s",
self._default_device)
async_save_start_time = time.time()
# Specify the ops placement on the worker if running with
# coordinator-worker mode. This is required as launching a new thread
# would clear the placement policy and make localhost the default
# placement, while the main thread's default placement would be the
# master worker's CPU:0.
try:
with ops.device(self._default_device):
with checkpoint_context.async_metrics_context():
if self._use_checkpoint_save:
self.checkpointer().save(
self._save_file_prefix, self._checkpoint_options
)
else:
self.checkpointer()._write( # pylint: disable=protected-access
self._save_file_prefix,
options=self._checkpoint_options,
)
except Exception as e: # # pylint: disable=broad-except
self._async_error = e
finally:
self._queue.task_done()
async_save_end_time = time.time()
metrics.AddAsyncCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(async_save_start_time,
async_save_end_time))
# Measure the elapsed time since the last checkpoint.
# Due to the nature of async checkpoint, here it actually captures the
# duration between the start_time of the previous checkpoint and the
# start time of this checkpoint. As a result, the duration of the final
# async checkpoint is excluded, which is fine since it does not take
# much time.
global _END_TIME_OF_LAST_ASYNC_WRITE
with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK:
metrics.AddTrainingTimeSaved(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(
_END_TIME_OF_LAST_ASYNC_WRITE, async_save_start_time))
_END_TIME_OF_LAST_ASYNC_WRITE = async_save_start_time
logging.info("Async save thread reached the end of the execution.")
def _handle_tpu_embedding(self, tpu_embedding):
"""Handle TPUEmbedding.
This is the only place where we populate object map in the class of
`AsyncCheckpointHelper`. For all other checkpointable trackables, we
populate object map using the trackable's own `_copy_trackable_to_cpu()`.
Args:
tpu_embedding: TPUEmbedding object to be handled.
Raises:
AttributeError: if the input trackable is not TPUEmbedding type.
"""
if not hasattr(type(tpu_embedding), _TPU_EMBEDDING_ATTR) or not callable(
tpu_embedding._create_copy_for_async_checkpoint # pylint: disable=protected-access
):
raise AttributeError(
"Expecting TPUEmbedding type; got %s" % type(tpu_embedding)
)
# Create a dummy TPUEmbedding object and add it to the object_map. This is
# to prevent the TPUEmbedding's save_callback from being triggered because
# the embedding values have already being retrieved by AsyncCheckpoint.
# pylint: disable=protected-access
new_embedding = tpu_embedding._create_copy_for_async_checkpoint(
feature_config=tpu_embedding._feature_config,
optimizer=tpu_embedding._table_config[0]
if tpu_embedding._table_config
else None,
pipeline_execution_with_tensor_core=tpu_embedding._pipeline_execution_with_tensor_core,
)
self._object_map[tpu_embedding] = new_embedding
# pylint: enable=protected-access
if tpu_embedding not in self._tpu_embedding_objects:
self._tpu_embedding_objects.append(tpu_embedding)
@property
def save_counter(self):
"""An integer variable numbering the checkpoint events.
This is maintained by the underlying tf.train.Checkpoing object employed by
AsyncCheckpoint class. The number starts at 0 and gets incremented for each
checkpoint event.
Returns:
The save counter variable.
"""
return self.checkpointer().save_counter
def write(self, save_path, options=None):
"""Save the checkpointed variables.
Args:
save_path: The file prefix of the checkpoint file.
options: Optional CheckpointOption instance.
Returns:
The full path of the checkpoint file.
"""
return self._write(save_path, options)
def _write(self, save_path, options=None):
"""Save the checkpointed variables.
This method has exactly the same logic as save(), except it does not
increment the underlying save_counter, which is done by the caller, e.g.,
CheckpointManager.
Args:
save_path: The file prefix of the checkpoint file.
options: Optional CheckpointOption instance.
Returns:
The full path of the checkpoint file.
"""
write_start_time = time.time()
if not self._initialized:
self._ensure_initialized()
else:
# First wait for async thread to finish the previous save, then copy the
# variable values to the host CPU.
self._queue.join()
self._copy_to_cpu()
# Surface the error from the async thread, if any.
# This step should come after the sem acquision step in the above, so that
# it makes sure it waits until the previous async save finishes storing the
# error.
self._check_async_thread_error()
# Trigger the async thread to checkpoint the cpu-copied variables.
# Need to wait until the weight copying finishes before checkpoint save.
context.async_wait()
self._save_file_prefix = save_path
self._use_checkpoint_save = False
# Ensure that we do not request async checkpointing to the underlying
# checkpointer as this could lead to an infinite loop.
self._checkpoint_options = copy.copy(options) if options else None
if self._checkpoint_options:
self._checkpoint_options.experimental_enable_async_checkpoint = False
self._queue.put(True) # Trigger save in async thread
write_end_time = time.time()
metrics.AddCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(write_start_time,
write_end_time))
return save_path
def save(self, save_path, options=None):
"""Save the checkpointed variables.
Args:
save_path: The file prefix of the checkpoint file.
options: Optional CheckpointOption instance.
Returns:
The full path of the checkpoint file.
"""
save_start_time = time.time()
# If this is the first time that AsyncCheckpoint.save() is called,
# initialize the internal states like `self._saveable_trackables`. We also
# populate `self._object_map` (i.e. initializing the cpu-copied variables
# and copy over the value for the first time) by essentially performing a
# `self._copy_to_cpu()`, hence the if-else logic here.
#
# This is not performed in the initializer because some variables, e.g.,
# slot variables of the optimizer, were not created until actually running
# the train function, so we could only get the complete list of the
# variables after some train steps were run.
if not self._initialized:
self._ensure_initialized()
else:
# First wait for async thread to finish the previous save, then copy the
# variable values to the host CPU.
self._queue.join()
self._copy_to_cpu()
# Surface the error from the async thread, if any.
# This step should come after the sem acquision step in the above, so that
# it makes sure it waits until the previous async save finishes storing the
# error.
self._check_async_thread_error()
# Retrieve the save counter from the underlying checkpoint object to
# re-construct the full path of the checkpoint file.
# This step has to happen before triggering the underlying checkpoint;
# otherwise, the save_counter value may or may not have been updated.
save_counter = self.checkpointer().save_counter.numpy() + 1
full_path = "{}-{}".format(save_path, save_counter)
# Trigger the async thread to checkpoint the cpu-copied variables.
# Need to wait until the weight copying finishes before checkpoint save.
context.async_wait()
self._save_file_prefix = save_path
self._use_checkpoint_save = True
# Ensure that we do not request async checkpointing to the underlying
# checkpointer as this could lead to an infinite loop.
self._checkpoint_options = copy.copy(options) if options else None
if self._checkpoint_options:
self._checkpoint_options.experimental_enable_async_checkpoint = False
self._queue.put(True) # Trigger save in async thread
save_end_time = time.time()
metrics.AddCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(save_start_time, save_end_time))
return full_path
def read(self, save_path, options=None):
"""Restore the checkpointed variables.
This method has exactly the same logic as restore(). This method is
implemented only to fulfill the duty of subclassing tf.train.Checkpoint.
Args:
save_path: The full name of the checkpoint file to be restored.
options: CheckpointOption instance.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See tf.train.Checkpoint.restore()
for more details.
"""
return self.restore(save_path, options)
def restore(self, save_path, options=None):
"""Restore the checkpointed variables.
Args:
save_path: The full name of the checkpoint file to be restored.
options: CheckpointOption instance.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See tf.train.Checkpoint.restore()
for more details.
"""
# Ensure that we do not request async checkpointing to the underlying
# checkpointer as this could lead to an infinite loop.
self._checkpoint_options = (
copy.copy(options) if options else self._checkpoint_options)
if self._checkpoint_options:
self._checkpoint_options.experimental_enable_async_checkpoint = False
# Wait for any ongoing checkpoint event to finish.
self._queue.join()
# Restore values of the cpu-copied variables directly back to accelerators
status = self.checkpointer().restore(save_path, self._checkpoint_options)
return status
def sync(self):
"""Sync on any ongoing save or restore events."""
self._queue.join()
logging.info("Sync on ongoing save/restore.")