1440 lines
59 KiB
Python
1440 lines
59 KiB
Python
|
import contextlib
|
||
|
import platform
|
||
|
import uuid
|
||
|
import warnings
|
||
|
import weakref
|
||
|
from collections import defaultdict
|
||
|
from itertools import count
|
||
|
from typing import (
|
||
|
Any,
|
||
|
Callable,
|
||
|
ContextManager,
|
||
|
DefaultDict,
|
||
|
Dict,
|
||
|
Iterable,
|
||
|
List,
|
||
|
Optional,
|
||
|
Tuple,
|
||
|
)
|
||
|
from weakref import ReferenceType
|
||
|
|
||
|
import torch
|
||
|
import torch.fx.traceback as fx_traceback
|
||
|
from torch._functorch._aot_autograd.functional_utils import is_fun
|
||
|
from torch.utils._pytree import tree_map
|
||
|
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
|
||
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||
|
|
||
|
__all__ = [
|
||
|
"checkpoint",
|
||
|
"checkpoint_sequential",
|
||
|
"CheckpointError",
|
||
|
"CheckpointFunction",
|
||
|
"check_backward_validity",
|
||
|
"detach_variable",
|
||
|
"get_device_states",
|
||
|
"set_device_states",
|
||
|
"noop_context_fn",
|
||
|
"set_checkpoint_early_stop",
|
||
|
"DefaultDeviceType",
|
||
|
"set_checkpoint_debug_enabled",
|
||
|
]
|
||
|
|
||
|
_DEFAULT_DETERMINISM_MODE = "default"
|
||
|
|
||
|
_checkpoint_debug_enabled: Optional[bool] = None
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def set_checkpoint_debug_enabled(enabled: Optional[bool]):
|
||
|
"""
|
||
|
Context manager that sets whether checkpoint should print additional debug
|
||
|
information when running. See the ``debug`` flag for
|
||
|
:func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
|
||
|
when set, this context manager overrides the value of ``debug`` passed to
|
||
|
checkpoint. To defer to the local setting, pass ``None`` to this context.
|
||
|
|
||
|
Args:
|
||
|
enabled (bool): Whether checkpoint should print debug information.
|
||
|
Default is 'None'.
|
||
|
"""
|
||
|
global _checkpoint_debug_enabled
|
||
|
try:
|
||
|
prev = _checkpoint_debug_enabled
|
||
|
_checkpoint_debug_enabled = enabled
|
||
|
yield
|
||
|
finally:
|
||
|
_checkpoint_debug_enabled = prev
|
||
|
|
||
|
|
||
|
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
||
|
if isinstance(inputs, tuple):
|
||
|
out = []
|
||
|
for inp in inputs:
|
||
|
if not isinstance(inp, torch.Tensor):
|
||
|
out.append(inp)
|
||
|
continue
|
||
|
|
||
|
x = inp.detach()
|
||
|
x.requires_grad = inp.requires_grad
|
||
|
out.append(x)
|
||
|
return tuple(out)
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
"Only tuple of tensors is supported. Got Unsupported input type: ",
|
||
|
type(inputs).__name__,
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_backward_validity(inputs: Iterable[Any]) -> None:
|
||
|
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
|
||
|
warnings.warn(
|
||
|
"None of the inputs have requires_grad=True. Gradients will be None"
|
||
|
)
|
||
|
|
||
|
|
||
|
def _get_device_module(device="cuda"):
|
||
|
device_module = getattr(torch, device)
|
||
|
return device_module
|
||
|
|
||
|
|
||
|
class DefaultDeviceType:
|
||
|
r"""
|
||
|
A class that manages the default device type for checkpointing.
|
||
|
|
||
|
If no non-CPU tensors are present, the default device type will
|
||
|
be used. The default value is 'cuda'. The device type is used in
|
||
|
the checkpointing process when determining which device states
|
||
|
to save and restore for recomputation.
|
||
|
"""
|
||
|
|
||
|
_default_device_type = "cuda"
|
||
|
|
||
|
@staticmethod
|
||
|
def set_device_type(device: str = "cuda"):
|
||
|
"""
|
||
|
Set the default device type for checkpointing.
|
||
|
|
||
|
Args:
|
||
|
device (str): The device type to be set as default. Default is 'cuda'.
|
||
|
"""
|
||
|
DefaultDeviceType._default_device_type = device
|
||
|
|
||
|
@staticmethod
|
||
|
def get_device_type() -> str:
|
||
|
"""
|
||
|
Get the current default device type for checkpointing.
|
||
|
|
||
|
Returns:
|
||
|
str: The current default device type.
|
||
|
"""
|
||
|
return DefaultDeviceType._default_device_type
|
||
|
|
||
|
|
||
|
def _infer_device_type(*args):
|
||
|
device_types = list(
|
||
|
{
|
||
|
arg.device.type
|
||
|
for arg in args
|
||
|
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
|
||
|
}
|
||
|
)
|
||
|
if len(device_types) > 1:
|
||
|
warnings.warn(
|
||
|
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
|
||
|
"Device state will only be saved for devices of a single device type, and the remaining "
|
||
|
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
|
||
|
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
|
||
|
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
|
||
|
)
|
||
|
if len(device_types) == 0:
|
||
|
return DefaultDeviceType.get_device_type()
|
||
|
elif "cuda" in device_types:
|
||
|
return "cuda"
|
||
|
else:
|
||
|
return device_types[0]
|
||
|
|
||
|
|
||
|
# We can't know if the run_fn will internally move some args to different devices,
|
||
|
# which would require logic to preserve rng states for those devices as well.
|
||
|
# We could paranoically stash and restore ALL the rng states for all visible devices,
|
||
|
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
|
||
|
# the device of all Tensor args.
|
||
|
#
|
||
|
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
|
||
|
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
||
|
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
||
|
# the conditionals short-circuit.
|
||
|
fwd_device_ids = list(
|
||
|
{
|
||
|
arg.get_device()
|
||
|
for arg in args
|
||
|
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
|
||
|
}
|
||
|
)
|
||
|
|
||
|
fwd_device_states = []
|
||
|
device_module = _get_device_module(_infer_device_type(*args))
|
||
|
|
||
|
for device_id in fwd_device_ids:
|
||
|
with device_module.device(device_id):
|
||
|
fwd_device_states.append(device_module.get_rng_state())
|
||
|
|
||
|
return fwd_device_ids, fwd_device_states
|
||
|
|
||
|
|
||
|
def set_device_states(devices, states) -> None:
|
||
|
device_module = _get_device_module(_infer_device_type(*states))
|
||
|
for device, state in zip(devices, states):
|
||
|
with device_module.device(device):
|
||
|
device_module.set_rng_state(state)
|
||
|
|
||
|
|
||
|
def _get_autocast_kwargs(device="cuda"):
|
||
|
if device == "cuda":
|
||
|
device_autocast_kwargs = {
|
||
|
"enabled": torch.is_autocast_enabled(),
|
||
|
"dtype": torch.get_autocast_gpu_dtype(),
|
||
|
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||
|
}
|
||
|
elif _supports_autocast(device):
|
||
|
device_module = _get_device_module(device)
|
||
|
device_autocast_kwargs = {
|
||
|
"enabled": device_module.is_autocast_enabled(),
|
||
|
"dtype": device_module.get_autocast_dtype(),
|
||
|
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||
|
}
|
||
|
else:
|
||
|
device_autocast_kwargs = None
|
||
|
|
||
|
cpu_autocast_kwargs = {
|
||
|
"enabled": torch.is_autocast_cpu_enabled(),
|
||
|
"dtype": torch.get_autocast_cpu_dtype(),
|
||
|
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||
|
}
|
||
|
|
||
|
return device_autocast_kwargs, cpu_autocast_kwargs
|
||
|
|
||
|
def _supports_autocast(device):
|
||
|
device_module = _get_device_module(device)
|
||
|
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
|
||
|
and hasattr(device_module, "get_autocast_dtype"))
|
||
|
|
||
|
class CheckpointFunction(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||
|
check_backward_validity(args)
|
||
|
ctx.run_function = run_function
|
||
|
ctx.preserve_rng_state = preserve_rng_state
|
||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||
|
ctx.device = _infer_device_type(*args)
|
||
|
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
||
|
ctx.device
|
||
|
)
|
||
|
if preserve_rng_state:
|
||
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
||
|
# Don't eagerly initialize the cuda context by accident.
|
||
|
# (If the user intends that the context is initialized later, within their
|
||
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||
|
# we have no way to anticipate this will happen before we run the function.)
|
||
|
ctx.had_device_in_fwd = False
|
||
|
device_module = _get_device_module(ctx.device)
|
||
|
if getattr(device_module, "_initialized", False):
|
||
|
ctx.had_device_in_fwd = True
|
||
|
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
||
|
|
||
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||
|
# to be filled out during the backward.
|
||
|
ctx.inputs = []
|
||
|
ctx.tensor_indices = []
|
||
|
tensor_inputs = []
|
||
|
for i, arg in enumerate(args):
|
||
|
if torch.is_tensor(arg):
|
||
|
tensor_inputs.append(arg)
|
||
|
ctx.tensor_indices.append(i)
|
||
|
ctx.inputs.append(None)
|
||
|
else:
|
||
|
ctx.inputs.append(arg)
|
||
|
|
||
|
ctx.save_for_backward(*tensor_inputs)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
outputs = run_function(*args)
|
||
|
return outputs
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, *args):
|
||
|
if not torch.autograd._is_checkpoint_valid():
|
||
|
raise RuntimeError(
|
||
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||
|
" argument."
|
||
|
)
|
||
|
# Copy the list to avoid modifying original list.
|
||
|
inputs = list(ctx.inputs)
|
||
|
tensor_indices = ctx.tensor_indices
|
||
|
tensors = ctx.saved_tensors
|
||
|
device_module = _get_device_module(ctx.device)
|
||
|
|
||
|
# Fill in inputs with appropriate saved tensors.
|
||
|
for i, idx in enumerate(tensor_indices):
|
||
|
inputs[idx] = tensors[i]
|
||
|
|
||
|
# Stash the surrounding rng state, and mimic the state that was
|
||
|
# present at this time during forward. Restore the surrounding state
|
||
|
# when we're done.
|
||
|
rng_devices = []
|
||
|
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
||
|
rng_devices = ctx.fwd_devices
|
||
|
with torch.random.fork_rng(
|
||
|
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
|
||
|
):
|
||
|
if ctx.preserve_rng_state:
|
||
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
||
|
if ctx.had_device_in_fwd:
|
||
|
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||
|
detached_inputs = detach_variable(tuple(inputs))
|
||
|
|
||
|
device_autocast_ctx = device_module.amp.autocast(
|
||
|
**ctx.device_autocast_kwargs
|
||
|
) if _supports_autocast(ctx.device) else contextlib.nullcontext()
|
||
|
with torch.enable_grad(), device_autocast_ctx, \
|
||
|
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||
|
outputs = ctx.run_function(*detached_inputs)
|
||
|
|
||
|
if isinstance(outputs, torch.Tensor):
|
||
|
outputs = (outputs,)
|
||
|
|
||
|
# run backward() with only tensor that requires grad
|
||
|
outputs_with_grad = []
|
||
|
args_with_grad = []
|
||
|
for i in range(len(outputs)):
|
||
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||
|
outputs_with_grad.append(outputs[i])
|
||
|
args_with_grad.append(args[i])
|
||
|
if len(outputs_with_grad) == 0:
|
||
|
raise RuntimeError(
|
||
|
"none of output has requires_grad=True,"
|
||
|
" this checkpoint() is not necessary"
|
||
|
)
|
||
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||
|
grads = tuple(
|
||
|
inp.grad if isinstance(inp, torch.Tensor) else None
|
||
|
for inp in detached_inputs
|
||
|
)
|
||
|
|
||
|
return (None, None) + grads
|
||
|
|
||
|
|
||
|
def noop_context_fn():
|
||
|
return contextlib.nullcontext(), contextlib.nullcontext()
|
||
|
|
||
|
# TorchDynamo does not step inside utils.checkpoint function. The flow
|
||
|
# looks likes this
|
||
|
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
|
||
|
# speculatively checking if the forward function is safe to trace.
|
||
|
# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher
|
||
|
# order op. As a result, TorchDynamo does not look inside utils.checkpoint.
|
||
|
# 3) If not, then TorchDynamo falls back to eager by performing a graph
|
||
|
# break. And here, the following disable wrapper ensures that
|
||
|
# TorchDynamo does not trigger again on the frames created by
|
||
|
# utils.checkpoint innards.
|
||
|
@torch._disable_dynamo
|
||
|
def checkpoint(
|
||
|
function,
|
||
|
*args,
|
||
|
use_reentrant: Optional[bool] = None,
|
||
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||
|
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
||
|
debug: bool = False,
|
||
|
**kwargs
|
||
|
):
|
||
|
r"""Checkpoint a model or part of the model.
|
||
|
|
||
|
Activation checkpointing is a technique that trades compute for memory.
|
||
|
Instead of keeping tensors needed for backward alive until they are used in
|
||
|
gradient computation during backward, forward computation in checkpointed
|
||
|
regions omits saving tensors for backward and recomputes them during the
|
||
|
backward pass. Activation checkpointing can be applied to any part of a
|
||
|
model.
|
||
|
|
||
|
There are currently two checkpointing implementations available, determined
|
||
|
by the :attr:`use_reentrant` parameter. It is recommended that you use
|
||
|
``use_reentrant=False``. Please refer the note below for a discussion of
|
||
|
their differences.
|
||
|
|
||
|
.. warning::
|
||
|
|
||
|
If the :attr:`function` invocation during the backward pass differs
|
||
|
from the forward pass, e.g., due to a global variable, the checkpointed
|
||
|
version may not be equivalent, potentially causing an
|
||
|
error being raised or leading to silently incorrect gradients.
|
||
|
|
||
|
.. warning::
|
||
|
|
||
|
The ``use_reentrant`` parameter should be passed explicitly. In version
|
||
|
2.4 we will raise an exception if ``use_reentrant`` is not passed.
|
||
|
If you are using the ``use_reentrant=True`` variant, please refer to the
|
||
|
note below for important considerations and potential limitations.
|
||
|
|
||
|
.. note::
|
||
|
|
||
|
The reentrant variant of checkpoint (``use_reentrant=True``) and
|
||
|
the non-reentrant variant of checkpoint (``use_reentrant=False``)
|
||
|
differ in the following ways:
|
||
|
|
||
|
* Non-reentrant checkpoint stops recomputation as soon as all needed
|
||
|
intermediate activations have been recomputed. This feature is enabled
|
||
|
by default, but can be disabled with :func:`set_checkpoint_early_stop`.
|
||
|
Reentrant checkpoint always recomputes :attr:`function` in its
|
||
|
entirety during the backward pass.
|
||
|
|
||
|
* The reentrant variant does not record the autograd graph during the
|
||
|
forward pass, as it runs with the forward pass under
|
||
|
:func:`torch.no_grad`. The non-reentrant version does record the
|
||
|
autograd graph, allowing one to perform backward on the graph within
|
||
|
checkpointed regions.
|
||
|
|
||
|
* The reentrant checkpoint only supports the
|
||
|
:func:`torch.autograd.backward` API for the backward pass without its
|
||
|
`inputs` argument, while the non-reentrant version supports all ways
|
||
|
of performing the backward pass.
|
||
|
|
||
|
* At least one input and output must have ``requires_grad=True`` for the
|
||
|
reentrant variant. If this condition is unmet, the checkpointed part
|
||
|
of the model will not have gradients. The non-reentrant version does
|
||
|
not have this requirement.
|
||
|
|
||
|
* The reentrant version does not consider tensors in nested structures
|
||
|
(e.g., custom objects, lists, dicts, etc) as participating in
|
||
|
autograd, while the non-reentrant version does.
|
||
|
|
||
|
* The reentrant checkpoint does not support checkpointed regions with
|
||
|
detached tensors from the computational graph, whereas the
|
||
|
non-reentrant version does. For the reentrant variant, if the
|
||
|
checkpointed segment contains tensors detached using ``detach()`` or
|
||
|
with :func:`torch.no_grad`, the backward pass will raise an error.
|
||
|
This is because ``checkpoint`` makes all the outputs require gradients
|
||
|
and this causes issues when a tensor is defined to have no gradient in
|
||
|
the model. To avoid this, detach the tensors outside of the
|
||
|
``checkpoint`` function.
|
||
|
|
||
|
Args:
|
||
|
function: describes what to run in the forward pass of the model or
|
||
|
part of the model. It should also know how to handle the inputs
|
||
|
passed as the tuple. For example, in LSTM, if user passes
|
||
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
||
|
first input as ``activation`` and the second input as ``hidden``
|
||
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||
|
the RNG state during each checkpoint. Note that under torch.compile,
|
||
|
this flag doesn't take effect and we always preserve RNG state.
|
||
|
Default: ``True``
|
||
|
use_reentrant(bool):
|
||
|
specify whether to use the activation checkpoint variant that
|
||
|
requires reentrant autograd. This parameter should be passed
|
||
|
explicitly. In version 2.4 we will raise an exception if
|
||
|
``use_reentrant`` is not passed. If ``use_reentrant=False``,
|
||
|
``checkpoint`` will use an implementation that does not require
|
||
|
reentrant autograd. This allows ``checkpoint`` to support additional
|
||
|
functionality, such as working as expected with
|
||
|
``torch.autograd.grad`` and support for keyword arguments input into
|
||
|
the checkpointed function.
|
||
|
context_fn(Callable, optional): A callable returning a tuple of two
|
||
|
context managers. The function and its recomputation will be run
|
||
|
under the first and second context managers respectively.
|
||
|
This argument is only supported if ``use_reentrant=False``.
|
||
|
determinism_check(str, optional): A string specifying the determinism
|
||
|
check to perform. By default it is set to ``"default"`` which
|
||
|
compares the shapes, dtypes, and devices of the recomputed tensors
|
||
|
against those the saved tensors. To turn off this check, specify
|
||
|
``"none"``. Currently these are the only two supported values.
|
||
|
Please open an issue if you would like to see more determinism
|
||
|
checks. This argument is only supported if ``use_reentrant=False``,
|
||
|
if ``use_reentrant=True``, the determinism check is always disabled.
|
||
|
debug(bool, optional): If ``True``, error messages will also include
|
||
|
a trace of the operators ran during the original forward computation
|
||
|
as well as the recomputation. This argument is only supported if
|
||
|
``use_reentrant=False``.
|
||
|
args: tuple containing inputs to the :attr:`function`
|
||
|
|
||
|
Returns:
|
||
|
Output of running :attr:`function` on :attr:`*args`
|
||
|
"""
|
||
|
if use_reentrant is None:
|
||
|
warnings.warn(
|
||
|
"torch.utils.checkpoint: the use_reentrant parameter should be "
|
||
|
"passed explicitly. In version 2.4 we will raise an exception "
|
||
|
"if use_reentrant is not passed. use_reentrant=False is "
|
||
|
"recommended, but if you need to preserve the current default "
|
||
|
"behavior, you can pass use_reentrant=True. Refer to docs for more "
|
||
|
"details on the differences between the two variants."
|
||
|
)
|
||
|
use_reentrant = True
|
||
|
|
||
|
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
||
|
preserve = kwargs.pop("preserve_rng_state", True)
|
||
|
if kwargs and use_reentrant:
|
||
|
raise ValueError(
|
||
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||
|
)
|
||
|
|
||
|
if use_reentrant:
|
||
|
if context_fn is not noop_context_fn or debug is not False:
|
||
|
raise ValueError(
|
||
|
"Passing `context_fn` or `debug` is only supported when "
|
||
|
"use_reentrant=False."
|
||
|
)
|
||
|
return CheckpointFunction.apply(function, preserve, *args)
|
||
|
else:
|
||
|
gen = _checkpoint_without_reentrant_generator(
|
||
|
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
|
||
|
)
|
||
|
# Runs pre-forward logic
|
||
|
next(gen)
|
||
|
ret = function(*args, **kwargs)
|
||
|
# Runs post-forward logic
|
||
|
try:
|
||
|
next(gen)
|
||
|
except StopIteration:
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
|
||
|
r"""Checkpoint a sequential model to save memory.
|
||
|
|
||
|
Sequential models execute a list of modules/functions in order
|
||
|
(sequentially). Therefore, we can divide such a model in various segments
|
||
|
and checkpoint each segment. All segments except the last will not store
|
||
|
the intermediate activations. The inputs of each checkpointed segment will
|
||
|
be saved for re-running the segment in the backward pass.
|
||
|
|
||
|
.. warning::
|
||
|
The ``use_reentrant`` parameter should be passed explicitly. In version
|
||
|
2.4 we will raise an exception if ``use_reentrant`` is not passed.
|
||
|
If you are using the ``use_reentrant=True` variant, please see
|
||
|
:func:`~torch.utils.checkpoint.checkpoint` for
|
||
|
the important considerations and limitations of this variant. It is
|
||
|
recommended that you use ``use_reentrant=False``.
|
||
|
|
||
|
.. warning:
|
||
|
Since PyTorch 1.4, it allows only one Tensor as the input and
|
||
|
intermediate outputs, just like :class:`torch.nn.Sequential`.
|
||
|
|
||
|
Args:
|
||
|
functions: A :class:`torch.nn.Sequential` or the list of modules or
|
||
|
functions (comprising the model) to run sequentially.
|
||
|
segments: Number of chunks to create in the model
|
||
|
input: A Tensor that is input to :attr:`functions`
|
||
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||
|
the RNG state during each checkpoint.
|
||
|
Default: ``True``
|
||
|
use_reentrant(bool):
|
||
|
specify whether to use the activation checkpoint variant that
|
||
|
requires reentrant autograd. This parameter should be passed
|
||
|
explicitly. In version 2.4 we will raise an exception if
|
||
|
``use_reentrant`` is not passed. If ``use_reentrant=False``,
|
||
|
``checkpoint`` will use an implementation that does not require
|
||
|
reentrant autograd. This allows ``checkpoint`` to support additional
|
||
|
functionality, such as working as expected with
|
||
|
``torch.autograd.grad`` and support for keyword arguments input into
|
||
|
the checkpointed function.
|
||
|
|
||
|
Returns:
|
||
|
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
||
|
|
||
|
Example:
|
||
|
>>> # xdoctest: +SKIP("stub")
|
||
|
>>> model = nn.Sequential(...)
|
||
|
>>> input_var = checkpoint_sequential(model, chunks, input_var)
|
||
|
"""
|
||
|
if use_reentrant is None:
|
||
|
warnings.warn(
|
||
|
"torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
|
||
|
"parameter should be passed explicitly. "
|
||
|
"In version 2.4 we will raise an exception if use_reentrant "
|
||
|
"is not passed. use_reentrant=False is "
|
||
|
"recommended, but if you need to preserve the current default "
|
||
|
"behavior, you can pass use_reentrant=True. Refer to docs for more "
|
||
|
"details on the differences between the two variants."
|
||
|
)
|
||
|
use_reentrant = True
|
||
|
|
||
|
# Hack for keyword-only parameter in a python 2.7-compliant way
|
||
|
preserve = kwargs.pop("preserve_rng_state", True)
|
||
|
if kwargs:
|
||
|
raise ValueError(
|
||
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||
|
)
|
||
|
|
||
|
def run_function(start, end, functions):
|
||
|
def forward(input):
|
||
|
for j in range(start, end + 1):
|
||
|
input = functions[j](input)
|
||
|
return input
|
||
|
|
||
|
return forward
|
||
|
|
||
|
if isinstance(functions, torch.nn.Sequential):
|
||
|
functions = list(functions.children())
|
||
|
|
||
|
segment_size = len(functions) // segments
|
||
|
# the last chunk has to be non-volatile
|
||
|
end = -1
|
||
|
for start in range(0, segment_size * (segments - 1), segment_size):
|
||
|
end = start + segment_size - 1
|
||
|
input = checkpoint(
|
||
|
run_function(start, end, functions),
|
||
|
input,
|
||
|
use_reentrant=use_reentrant,
|
||
|
preserve_rng_state=preserve,
|
||
|
)
|
||
|
return run_function(end + 1, len(functions) - 1, functions)(input)
|
||
|
|
||
|
|
||
|
def _internal_assert(cond):
|
||
|
if not cond:
|
||
|
raise AssertionError(
|
||
|
"Something went unexpectedly wrong in activation checkpoint. "
|
||
|
"Please report this bug by filing an issue to PyTorch."
|
||
|
)
|
||
|
|
||
|
|
||
|
# NOTE [ Nestable Checkpoint ]
|
||
|
#
|
||
|
# The semantics of nested checkpoint can be defined by two basic rules.
|
||
|
# Following the two rules leads to an important implication that is central
|
||
|
# to motivating the design.
|
||
|
#
|
||
|
# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
|
||
|
# from any outer layers of checkpoint.
|
||
|
#
|
||
|
# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
|
||
|
# parent checkpoint.
|
||
|
#
|
||
|
# Implication: To recompute any given saved tensor, we need to recompute all of
|
||
|
# the checkpoints wrapping it.
|
||
|
#
|
||
|
# Why is this implied? To unpack a saved tensor X during backward we need to
|
||
|
# recompute the inner-most checkpoint (#1), and in order to recompute that
|
||
|
# checkpoint I need to have its inputs, which are managed by that checkpoint's
|
||
|
# parent (#2), which thus also needs to be recomputed first. Continue this line
|
||
|
# of reasoning and we realize that in order to unpack X, all checkpoints that
|
||
|
# were active at the time X was saved need to be recomputed. (unless we have
|
||
|
# already done so in that backward for some other saved tensor).
|
||
|
#
|
||
|
# In practice, we use a noop autograd Function to save inputs as saved tensors.
|
||
|
# During unpack calling ctx.saved_tensor triggers the parent checkpoint to
|
||
|
# recompute.
|
||
|
#
|
||
|
# Rule 3. We should start recomputation as if there are no checkpoints currently
|
||
|
# active. Checkpoints encountered during recomputation are still
|
||
|
# respected.
|
||
|
#
|
||
|
# When we start recomputation, we push the saved variable hook meant for
|
||
|
# recomputation on the stack. See examples in Rule 6 for more context.
|
||
|
#
|
||
|
# * * * *
|
||
|
#
|
||
|
# Beyond the basic semantics specific to nested checkpoint, we impose several
|
||
|
# more constraints that may apply to checkpointing in general.
|
||
|
#
|
||
|
# Rule 4. Lifetime of recomputed tensors
|
||
|
#
|
||
|
# Recomputed tensors are considered specific to particular invocations
|
||
|
# of backward and are always cleared immediately as they are unpacked
|
||
|
# Particularly, we require this to happen even if retain_graph=True.
|
||
|
#
|
||
|
# [ Implementation details of Rule 4 ]
|
||
|
#
|
||
|
# If we were okay with recomputed tensors staying alive after backward is run
|
||
|
# with retain_graph=True, we would store recomputed variables as the values of a
|
||
|
# WeakKeyDictionary and pack strong references to the keys, so that as we
|
||
|
# backward, those packed keys would be cleared as long as retain_graph=False.
|
||
|
# Clearing the packed key clears the corresponding entry in the WKD.
|
||
|
#
|
||
|
# If we wish recomputed variables to be immediately cleared as we unpack them in
|
||
|
# the retain_graph=True case, we cannot rely on the packed keys to be cleared by
|
||
|
# backward automatically. Instead of packing the strong reference to the key
|
||
|
# directly, we pack a container object, which we manually clear as we unpack.
|
||
|
#
|
||
|
# An important detail is that if a second backward happens, the second
|
||
|
# recomputation needs to reset the container with a newly created key.
|
||
|
#
|
||
|
# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
|
||
|
# know we need.
|
||
|
#
|
||
|
# [ Implementation details of Rule 5 ]
|
||
|
#
|
||
|
# During recomputation, raise an exception if the number of recomputed tensors
|
||
|
# matches the number of tensors that we expected to recompute. We wrap the
|
||
|
# recomputation call with a try-catch to catch this specific exception. See
|
||
|
# Rule #6 below for some examples.
|
||
|
#
|
||
|
# Rule 6. We support doing backward inside checkpoint context
|
||
|
#
|
||
|
# [ retain_graph is True]
|
||
|
#
|
||
|
# def fn(x):
|
||
|
# y = x.sin()
|
||
|
# z = y.cos()
|
||
|
# gx, = torch.autograd.grad(z, x, retains_grad=True)
|
||
|
# return gx, z
|
||
|
#
|
||
|
# out = checkpoint(fn)(inp)
|
||
|
# out.backward()
|
||
|
#
|
||
|
# Because z is saved by cos while checkpoint is enabled, it would not be
|
||
|
# actually saved, and so the .grad() call inside must trigger a recomputation.
|
||
|
#
|
||
|
# During recomputation the "inner pack hook" has two responsibilities:
|
||
|
#
|
||
|
# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
|
||
|
# 2) Pack the actual tensor (detached) so that one may perform backward on the
|
||
|
# recomputed graph. The tensors saved to this graph will live until the end
|
||
|
# of recomputation, or die earlier if someone performs backward with
|
||
|
# retain_graph=False.
|
||
|
#
|
||
|
# More generally performing backward on the recomputed graph occurs in the
|
||
|
# following cases:
|
||
|
# - If backward is performed inside forward,
|
||
|
# - During the original forward IF early-stop is disabled
|
||
|
# - During the original backward
|
||
|
# - If there are multiple .grad()/.backward() calls, we would perform backward
|
||
|
# on the recomputed graph even if early-stop is enabled (see the example below)
|
||
|
#
|
||
|
# [ retain_graph is False ]
|
||
|
#
|
||
|
# The example below shows what happens if during recomputation we find that some
|
||
|
# of the tensors we are trying to recompute have already been cleared.
|
||
|
#
|
||
|
# Spoiler: we don't do anything special, we just skip over them!
|
||
|
#
|
||
|
# def fn(x):
|
||
|
# y = x.sin() # (1)
|
||
|
# z = y.cos() # (2)
|
||
|
# gx, = torch.autograd.grad(z, x) # (3)
|
||
|
# return x.cos() * gx # (4)
|
||
|
#
|
||
|
# out = checkpoint(fn)(inp)
|
||
|
# out.backward() # (5)
|
||
|
#
|
||
|
# 1, 2. Don't save x and y since we are inside a checkpoint.
|
||
|
# 3. Trigger a recompute of fn since x and y weren't saved.
|
||
|
# And depending on whether early stop is enabled, either stop at (2) or
|
||
|
# continue running the function.
|
||
|
# Because we are running backward with retain_graph=False, we clear x and y's
|
||
|
# holders.
|
||
|
# 4. Don't save x since we are inside a checkpoint.
|
||
|
# 5. Calling backward triggers another recompute of fn. During recompute, we see
|
||
|
# that x and y have already been cleared in the original graph as indicated
|
||
|
# by holder=None. We skip over them. We still save x at (4) (since its holder
|
||
|
# is still alive.)
|
||
|
|
||
|
_enable_checkpoint_early_stop = True
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def set_checkpoint_early_stop(enable: bool):
|
||
|
"""Context manager that sets whether checkpoint should stop recomputation early.
|
||
|
|
||
|
By default, non-reentrant checkpoint stops recomputation as soon as it
|
||
|
has computed all needed Tensors. This context manager can be used to disable
|
||
|
that feature if it is problematic for your specific application.
|
||
|
|
||
|
This context manager only needs to be active when forward is run. It does
|
||
|
not need to be active during backward.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> # xdoctest: +SKIP(failing)
|
||
|
>>> message = "saved tensors default hooks are disabled"
|
||
|
>>> with set_checkpoint_early_stop(False):
|
||
|
... # Any checkpoint under this context manager will respect this
|
||
|
... # context manager, even if its backward is performed outside.
|
||
|
... out = checkpoint(fn, inputs)
|
||
|
...
|
||
|
>>> out.backward()
|
||
|
"""
|
||
|
global _enable_checkpoint_early_stop
|
||
|
try:
|
||
|
prev = _enable_checkpoint_early_stop
|
||
|
_enable_checkpoint_early_stop = enable
|
||
|
yield
|
||
|
finally:
|
||
|
_enable_checkpoint_early_stop = prev
|
||
|
|
||
|
|
||
|
class _Handle:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class _Holder:
|
||
|
def __init__(self):
|
||
|
self.handles: Dict[int, Optional[_Handle]] = dict()
|
||
|
|
||
|
|
||
|
class _NoopSaveInputs(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(*args):
|
||
|
return torch.empty((0,))
|
||
|
|
||
|
@staticmethod
|
||
|
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
|
||
|
# Only tensors can be saved with ctx.save_for_backward, everything else
|
||
|
# is captured by get_args, which is saved directly on ctx
|
||
|
tensor_indices, tensors = zip(
|
||
|
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
|
||
|
)
|
||
|
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
|
||
|
# args but with tensors replaced with None as placeholders
|
||
|
args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
|
||
|
|
||
|
def get_args(saved_tensors):
|
||
|
# restore the placeholders with the original tensors grabbed from
|
||
|
# ctx.saved_tensors (which may be saved on a parent checkpoint if
|
||
|
# this checkpoint is nested, and that would trigger a recursive
|
||
|
# unpack!)
|
||
|
ret = [
|
||
|
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
|
||
|
for i, o in enumerate(args)
|
||
|
]
|
||
|
# grab the tail since we also saved the dummy to avoid having to explicitly
|
||
|
# handle the case where there are no tensor inputs
|
||
|
return ret[1:]
|
||
|
|
||
|
ctx.get_args = get_args
|
||
|
ctx.save_for_backward(*tensors)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, *grad_outputs):
|
||
|
raise AssertionError("Did not expect to backward on this graph")
|
||
|
|
||
|
|
||
|
class _CheckpointFrame:
|
||
|
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
|
||
|
self.recompute_fn = recompute_fn
|
||
|
self.input_saver = None
|
||
|
self.weak_holders: List[ReferenceType] = []
|
||
|
# We store this as a weakkeydictionary so that in the case of a partial
|
||
|
# backward, the entries in the dict are cleared alongside the Holder
|
||
|
# which will be removed when the SavedVariable is cleared.
|
||
|
self.recomputed: DefaultDict[
|
||
|
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
|
||
|
] = defaultdict(weakref.WeakKeyDictionary)
|
||
|
# We need both recomp_counter and recomputed since they can diverge
|
||
|
# https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
|
||
|
self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
|
||
|
self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
|
||
|
|
||
|
# See Rule 5
|
||
|
self.early_stop = early_stop
|
||
|
|
||
|
# Debugging
|
||
|
self.metadata_fn = metadata_fn
|
||
|
self.unpack_error_cb = unpack_error_cb
|
||
|
self.x_metadatas = []
|
||
|
self.forward_completed = False
|
||
|
self.ignore_saved_mismatch = False
|
||
|
|
||
|
def check_recomputed_tensors_match(self, gid):
|
||
|
if self.ignore_saved_mismatch:
|
||
|
# TODO: we can probably make this check stricter by checking that
|
||
|
# the metadata of the first tensors still match.
|
||
|
return
|
||
|
# NOTE [ Error handling for checkpoint ]
|
||
|
#
|
||
|
# At a high level, we need to check that the tensors saved
|
||
|
# during original forward matches tensors saved during recompute
|
||
|
# This means handling 3 cases:
|
||
|
#
|
||
|
# 1. During recompute, more tensors were saved.
|
||
|
#
|
||
|
# Usually this is hidden due to the StopRecomputationError
|
||
|
# but if early stop is not enabled, or we would have errored
|
||
|
# anyway because there aren't enough weak_holders. But we
|
||
|
# do want to have a nice error. See the _recomputation_hook
|
||
|
# for details.
|
||
|
if not len(self.weak_holders) == self.recomp_counter[gid]:
|
||
|
# 2. During recompute, fewer tensors were saved
|
||
|
#
|
||
|
# We know that everytime we save something do original forward
|
||
|
# we append to weak_holder, and every time we save a tensor
|
||
|
# during recompute we increment recompute_counter.
|
||
|
raise CheckpointError(
|
||
|
"torch.utils.checkpoint: A different number of tensors was saved "
|
||
|
"during the original forward and recomputation.\n"
|
||
|
f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
|
||
|
f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
|
||
|
)
|
||
|
|
||
|
# 3. During recompute, the same tensors were saved, but they
|
||
|
# have different metadata
|
||
|
nb_meta_different = []
|
||
|
for idx, weak_holder in enumerate(self.weak_holders):
|
||
|
holder = weak_holder()
|
||
|
if holder is None:
|
||
|
continue
|
||
|
# We've seen all holders since we iterate over them in order
|
||
|
# For every holder that is still alive now, it must've been
|
||
|
# alive when we saw it during recompute, therefore, the
|
||
|
# gid must be set.
|
||
|
_internal_assert(gid in holder.handles)
|
||
|
# We know this is the first unpack, so it couldn't have been set
|
||
|
# to None yet.
|
||
|
_internal_assert(holder.handles[gid] is not None)
|
||
|
# We always set these together in the recomputation hook
|
||
|
_internal_assert(holder.handles[gid] in self.recomputed[gid])
|
||
|
# see pack hook, x_metadata is 1:1 with weak_holders.
|
||
|
x_meta = self.x_metadatas[idx]
|
||
|
recomputed_x = self.recomputed[gid][holder.handles[gid]]
|
||
|
if x_meta != self.metadata_fn(recomputed_x):
|
||
|
nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
|
||
|
|
||
|
if len(nb_meta_different) > 0:
|
||
|
mismatched_tensors = ""
|
||
|
for idx, x_meta, recomputed_meta in nb_meta_different:
|
||
|
mismatched_tensors += (
|
||
|
f"tensor at position {idx}:\n"
|
||
|
f"saved metadata: {x_meta}\n"
|
||
|
f"recomputed metadata: {recomputed_meta}\n"
|
||
|
)
|
||
|
raise CheckpointError(
|
||
|
"torch.utils.checkpoint: Recomputed values for the following tensors "
|
||
|
"have different metadata than during the forward pass.\n"
|
||
|
f"{mismatched_tensors}"
|
||
|
)
|
||
|
|
||
|
|
||
|
_checkpoint_error_template = """ \
|
||
|
An error happened while unpacking tensors; dumping logs of latest computation
|
||
|
because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
|
||
|
Scroll all the way down for guidance on how to navigate these logs.
|
||
|
|
||
|
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
|
||
|
| 1. Stack traces of the operators that ran in the original forward |
|
||
|
+------------------------------------------------------------------------------+
|
||
|
|
||
|
{forward_traces}
|
||
|
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
|
||
|
| 2. Stack traces of the operators that ran during recomputation |
|
||
|
+------------------------------------------------------------------------------+
|
||
|
|
||
|
{recompute_traces}
|
||
|
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
|
||
|
| 3. Log of operators in the original forward and recomputation |
|
||
|
+------------------------------------------------------------------------------+
|
||
|
(Scroll up to correlate stack traces with each operation listed below. This
|
||
|
helps identify their source in the code.)
|
||
|
|
||
|
IMPORTANT: Differences in "detach" calls between the original forward and the
|
||
|
recomputation are expected. They are introduced by the checkpointing
|
||
|
mechanism and can be ignored.
|
||
|
|
||
|
Operations executed during the original forward:
|
||
|
|
||
|
{forward_ops}
|
||
|
|
||
|
Operations executed during recomputation:
|
||
|
|
||
|
{recompute_ops}
|
||
|
|
||
|
+------------------------------------------------------------------------------+
|
||
|
ERROR: Detected non-determinism while running activation checkpointing
|
||
|
|
||
|
You are seeing this error because you passed `debug=True` to checkpoint and
|
||
|
tensors to be saved during the original forward and differ between those saved
|
||
|
during recomputation. This can happen if different operators were ran in the
|
||
|
original forward and in the recomputation.
|
||
|
|
||
|
To identify where the mismatch may be coming from, you can do the following:
|
||
|
|
||
|
1) Compare the operators ran during original forward and recomputation to
|
||
|
see where they differ. These operators are printed above in the order they
|
||
|
were executed.
|
||
|
|
||
|
2) Review the stack trace for each operator to locate its invocation source.
|
||
|
Each operator's stack trace is printed in their execution order.
|
||
|
|
||
|
Note that the logs can be quite long. Here's how they are structured:
|
||
|
(Tip: you can Ctrl-f for these headers)
|
||
|
|
||
|
1. Stack traces of the operators that ran in the original forward
|
||
|
2. Stack traces of the operators that ran during recomputation
|
||
|
3. Log of operators in the original forward and recomputation
|
||
|
4. Error message <--- You are here
|
||
|
--------------------------------------------------------------------------------
|
||
|
"""
|
||
|
|
||
|
class CheckpointError(RuntimeError):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
|
||
|
# This function returns the context_fn and error_cb to be used by the
|
||
|
# checkpointing mechanism. error_cb is invoked when an error is detected
|
||
|
# during unpack.
|
||
|
|
||
|
# record_context_cpp is not support on non-linux non-x86_64 platforms
|
||
|
cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
|
||
|
|
||
|
class CaptureLogs:
|
||
|
def __init__(self):
|
||
|
self.logs = None
|
||
|
self.tbs = None
|
||
|
|
||
|
def get_context_manager(self):
|
||
|
@contextlib.contextmanager
|
||
|
def logging_mode():
|
||
|
with LoggingTensorMode(), \
|
||
|
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
|
||
|
self.logs, self.tbs = logs_and_tb
|
||
|
yield logs_and_tb
|
||
|
return logging_mode()
|
||
|
|
||
|
capture_logs_fwd = CaptureLogs()
|
||
|
capture_logs_recompute = CaptureLogs()
|
||
|
|
||
|
def unpack_error_cb(e: CheckpointError):
|
||
|
def get_str_tb(label, capture_logs):
|
||
|
out = ""
|
||
|
total_len = len(capture_logs.logs)
|
||
|
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
|
||
|
out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
|
||
|
found_torch_dispatch = False
|
||
|
for line in tb:
|
||
|
# Start printing stack trace only after __torch_dispatch__ is found
|
||
|
is_torch_dispatch = line['name'] == '__torch_dispatch__'
|
||
|
if not found_torch_dispatch and not is_torch_dispatch:
|
||
|
continue
|
||
|
elif is_torch_dispatch:
|
||
|
found_torch_dispatch = True
|
||
|
continue
|
||
|
out += f"{line['filename']}:{line['line']}:{line['name']}\n"
|
||
|
out += "\n\n"
|
||
|
return out
|
||
|
assert capture_logs_fwd.logs is not None
|
||
|
assert capture_logs_recompute.logs is not None
|
||
|
raise CheckpointError(
|
||
|
_checkpoint_error_template.format(
|
||
|
forward_traces=get_str_tb("original", capture_logs_fwd),
|
||
|
recompute_traces=get_str_tb("recompute", capture_logs_recompute),
|
||
|
forward_ops="\n".join(capture_logs_fwd.logs),
|
||
|
recompute_ops="\n".join(capture_logs_recompute.logs)
|
||
|
)
|
||
|
) from e
|
||
|
|
||
|
def context_fn():
|
||
|
return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
|
||
|
|
||
|
return context_fn, unpack_error_cb
|
||
|
|
||
|
def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
|
||
|
# These properties are fast to check, easy to understand
|
||
|
return {
|
||
|
"shape": x.shape,
|
||
|
"dtype": x.dtype,
|
||
|
"device": x.device
|
||
|
}
|
||
|
|
||
|
_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
|
||
|
_DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
|
||
|
"none": lambda _: None,
|
||
|
}
|
||
|
|
||
|
# See Rule 5
|
||
|
class _StopRecomputationError(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||
|
def __init__(self, target_frame_ref: ReferenceType, gid: int):
|
||
|
def pack_hook(x):
|
||
|
target_frame = target_frame_ref()
|
||
|
assert target_frame is not None # appease mypy
|
||
|
recomp_idx = target_frame.recomp_counter[gid]
|
||
|
target_frame.recomp_counter[gid] += 1
|
||
|
|
||
|
if recomp_idx >= len(target_frame.weak_holders):
|
||
|
assert not target_frame.early_stop
|
||
|
if not target_frame.forward_completed:
|
||
|
# We run into this case when early stop is not enabled and do
|
||
|
# grad within checkpoint.
|
||
|
# We need to set this flag, so we don't error out later when
|
||
|
# we check if the number of tensors saved during forward and
|
||
|
# recomputation match.
|
||
|
target_frame.ignore_saved_mismatch = True
|
||
|
return x.detach()
|
||
|
raise CheckpointError(
|
||
|
"torch.utils.checkpoint: trying to save more tensors during "
|
||
|
"recomputation than during the original forward pass."
|
||
|
)
|
||
|
|
||
|
holder = target_frame.weak_holders[recomp_idx]()
|
||
|
|
||
|
# This holder may have been cleared because someone may have called
|
||
|
# backward within forward. If so, we don't need to save.
|
||
|
if holder is not None:
|
||
|
_internal_assert(holder.handles.get(gid, None) is None)
|
||
|
holder.handles[gid] = _Handle()
|
||
|
target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
|
||
|
|
||
|
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
|
||
|
target_frame.weak_holders
|
||
|
):
|
||
|
raise _StopRecomputationError()
|
||
|
# See Rule 6: [ retain_graph is True ] above
|
||
|
return x.detach()
|
||
|
|
||
|
def unpack_hook(x):
|
||
|
# See Rule 6: [ retain_graph is True ] above for an example of when
|
||
|
# the graph created during recomputation could be backwarded.
|
||
|
return x
|
||
|
|
||
|
super().__init__(pack_hook, unpack_hook)
|
||
|
|
||
|
|
||
|
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||
|
def __init__(self, frame):
|
||
|
def pack_hook(x):
|
||
|
# See Rule 4 above
|
||
|
holder = _Holder()
|
||
|
frame.weak_holders.append(weakref.ref(holder))
|
||
|
# Save metadata to detect non-determinism
|
||
|
if frame.metadata_fn is not None:
|
||
|
with torch.no_grad():
|
||
|
frame.x_metadatas.append(frame.metadata_fn(x))
|
||
|
return holder
|
||
|
|
||
|
def unpack_hook(holder):
|
||
|
gid = torch._C._current_graph_task_id()
|
||
|
if gid == -1:
|
||
|
# generate a temporary id if we trigger unpack outside of a backward call
|
||
|
gid = int(uuid.uuid4())
|
||
|
|
||
|
if not frame.is_recomputed[gid]:
|
||
|
ctx = frame.input_saver.grad_fn
|
||
|
args = ctx.get_args(ctx.saved_tensors)
|
||
|
|
||
|
try:
|
||
|
with _recomputation_hook(
|
||
|
weakref.ref(frame), gid
|
||
|
), torch.autograd.enable_grad():
|
||
|
frame.recompute_fn(*args)
|
||
|
except _StopRecomputationError:
|
||
|
pass
|
||
|
frame.is_recomputed[gid] = True
|
||
|
frame.check_recomputed_tensors_match(gid)
|
||
|
|
||
|
_internal_assert(gid in holder.handles)
|
||
|
|
||
|
if holder.handles[gid] is None:
|
||
|
raise CheckpointError(
|
||
|
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
|
||
|
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
|
||
|
"so only once. Otherwise please open an issue with details on your use case."
|
||
|
)
|
||
|
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
|
||
|
ret = frame.recomputed[gid][holder.handles[gid]]
|
||
|
holder.handles[gid] = None
|
||
|
return ret
|
||
|
|
||
|
if frame.unpack_error_cb is not None:
|
||
|
def unpack_hook_with_error_cb(holder):
|
||
|
try:
|
||
|
return unpack_hook(holder)
|
||
|
except CheckpointError as e:
|
||
|
frame.unpack_error_cb(e)
|
||
|
super().__init__(pack_hook, unpack_hook_with_error_cb)
|
||
|
else:
|
||
|
super().__init__(pack_hook, unpack_hook)
|
||
|
|
||
|
|
||
|
def _is_compiling(func, args, kwargs):
|
||
|
# Check if we are under AOTAutograd tracing
|
||
|
# There should probably be a better way to do this...
|
||
|
# TODO: unify _is_compiling across all compile stacks
|
||
|
for arg in args:
|
||
|
if isinstance(arg, torch.Tensor) and is_fun(arg):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _detach(x):
|
||
|
if isinstance(x, torch.Tensor):
|
||
|
return x.detach()
|
||
|
return x
|
||
|
|
||
|
|
||
|
uid = count(1)
|
||
|
|
||
|
|
||
|
# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times
|
||
|
# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call),
|
||
|
# so we ignore these ops and just always recompute them.
|
||
|
_ignored_ops = {
|
||
|
torch.ops.prim.device.default,
|
||
|
torch.ops.aten.detach.default,
|
||
|
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
|
||
|
|
||
|
|
||
|
class _CachingTorchDispatchMode(TorchDispatchMode):
|
||
|
r"""
|
||
|
A :class:`TorchDispatchMode` to implement selective activation checkpointing
|
||
|
that's compatible with torch.compile. Used together with _CachedTorchDispatchMode.
|
||
|
"""
|
||
|
def __init__(self, policy_fn, storage):
|
||
|
self.policy_fn = policy_fn
|
||
|
self.storage = storage
|
||
|
|
||
|
def push_into_storage(self, out, func, args, kwargs):
|
||
|
out_detached = tree_map(_detach, out)
|
||
|
self.storage[func].append(out_detached)
|
||
|
|
||
|
def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs):
|
||
|
if func in _ignored_ops:
|
||
|
return func(*args, **kwargs)
|
||
|
if should_not_recompute:
|
||
|
fx_traceback.current_meta["recompute"] = 0
|
||
|
# NOTE: Here we just store and reuse output of all ops, since in torch.compile mode
|
||
|
# we decide and handle recomputation in the partitioner.
|
||
|
out = func(*args, **kwargs)
|
||
|
self.push_into_storage(out, func, args, kwargs)
|
||
|
return out
|
||
|
|
||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||
|
if kwargs is None:
|
||
|
kwargs = {}
|
||
|
should_not_recompute = self.policy_fn("forward", func, *args, **kwargs)
|
||
|
if _is_compiling(func, args, kwargs):
|
||
|
return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs)
|
||
|
else:
|
||
|
if should_not_recompute:
|
||
|
out = func(*args, **kwargs)
|
||
|
self.push_into_storage(out, func, args, kwargs)
|
||
|
else:
|
||
|
out = func(*args, **kwargs)
|
||
|
return out
|
||
|
|
||
|
|
||
|
class _CachedTorchDispatchMode(TorchDispatchMode):
|
||
|
r"""
|
||
|
A :class:`TorchDispatchMode` to implement selective activation checkpointing
|
||
|
that's compatible with torch.compile. Used together with _CachingTorchDispatchMode.
|
||
|
"""
|
||
|
def __init__(self, policy_fn, storage):
|
||
|
self.policy_fn = policy_fn
|
||
|
self.storage = storage
|
||
|
|
||
|
def pop_from_storage(self, func, args, kwargs):
|
||
|
assert func in self.storage
|
||
|
out = self.storage[func].pop(0)
|
||
|
return out
|
||
|
|
||
|
def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs):
|
||
|
if func in _ignored_ops:
|
||
|
return func(*args, **kwargs)
|
||
|
out = self.pop_from_storage(func, args, kwargs)
|
||
|
return out
|
||
|
|
||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||
|
if kwargs is None:
|
||
|
kwargs = {}
|
||
|
should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs)
|
||
|
if _is_compiling(func, args, kwargs):
|
||
|
return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs)
|
||
|
else:
|
||
|
if should_not_recompute:
|
||
|
out = self.pop_from_storage(func, args, kwargs)
|
||
|
else:
|
||
|
out = func(*args, **kwargs)
|
||
|
return out
|
||
|
|
||
|
|
||
|
def _pt2_selective_checkpoint_context_fn_gen(policy_fn):
|
||
|
"""
|
||
|
A helper function that generates a pair of contexts to be later passed into
|
||
|
`torch.utils.checkpoint` API to implment selective checkpointing.
|
||
|
|
||
|
.. warning::
|
||
|
This is context_fn is intended for use with torch.compile only.
|
||
|
|
||
|
Args:
|
||
|
policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function
|
||
|
to decide whether a particular op should be recomputed in backward pass or not.
|
||
|
In eager mode:
|
||
|
If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
|
||
|
If policy_fn(...) returns False, the op is guaranteed to be recomputed.
|
||
|
In torch.compile mode:
|
||
|
If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
|
||
|
If policy_fn(...) returns False, the op may or may not be recomputed
|
||
|
(it's up to the partitioner to decide).
|
||
|
|
||
|
Returns:
|
||
|
A pair of generated contexts.
|
||
|
|
||
|
Example:
|
||
|
>>> # xdoctest: +REQUIRES(LINUX)
|
||
|
>>>
|
||
|
>>> def get_custom_policy():
|
||
|
>>> no_recompute_list = [
|
||
|
>>> torch.ops.aten.mm.default,
|
||
|
>>> ]
|
||
|
>>> def custom_policy(mode, func, *args, **kwargs):
|
||
|
>>> return func in no_recompute_list
|
||
|
>>> return custom_policy
|
||
|
>>>
|
||
|
>>> def selective_checkpointing_context_fn():
|
||
|
>>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy())
|
||
|
>>>
|
||
|
>>> def gn(x, y):
|
||
|
>>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
||
|
>>>
|
||
|
>>> def fn(x, y):
|
||
|
>>> return torch.utils.checkpoint.checkpoint(
|
||
|
>>> gn, x, y,
|
||
|
>>> use_reentrant=False,
|
||
|
>>> context_fn=selective_checkpointing_context_fn,
|
||
|
>>> )
|
||
|
>>>
|
||
|
>>> x = torch.randn(4, 4, requires_grad=True)
|
||
|
>>> y = torch.randn(4, 4, requires_grad=True)
|
||
|
>>>
|
||
|
>>> compiled_fn = torch.compile(fn)
|
||
|
"""
|
||
|
storage: Dict[Any, List[Any]] = defaultdict(list)
|
||
|
return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage)
|
||
|
|
||
|
|
||
|
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
|
||
|
# saving/restoring of global state is handled here.
|
||
|
|
||
|
def _checkpoint_without_reentrant_generator(
|
||
|
fn,
|
||
|
preserve_rng_state=True,
|
||
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||
|
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
||
|
debug: bool = False,
|
||
|
*args,
|
||
|
**kwargs
|
||
|
):
|
||
|
"""Checkpointing without reentrant autograd.
|
||
|
|
||
|
Args:
|
||
|
function: describes what to run in the forward pass of the model or
|
||
|
part of the model. It should also know how to handle the inputs
|
||
|
passed as the tuple. For example, in LSTM, if user passes
|
||
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
||
|
first input as ``activation`` and the second input as ``hidden``
|
||
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||
|
the RNG state during each checkpoint.
|
||
|
Default: ``True``
|
||
|
context_fn(Callable, optional): A callable returning a tuple of two
|
||
|
context managers. The function and its recomputation will be run
|
||
|
under the first and second context managers respectively.
|
||
|
determinism_check(str, optional): A string specifying the determinism
|
||
|
check to perform. By default it is set to ``"default"`` which
|
||
|
compares the shapes, dtypes, and devices of the recomputed tensors
|
||
|
against those the saved tensors. To turn off this check, specify
|
||
|
``"none"``. Currently these are the only two supported values.
|
||
|
Please open an issue if you would like to see more determinism
|
||
|
checks.
|
||
|
debug(bool, optional): If ``True``, error messages will also include
|
||
|
a trace of the operators ran during the original forward computation
|
||
|
as well as the recomputation.
|
||
|
*args: Arguments to pass in to the given ``function``.
|
||
|
**kwargs: Keyword arguments to pass into the given ``function``.
|
||
|
"""
|
||
|
unpack_error_cb = None
|
||
|
|
||
|
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
|
||
|
if context_fn != noop_context_fn:
|
||
|
raise ValueError(
|
||
|
"debug=True is incompatible with non-default context_fn"
|
||
|
)
|
||
|
context_fn, unpack_error_cb = _get_debug_context_and_cb()
|
||
|
|
||
|
if determinism_check in _allowed_determinism_checks_to_fns:
|
||
|
metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
|
||
|
f"but got {determinism_check}"
|
||
|
)
|
||
|
|
||
|
device = _infer_device_type(*args)
|
||
|
device_module = _get_device_module(device)
|
||
|
forward_context, recompute_context = context_fn()
|
||
|
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
|
||
|
assert (
|
||
|
isinstance(forward_context, TorchDispatchMode) and
|
||
|
isinstance(recompute_context, TorchDispatchMode)
|
||
|
), \
|
||
|
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
|
||
|
"must generate a tuple of two `TorchDispatchMode`s."
|
||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||
|
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device)
|
||
|
|
||
|
if preserve_rng_state:
|
||
|
fwd_cpu_state = torch.get_rng_state()
|
||
|
# Don't eagerly initialize the cuda context by accident.
|
||
|
# (If the user intends that the context is initialized later, within their
|
||
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||
|
# we have no way to anticipate this will happen before we run the function.
|
||
|
# If they do so, we raise an error.)
|
||
|
had_device_in_fwd = False
|
||
|
if getattr(device_module, "_initialized", False):
|
||
|
had_device_in_fwd = True
|
||
|
fwd_devices, fwd_device_states = get_device_states(*args)
|
||
|
|
||
|
def recompute_fn(*inputs):
|
||
|
kwargs, *args = inputs
|
||
|
# This will be called later during recomputation. This wrapping enables
|
||
|
# the necessary global state to be captured.
|
||
|
rng_devices = []
|
||
|
if preserve_rng_state and had_device_in_fwd:
|
||
|
rng_devices = fwd_devices
|
||
|
with torch.random.fork_rng(
|
||
|
devices=rng_devices, enabled=preserve_rng_state, device_type=device
|
||
|
):
|
||
|
if preserve_rng_state:
|
||
|
torch.set_rng_state(fwd_cpu_state)
|
||
|
if had_device_in_fwd:
|
||
|
set_device_states(fwd_devices, fwd_device_states)
|
||
|
|
||
|
device_autocast_ctx = device_module.amp.autocast(
|
||
|
**device_autocast_kwargs
|
||
|
) if _supports_autocast(device) else contextlib.nullcontext()
|
||
|
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||
|
recompute_context:
|
||
|
fn(*args, **kwargs)
|
||
|
|
||
|
new_frame = _CheckpointFrame(
|
||
|
recompute_fn,
|
||
|
_enable_checkpoint_early_stop,
|
||
|
unpack_error_cb,
|
||
|
metadata_fn
|
||
|
)
|
||
|
dummy = torch.empty((0,), requires_grad=True)
|
||
|
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
|
||
|
|
||
|
# When ambient grad_mode is False
|
||
|
if new_frame.input_saver.grad_fn is None:
|
||
|
yield
|
||
|
return
|
||
|
|
||
|
with _checkpoint_hook(new_frame), forward_context:
|
||
|
yield
|
||
|
new_frame.forward_completed = True
|
||
|
|
||
|
if getattr(device_module, "_initialized", False) and \
|
||
|
preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined]
|
||
|
# Device was not initialized before running the forward, so we didn't
|
||
|
# stash the device state.
|
||
|
raise RuntimeError(
|
||
|
"PyTorch's device state was initialized in the forward pass "
|
||
|
"of a Checkpoint, which is not allowed. Please open an issue "
|
||
|
"if you need this feature."
|
||
|
)
|
||
|
|
||
|
return
|