Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/checkpoint/async_checkpoint_helper.py
2023-06-19 00:49:18 +02:00

562 lines
22 KiB
Python

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for saving/loading Trackable objects asynchronously."""
import atexit
import collections
import copy
import threading
import time
import weakref
from absl import logging
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute.sharded_variable import ShardedVariable
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import ops
from tensorflow.python.ops.resource_variable_ops import UninitializedVariable
from tensorflow.python.ops.variables import Variable
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.tpu.tpu_embedding_v2 import TPUEmbedding
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.util import object_identity
# Captures the timestamp of the first Checkpoint instantiation or end of a write
# operation. Can be accessed by multiple Checkpoint instances.
_END_TIME_OF_LAST_ASYNC_WRITE = None
_END_TIME_OF_LAST_ASYNC_WRITE_LOCK = threading.Lock()
# API label for cell names used in async checkpoint metrics.
_ASYNC_CHECKPOINT = "async_checkpoint"
def _get_duration_microseconds(start_time_seconds, end_time_seconds):
"""Calculate the duration between start and end time.
Args:
start_time_seconds: The start time in seconds.
end_time_seconds: The end time in seconds.
Returns:
The duration between the start and the end time. Return 0 if
end_time_seconds < start_time_seconds.
"""
if end_time_seconds < start_time_seconds:
# Avoid returning negative value in case of clock skew.
return 0
return round((end_time_seconds - start_time_seconds) * 1000000)
class AsyncCheckpointHelper:
"""Helper class for async checkpoint."""
def __init__(self, checkpointer_impl, root=None, **kwargs):
"""Initialize AsyncCheckpoint.
Args:
checkpointer_impl: The Checkpoint class to power the AsyncCheckpoint.
root: The root object to checkpoint. `root` may be a trackable object or
`WeakRef` of a trackable object.
**kwargs: The keyword arguments representing the checkpointed variables.
"""
# TODO(chienchunh): Make sure the processing for the root object is
# consistent when integrating with the public API, e.g., adding all kwarg
# items as the child of the root object.
if root:
trackable_root = root() if isinstance(root, weakref.ref) else root
kwargs["root"] = trackable_root
trackable_root._maybe_initialize_trackable()
self._checkpointer_impl = checkpointer_impl
self._checkpoint_items = kwargs
# The underlying Checkpoint instance and its items.
self._checkpoint = None
self._checkpoint_options = None
# The callback function that needs to be executed after checkpoint write.
# Currently this is only applied to the scenario where CheckpointManager is
# used, which triggers the _write() method.
self._async_write_done_callback = None
# The list of all nodes from the original checkpoint items.
# TODO(chienchunh): Consider changing this to local variable.
self._original_nodes = None
# The mapping between the original and the copied resource variables.
# The copied variables are used for the underlying checkpointing.
self._object_map = None
# A list of TPUEmbedding objects included in the checkpoint items.
self._tpu_embedding_objects = None
self._default_device = device_util.current() or "CPU:0"
self._default_device = device_util.canonicalize(self._default_device)
self._save_file_prefix = None
self._use_checkpoint_save = False
self._async_save_thread = None
self._async_save_thread_shutdown = False
# Semaphores for writing/reading the cpu-copied variables (self._var_pairs)
# TODO(chienchunh): Consider Queue/Condition instead of Semaphore.
self._writer_sem = threading.Semaphore(1)
self._reader_sem = threading.Semaphore(0)
# Register to join the async save thread upon exit.
atexit.register(self._join_async_save_thread)
global _END_TIME_OF_LAST_ASYNC_WRITE
with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK:
if _END_TIME_OF_LAST_ASYNC_WRITE is None:
_END_TIME_OF_LAST_ASYNC_WRITE = time.time()
@def_function.function
def _copy_from_cpu(self):
"""Copy the checkpointed variables from the host CPU to the accelerator.
TODO(chienchunh): Get the concrete function before firstly called to avoid
hangining the accelerators idle during function tracing.
"""
for accelerator_var, cpu_var in self._object_map.items():
if isinstance(accelerator_var, (ShardedVariable, TPUEmbedding)):
# Skip for SharededVariable and TPUEmbedding as their sub-variables will
# be copied over separately through other entries in the object map.
continue
with ops.device(accelerator_var.device):
accelerator_var.assign(cpu_var.read_value())
@def_function.function
def _copy_to_cpu(self):
"""Copy the checkpointed variables from the accelerator to the host CPU.
TODO(chienchunh): Get the concrete function before firstly called to avoid
hangining the accelerators idle during function tracing.
"""
for accelerator_var, cpu_var in self._object_map.items():
if isinstance(accelerator_var, (ShardedVariable, TPUEmbedding)):
# Skip for SharededVariable and TPUEmbedding as their sub-variables will
# be copied over separately through other entries in the object map.
continue
with ops.device(cpu_var.device):
cpu_var.assign(accelerator_var.read_value())
for tpu_embedding in self._tpu_embedding_objects:
tpu_embedding._retrieve_variables() # pylint: disable=protected-access
def _traverse_variables(self, to_traverse, visited):
"""Create the copied nodes and variables while traversing the nodes.
This method performs a BFS to traverse the nodes while avoiding duplicated
visits. Throughout the process, self._mapping, self._original_nodes, and
self._var_pairs are populated.
Args:
to_traverse: A deque that stores the nodes to be traversed.
visited: A list of nodes that have been visited.
"""
# pylint: disable=protected-access
while to_traverse:
current_trackable = to_traverse.popleft()
self._original_nodes.append(current_trackable)
if isinstance(current_trackable, (Variable, ShardedVariable)):
self._copy_trackable(current_trackable)
if isinstance(current_trackable, TPUEmbedding):
self._handle_tpu_embedding(current_trackable)
for child in current_trackable._trackable_children().values():
if child in visited:
continue
visited.add(child)
to_traverse.append(child)
# pylint: enable=protected-access
def _ensure_initialized(self):
"""Initialize the async checkpoint internal state."""
if self._checkpoint is not None:
return
self._original_nodes = []
self._object_map = object_identity.ObjectIdentityDictionary()
self._tpu_embedding_objects = []
# Add the top-level checkpoint items to be traversed,
to_traverse = collections.deque([])
visited = object_identity.ObjectIdentitySet()
for v in self._checkpoint_items.values():
if isinstance(v, (Variable, ShardedVariable)):
self._copy_trackable(v)
elif isinstance(v, TPUEmbedding):
self._handle_tpu_embedding(v)
to_traverse.append(v)
visited.add(v)
self._traverse_variables(to_traverse, visited)
# Copy for the slot variables.
for current_trackable in self._original_nodes:
if (isinstance(current_trackable, optimizer_v1.Optimizer)
# Note: dir() is used rather than hasattr() here to avoid triggering
# custom __getattr__ code, see b/152031870 for context.
or "get_slot_names" in dir(current_trackable)):
slot_names = current_trackable.get_slot_names()
for slot_name in slot_names:
for original_variable in self._original_nodes:
if not isinstance(original_variable, Variable):
continue
try:
original_slot_variable = current_trackable.get_slot(
original_variable, slot_name)
except (AttributeError, KeyError):
continue
if isinstance(original_slot_variable, (Variable, ShardedVariable)):
self._copy_trackable(original_slot_variable)
# Initiate the underlying Checkpoint instance with the copied items.
self._checkpoint = self._checkpointer_impl(**self._checkpoint_items)
# Pass the object map of the copied variables to the underlying Checkpoint.
self._checkpoint._saver._object_map = self._object_map # pylint: disable=protected-access
# Initiate the async thread for checkpoint saving.
self._async_save_thread = threading.Thread(
target=self._async_save, daemon=True)
self._async_save_thread.start()
def _join_async_save_thread(self):
"""Join the async save thread.
The steps for terminating the async save thread:
1). Wait until the last async save event is done.
2). Set _async_save_thread_shutdown flag to false to indicate termination.
3). Trigger the async save thread to check and fail the while-predicate.
4). Join the async save thread. (The thread may finish before joining.)
"""
if self._writer_sem.acquire(timeout=3600): # Step-1.
self._async_save_thread_shutdown = True # Step-2.
self._reader_sem.release() # Step-3.
logging.info("Joining the async save thread.")
if self._async_save_thread is not None:
self._async_save_thread.join() # Step-4.
else:
logging.error("Timeout waiting for the async save thread; terminating the"
" thread instead. The last checkpoint may be incomeplete.")
def _async_save(self):
"""The thread function for the async checkpoint save."""
with context.executor_scope(
executor.new_executor(
enable_async=False, enable_streaming_enqueue=False)):
while self._reader_sem.acquire() and not self._async_save_thread_shutdown:
logging.info("Starting async checkpoint save on the device: %s",
self._default_device)
async_save_start_time = time.time()
# Specify the ops placement on the worker if running with
# coordinator-worker mode. This is required as launching a new thread
# would clear the placement policy and make localhost the default
# placement, while the main thread's default placement would be the
# master worker's CPU:0.
with ops.device(self._default_device):
if self._use_checkpoint_save:
self._checkpoint.save(self._save_file_prefix,
self._checkpoint_options)
else:
self._checkpoint._write( # pylint: disable=protected-access
self._save_file_prefix,
options=self._checkpoint_options,
write_done_callback=self._async_write_done_callback)
# Allow the next checkpoint event to overwrite the cpu-copied variables.
self._writer_sem.release()
async_save_end_time = time.time()
metrics.AddAsyncCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(async_save_start_time,
async_save_end_time))
# Measure the elapsed time since the last checkpoint.
# Due to the nature of async checkpoint, here it actually captures the
# duration between the start_time of the previous checkpoint and the
# start time of this checkpoint. As a result, the duration of the final
# async checkpoint is excluded, which is fine since it does not take
# much time.
global _END_TIME_OF_LAST_ASYNC_WRITE
with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK:
metrics.AddTrainingTimeSaved(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(
_END_TIME_OF_LAST_ASYNC_WRITE, async_save_start_time))
_END_TIME_OF_LAST_ASYNC_WRITE = async_save_start_time
logging.info("Async save thread reached the end of the execution.")
def _copy_for_variable(self, original_var):
"""Create a new instance for the input trackable.
Args:
original_var: Input Variable object to be copied.
"""
op_device = pydev.DeviceSpec.from_string(original_var.device).replace(
device_type="CPU", device_index=0).to_string()
with ops.device(op_device):
new_var = UninitializedVariable(
trainable=original_var.trainable,
shape=original_var.shape,
dtype=original_var.dtype,
name=original_var._shared_name) # pylint: disable=protected-access
self._object_map[original_var] = new_var
def _copy_for_sharded_variable(self, original_var):
"""Create a new instance for the input ShardedVariable.
Args:
original_var: Input ShardedVariable object to be copied.
"""
copied_vars = []
for v in original_var._variables: # pylint: disable=protected-access
self._copy_for_variable(v)
copied_vars.append(self._object_map[v])
self._object_map[original_var] = ShardedVariable(
copied_vars, name=original_var.name)
def _copy_trackable(self, original_trackable):
"""Create a new instance for the input trackable.
Args:
original_trackable: The trackable instance to be copied.
Raises:
AttributeError: if the input trackable is not Variable or ShardedVariable.
"""
if isinstance(original_trackable, ShardedVariable):
self._copy_for_sharded_variable(original_trackable)
elif isinstance(original_trackable, Variable):
self._copy_for_variable(original_trackable)
else:
raise AttributeError("Only Variable or ShardedVariable can be copied.")
def _handle_tpu_embedding(self, tpu_embedding):
"""Handle TPUEmbedding.
Args:
tpu_embedding: TPUEmbedding object to be handled.
Raises:
AttributeError: if the input trackable is not TPUEmbedding type.
"""
if not isinstance(tpu_embedding, TPUEmbedding):
raise AttributeError("Expecting TPUEmbedding type; got %s" %
type(tpu_embedding))
# Create a dummy TPUEmbedding object and add it to the object_map. This is
# to prevent the TPUEmbedding's save_callback from being triggered because
# the embedding values have already being retrieved by AsyncCheckpoint.
# pylint: disable=protected-access
new_embedding = TPUEmbedding(
feature_config=tpu_embedding._feature_config,
optimizer=tpu_embedding._table_config[0].optimizer,
pipeline_execution_with_tensor_core=tpu_embedding
._pipeline_execution_with_tensor_core)
self._object_map[tpu_embedding] = new_embedding
# pylint: enable=protected-access
if tpu_embedding not in self._tpu_embedding_objects:
self._tpu_embedding_objects.append(tpu_embedding)
@property
def save_counter(self):
"""An integer variable numbering the checkpoint events.
This is maintained by the underlying tf.train.Checkpoing object employed by
AsyncCheckpoint class. The number starts at 0 and gets incremented for each
checkpoint event.
Returns:
The save counter variable.
"""
self._ensure_initialized()
return self._checkpoint.save_counter
def write(self, save_path, options=None):
"""Save the checkpointed variables.
Args:
save_path: The file prefix of the checkpoint file.
options: Optional CheckpointOption instance.
Returns:
The full path of the checkpoint file.
"""
self._write(save_path, options)
def _write(self, save_path, options=None, write_done_callback=None):
"""Save the checkpointed variables.
This method has exactly the same logic as save(), except it does not
increment the underlying save_counter, which is done by the caller, e.g.,
CheckpointManager.
Args:
save_path: The file prefix of the checkpoint file.
options: Optional CheckpointOption instance.
write_done_callback: Optional callback function executed after the async
write is done.
Returns:
The full path of the checkpoint file.
"""
self._ensure_initialized()
write_start_time = time.time()
# Copy the variable values to the host CPU.
if self._writer_sem.acquire():
self._copy_to_cpu()
# Trigger the async thread to checkpoint the cpu-copied variables.
# Need to wait until the weight copying finishes before checkpoint save.
context.async_wait()
self._save_file_prefix = save_path
self._use_checkpoint_save = False
# Ensure that we do not request async checkpointing to the underlying
# checkpointer as this could lead to an infinite loop.
self._checkpoint_options = copy.copy(options) if options else None
if self._checkpoint_options:
self._checkpoint_options.experimental_enable_async_checkpoint = False
self._async_write_done_callback = write_done_callback
self._reader_sem.release()
write_end_time = time.time()
metrics.AddCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(write_start_time,
write_end_time))
return save_path
def save(self, save_path, options=None):
"""Save the checkpointed variables.
Args:
save_path: The file prefix of the checkpoint file.
options: Optional CheckpointOption instance.
Returns:
The full path of the checkpoint file.
"""
# If this is the first time that AsyncCheckpoint.save() is called,
# initialize the cpu-copied variables and create the pair-wise mapping
# between the original model variables and the cpu-copied variables.
#
# This is not performed in the initializer because some variables, e.g.,
# slot variables of the optimizer, were not created until actually running
# the train function, so we could only get the complete list of the
# variables after some train steps were run.
self._ensure_initialized()
save_start_time = time.time()
# Copy the variable values to the host CPU.
if self._writer_sem.acquire():
self._copy_to_cpu()
# Retrieve the save counter from the underlying checkpoint object to
# re-construct the full path of the checkpoint file.
# This step has to happen before triggerting the underlying checkpoint;
# otherwise, the save_counter value may or may not have been updated.
save_counter = self._checkpoint.save_counter.numpy() + 1
full_path = "{}-{}".format(save_path, save_counter)
# Trigger the async thread to checkpoint the cpu-copied variables.
# Need to wait until the weight copying finishes before checkpoint save.
context.async_wait()
self._save_file_prefix = save_path
self._use_checkpoint_save = True
# Ensure that we do not request async checkpointing to the underlying
# checkpointer as this could lead to an infinite loop.
self._checkpoint_options = copy.copy(options) if options else None
if self._checkpoint_options:
self._checkpoint_options.experimental_enable_async_checkpoint = False
self._reader_sem.release()
save_end_time = time.time()
metrics.AddCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT,
microseconds=_get_duration_microseconds(save_start_time, save_end_time))
return full_path
def read(self, save_path, options=None):
"""Restore the checkpointed variables.
This method has exactly the same logic as restore(). This method is
implemented only to fulfill the duty of subclassing tf.train.Checkpoint.
Args:
save_path: The full name of the checkpoint file to be restored.
options: CheckpointOption instance.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See tf.train.Checkpoint.restore()
for more details.
"""
return self.restore(save_path, options)
def restore(self, save_path, options=None):
"""Restore the checkpointed variables.
Args:
save_path: The full name of the checkpoint file to be restored.
options: CheckpointOption instance.
Returns:
A load status object, which can be used to make assertions about the
status of a checkpoint restoration. See tf.train.Checkpoint.restore()
for more details.
"""
# Ensure that we do not request async checkpointing to the underlying
# checkpointer as this could lead to an infinite loop.
self._checkpoint_options = (
copy.copy(options) if options else self._checkpoint_options)
if self._checkpoint_options:
self._checkpoint_options.experimental_enable_async_checkpoint = False
# Wait for any ongoing checkpoint event to finish.
with self._writer_sem:
# If _checkpoint has not been initialized yet, it means the restore() is
# called right after the coordinator is restarted. We directly restore
# the checkpointed items through tf.train.Checkpoint.restore().
if self._checkpoint is None:
tmp_checkpoint = self._checkpointer_impl(**self._checkpoint_items)
return tmp_checkpoint.restore(save_path, self._checkpoint_options)
# Restore the values of the cpu-copied variables.
status = self._checkpoint.restore(save_path, self._checkpoint_options)
# Restore the values of the original model.
self._copy_from_cpu()
return status
def sync(self):
"""Sync on any ongoing save or restore events."""
with self._writer_sem:
logging.info("Sync on ongoing save/restore.")