340 lines
12 KiB
Python
340 lines
12 KiB
Python
|
import dataclasses
|
||
|
import traceback
|
||
|
from typing import Any, Callable, Container, Dict, List, Optional, OrderedDict, Tuple, TypeVar, overload
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch import nn
|
||
|
from torch.nn.parallel._functions import _get_stream
|
||
|
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
||
|
from torch.nn.utils.rnn import PackedSequence
|
||
|
|
||
|
__all__ = [] # type: ignore[var-annotated]
|
||
|
|
||
|
|
||
|
def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
|
||
|
"""
|
||
|
Turn argument list into separate key list and value list (unpack_kwargs does the opposite).
|
||
|
|
||
|
Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70
|
||
|
Usage::
|
||
|
|
||
|
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
|
||
|
assert kwarg_keys == ("a", "b")
|
||
|
assert flat_args == (1, 2, 3, 4)
|
||
|
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
|
||
|
assert args == (1, 2)
|
||
|
assert kwargs == {"a": 3, "b": 4}
|
||
|
Returns:
|
||
|
Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives
|
||
|
gives both positional args and kwarg values, where the positional args
|
||
|
proceed kwarg values and kwarg values are ordered consistently with the
|
||
|
kwarg keys. The second tuple element gives the kwarg keys.
|
||
|
The second tuple element's length is at most the first tuple element's length.
|
||
|
"""
|
||
|
kwarg_keys: List[str] = []
|
||
|
flat_args: List[Any] = list(args)
|
||
|
for k, v in kwargs.items():
|
||
|
kwarg_keys.append(k)
|
||
|
flat_args.append(v)
|
||
|
|
||
|
return tuple(flat_args), tuple(kwarg_keys)
|
||
|
|
||
|
def _cast_forward_inputs(
|
||
|
dtype: Optional[torch.dtype],
|
||
|
*args: Any,
|
||
|
**kwargs: Any,
|
||
|
) -> Tuple[Any, Any]:
|
||
|
"""
|
||
|
Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
|
||
|
|
||
|
This respects the existing ``requires_grad`` on the tensors.
|
||
|
"""
|
||
|
if dtype is None:
|
||
|
return args, kwargs
|
||
|
|
||
|
def cast_fn(x: torch.Tensor) -> torch.Tensor:
|
||
|
if not torch.is_floating_point(x) or x.dtype == dtype:
|
||
|
return x
|
||
|
return x.to(dtype)
|
||
|
|
||
|
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
|
||
|
|
||
|
def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
||
|
"""See _pack_kwargs."""
|
||
|
assert len(kwarg_keys) <= len(
|
||
|
flat_args
|
||
|
), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
|
||
|
if len(kwarg_keys) == 0:
|
||
|
return flat_args, {}
|
||
|
args = flat_args[: -len(kwarg_keys)]
|
||
|
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
|
||
|
return args, kwargs
|
||
|
|
||
|
|
||
|
S = TypeVar("S", dict, list, tuple)
|
||
|
T = TypeVar("T", torch.Tensor, PackedSequence)
|
||
|
|
||
|
|
||
|
@overload
|
||
|
def _recursive_to(inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> List[S]:
|
||
|
...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
def _recursive_to(inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> Tuple[T]:
|
||
|
...
|
||
|
|
||
|
|
||
|
def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
|
||
|
r"""Recursively moves input to the target_device."""
|
||
|
|
||
|
def to_map(obj):
|
||
|
if isinstance(obj, (torch.Tensor, PackedSequence)):
|
||
|
device = obj.data.device if isinstance(obj, PackedSequence) else obj.device
|
||
|
if device == target_device:
|
||
|
return (obj,)
|
||
|
if not use_side_stream_for_tensor_copies:
|
||
|
return (obj.to(target_device),)
|
||
|
else:
|
||
|
# If the custom module is not registered to torch, stream is not used for acceleration
|
||
|
device_mod = getattr(torch, device.type, None)
|
||
|
if device.type == "cpu" or device_mod is None:
|
||
|
return (obj.to(target_device),)
|
||
|
# Perform CPU -> target_device copies in a background stream. This code is
|
||
|
# motivated from similar logic in torch/nn/parallel/_functions.py
|
||
|
stream = _get_stream(target_device)
|
||
|
with device_mod.stream(stream):
|
||
|
output = obj.to(target_device)
|
||
|
# synchronize with the copy stream
|
||
|
with device_mod.device(target_device.index):
|
||
|
current_stream = device_mod.current_stream()
|
||
|
# Sync the current stream with the copy stream
|
||
|
current_stream.wait_stream(stream)
|
||
|
# Ensure tensor memory is not reused until work on
|
||
|
# main stream is complete
|
||
|
if isinstance(obj, PackedSequence):
|
||
|
output.data.record_stream(current_stream) # type: ignore[arg-type]
|
||
|
else:
|
||
|
assert isinstance(output, torch.Tensor)
|
||
|
output.record_stream(current_stream) # type: ignore[arg-type]
|
||
|
return (output,)
|
||
|
if _is_namedtuple(obj):
|
||
|
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
|
||
|
if isinstance(obj, tuple) and len(obj) > 0:
|
||
|
return list(zip(*map(to_map, obj)))
|
||
|
if isinstance(obj, list) and len(obj) > 0:
|
||
|
return [list(i) for i in zip(*map(to_map, obj))]
|
||
|
if isinstance(obj, dict) and len(obj) > 0:
|
||
|
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
|
||
|
return [obj]
|
||
|
|
||
|
# Avoid reference cycle
|
||
|
try:
|
||
|
res = to_map(inputs)
|
||
|
finally:
|
||
|
to_map = None # type: ignore[assignment]
|
||
|
return res
|
||
|
|
||
|
|
||
|
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
|
||
|
"""Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed."""
|
||
|
if not cond:
|
||
|
print(s)
|
||
|
traceback.print_stack()
|
||
|
if raise_assertion_error:
|
||
|
raise AssertionError(s)
|
||
|
|
||
|
|
||
|
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
|
||
|
"""
|
||
|
Allocate storage for ``tensor`` with the given size.
|
||
|
|
||
|
Returns:
|
||
|
bool: ``True`` if this method allocated storage and ``False`` if the
|
||
|
storage was already allocated.
|
||
|
"""
|
||
|
with torch.no_grad():
|
||
|
if (
|
||
|
not torch.distributed._functional_collectives.is_torchdynamo_compiling()
|
||
|
):
|
||
|
already_allocated = tensor._typed_storage()._size() == size.numel()
|
||
|
if not already_allocated:
|
||
|
tensor_storage_size = tensor._typed_storage()._size()
|
||
|
_p_assert(
|
||
|
tensor_storage_size == 0,
|
||
|
"Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
|
||
|
)
|
||
|
tensor._typed_storage()._resize_(size.numel())
|
||
|
|
||
|
|
||
|
def _free_storage(tensor: torch.Tensor):
|
||
|
"""
|
||
|
Frees the underlying storage of ``tensor``.
|
||
|
|
||
|
Returns:
|
||
|
bool: ``True`` if the method freed the storage and ``False`` if the
|
||
|
storage was already freed.
|
||
|
"""
|
||
|
with torch.no_grad():
|
||
|
if (
|
||
|
not torch.distributed._functional_collectives.is_torchdynamo_compiling()
|
||
|
):
|
||
|
already_freed = tensor._typed_storage()._size() == 0
|
||
|
if not already_freed:
|
||
|
_p_assert(
|
||
|
tensor.storage_offset() == 0,
|
||
|
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
|
||
|
f"storage offset: {tensor.storage_offset()}\n"
|
||
|
f"storage size: {tensor._typed_storage()._size()}\n"
|
||
|
f"tensor shape: {tensor.shape}",
|
||
|
)
|
||
|
tensor._typed_storage()._resize_(0)
|
||
|
|
||
|
|
||
|
|
||
|
Q = TypeVar("Q")
|
||
|
R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any)
|
||
|
|
||
|
|
||
|
@overload
|
||
|
def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q:
|
||
|
...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R:
|
||
|
...
|
||
|
|
||
|
|
||
|
def _apply_to_tensors(fn, container):
|
||
|
"""Recursively apply to all tensor in different kinds of container types."""
|
||
|
|
||
|
def apply(x):
|
||
|
if isinstance(x, torch.Tensor):
|
||
|
return fn(x)
|
||
|
elif hasattr(x, "__dataclass_fields__"):
|
||
|
dc = dataclasses.replace(x)
|
||
|
for f in dataclasses.fields(dc):
|
||
|
name = f.name
|
||
|
setattr(dc, name, apply(getattr(dc, name)))
|
||
|
return dc
|
||
|
elif isinstance(x, OrderedDict):
|
||
|
od = x.__class__()
|
||
|
for key, value in x.items():
|
||
|
od[key] = apply(value)
|
||
|
return od
|
||
|
elif isinstance(x, PackedSequence):
|
||
|
apply(x.data)
|
||
|
return x
|
||
|
elif isinstance(x, dict):
|
||
|
return {key: apply(value) for key, value in x.items()}
|
||
|
elif _is_namedtuple(x):
|
||
|
res = (apply(el) for el in x)
|
||
|
return type(x)(*res)
|
||
|
elif isinstance(x, (list, tuple, set)):
|
||
|
return type(x)(apply(el) for el in x)
|
||
|
else:
|
||
|
return x
|
||
|
|
||
|
return apply(container)
|
||
|
|
||
|
|
||
|
def _to_kwargs(
|
||
|
inputs: Tuple[Any, ...],
|
||
|
kwargs: Optional[Dict[str, Any]],
|
||
|
target_device: torch.device,
|
||
|
use_side_stream_for_tensor_copies: bool,
|
||
|
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
|
||
|
moved_inputs = (
|
||
|
_recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
|
||
|
if inputs
|
||
|
else []
|
||
|
)
|
||
|
moved_kwargs = (
|
||
|
_recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies)
|
||
|
if kwargs
|
||
|
else []
|
||
|
)
|
||
|
if len(moved_inputs) < len(moved_kwargs):
|
||
|
moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))])
|
||
|
elif len(moved_kwargs) < len(moved_inputs):
|
||
|
moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
|
||
|
return tuple(moved_inputs), tuple(moved_kwargs)
|
||
|
|
||
|
|
||
|
def _verify_param_shape_across_processes(
|
||
|
process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None
|
||
|
):
|
||
|
return dist._verify_params_across_processes(process_group, tensors, logger)
|
||
|
|
||
|
|
||
|
def _sync_module_states(
|
||
|
module: nn.Module,
|
||
|
process_group: dist.ProcessGroup,
|
||
|
broadcast_bucket_size: int,
|
||
|
src: int,
|
||
|
params_and_buffers_to_ignore: Container[str],
|
||
|
broadcast_buffers: bool = True,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Sync ``module``'s parameters and buffers state.
|
||
|
|
||
|
Syncs ``module``'s parameters and buffers state so that all ranks contain
|
||
|
the same module state across all ranks. Note that this API assumes that all
|
||
|
parameter shapes are consistent before running the synchronization. This can
|
||
|
be checked with ``_verify_param_shape_across_processes``.
|
||
|
"""
|
||
|
module_states: List[torch.Tensor] = []
|
||
|
for name, param in module.named_parameters():
|
||
|
if name not in params_and_buffers_to_ignore:
|
||
|
module_states.append(param.detach())
|
||
|
|
||
|
if broadcast_buffers:
|
||
|
for name, buffer in module.named_buffers():
|
||
|
if name not in params_and_buffers_to_ignore:
|
||
|
module_states.append(buffer.detach())
|
||
|
|
||
|
_sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
|
||
|
|
||
|
|
||
|
def _sync_params_and_buffers(
|
||
|
process_group: dist.ProcessGroup,
|
||
|
module_states: List[torch.Tensor],
|
||
|
broadcast_bucket_size: int,
|
||
|
src: int,
|
||
|
) -> None:
|
||
|
"""Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0."""
|
||
|
if len(module_states) > 0:
|
||
|
dist._broadcast_coalesced(
|
||
|
process_group, module_states, broadcast_bucket_size, src
|
||
|
)
|
||
|
|
||
|
|
||
|
def _replace_by_prefix(
|
||
|
state_dict: Dict[str, Any],
|
||
|
old_prefix: str,
|
||
|
new_prefix: str,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Replace all keys that match a given old_prefix with a new_prefix (in-place).
|
||
|
|
||
|
Usage::
|
||
|
|
||
|
state_dict = {"layer.xyz": torch.tensor(1)}
|
||
|
replace_by_prefix_(state_dict, "layer.", "module.layer.")
|
||
|
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
|
||
|
"""
|
||
|
if old_prefix == new_prefix:
|
||
|
raise ValueError("old_prefix and new_prefix must be distinct")
|
||
|
for key in list(state_dict.keys()):
|
||
|
if not key.startswith(old_prefix):
|
||
|
continue
|
||
|
new_key = new_prefix + key[len(old_prefix) :]
|
||
|
state_dict[new_key] = state_dict[key]
|
||
|
del state_dict[key]
|
||
|
|
||
|
|
||
|
def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
|
||
|
return tensor.untyped_storage().data_ptr() > 0
|