349 lines
13 KiB
Python
349 lines
13 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||
|
|
||
|
import dataclasses
|
||
|
from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch._utils import _get_device_module
|
||
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||
|
from torch.distributed._shard.sharded_tensor.metadata import (
|
||
|
TensorProperties as ShardTensorProperties,
|
||
|
)
|
||
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
||
|
from torch.distributed._tensor import DTensor
|
||
|
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
||
|
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
||
|
from torch.distributed.checkpoint.metadata import (
|
||
|
BytesStorageMetadata,
|
||
|
ChunkStorageMetadata,
|
||
|
Metadata,
|
||
|
MetadataIndex,
|
||
|
STATE_DICT_TYPE,
|
||
|
TensorProperties,
|
||
|
TensorStorageMetadata,
|
||
|
)
|
||
|
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
|
||
|
from torch.distributed.checkpoint.planner_helpers import (
|
||
|
_create_read_items,
|
||
|
create_read_items_for_chunk_list,
|
||
|
)
|
||
|
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
|
||
|
from torch.distributed.checkpoint.storage import StorageReader
|
||
|
from torch.distributed.checkpoint.utils import (
|
||
|
_element_wise_add,
|
||
|
_element_wise_sub,
|
||
|
_normalize_device_info,
|
||
|
)
|
||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
||
|
from torch.distributed.remote_device import _remote_device
|
||
|
|
||
|
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
||
|
|
||
|
|
||
|
# TODO: Update docstrings for optimizer.py
|
||
|
__all__ = [
|
||
|
"load_sharded_optimizer_state_dict",
|
||
|
]
|
||
|
|
||
|
|
||
|
def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
|
||
|
if device_type == "cpu":
|
||
|
return "cpu"
|
||
|
device_module = _get_device_module(device_type)
|
||
|
if device_module.is_available():
|
||
|
return _normalize_device_info(
|
||
|
device_type, global_rank % device_module.device_count()
|
||
|
)
|
||
|
return "cpu"
|
||
|
|
||
|
|
||
|
def _create_colwise_spec(
|
||
|
pg: Optional[dist.ProcessGroup] = None,
|
||
|
) -> ChunkShardingSpec:
|
||
|
pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
|
||
|
if pg is None:
|
||
|
placements = [
|
||
|
f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
|
||
|
for idx in range(dist.get_world_size())
|
||
|
]
|
||
|
else:
|
||
|
placements = [
|
||
|
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
|
||
|
for idx in range(pg.size())
|
||
|
]
|
||
|
return ChunkShardingSpec(
|
||
|
dim=0,
|
||
|
placements=cast(List[Union[_remote_device, str]], placements),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _is_nested_tensor(val: torch.Tensor) -> bool:
|
||
|
if type(val) is ShardedTensor:
|
||
|
if len(val.local_shards()) == 0:
|
||
|
return False
|
||
|
if type(val.local_shards()[0].tensor) is ShardedTensor:
|
||
|
return True
|
||
|
if type(val.local_shards()[0].tensor) is DTensor:
|
||
|
raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
|
||
|
elif type(val) is DTensor and (
|
||
|
type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
|
||
|
):
|
||
|
raise ValueError("Cannot handle nested DTensor")
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _alloc_tensor(
|
||
|
props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
|
||
|
) -> torch.Tensor:
|
||
|
return torch.empty(
|
||
|
size=size,
|
||
|
dtype=props.dtype,
|
||
|
layout=props.layout,
|
||
|
requires_grad=props.requires_grad,
|
||
|
pin_memory=props.pin_memory,
|
||
|
device=cast(torch.device, _get_device_module(device_type).current_device()),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _get_state_dict_2d_layout(
|
||
|
state_dict: STATE_DICT_TYPE,
|
||
|
) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
|
||
|
"""
|
||
|
Load the right TP slice of the optimizer state.
|
||
|
|
||
|
This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
|
||
|
We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
|
||
|
This is pretty fragile and it might be easier for FSDP to compute this info for us.
|
||
|
Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
|
||
|
(offset, size) for the current rank TP slice.
|
||
|
N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
|
||
|
"""
|
||
|
specs: STATE_DICT_2D_LAYOUT = {}
|
||
|
dp_pg: Optional[dist.ProcessGroup] = None
|
||
|
for key, value in state_dict.items():
|
||
|
specs[key] = (None, value.size())
|
||
|
if _is_nested_tensor(value):
|
||
|
assert (
|
||
|
len(value.local_shards()) == 1
|
||
|
), "Cannot handle ST with multiple shards"
|
||
|
assert isinstance(
|
||
|
value, ShardedTensor
|
||
|
), "Can only handle nested ShardedTensor"
|
||
|
shard = value.local_shards()[0]
|
||
|
specs[key] = (
|
||
|
shard.metadata.shard_offsets,
|
||
|
shard.metadata.shard_sizes,
|
||
|
)
|
||
|
dp_pg = shard.tensor._process_group # type: ignore[attr-defined]
|
||
|
|
||
|
return (
|
||
|
specs,
|
||
|
dp_pg,
|
||
|
)
|
||
|
|
||
|
|
||
|
class _ReaderWithOffset(DefaultLoadPlanner):
|
||
|
translation: Dict[MetadataIndex, MetadataIndex]
|
||
|
state_dict: STATE_DICT_TYPE
|
||
|
metadata: Metadata
|
||
|
|
||
|
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
|
||
|
super().__init__()
|
||
|
self.fqn_to_offset = fqn_to_offset
|
||
|
self.metadata = Metadata({})
|
||
|
self.state_dict = {}
|
||
|
self.translation = {}
|
||
|
|
||
|
def create_local_plan(self) -> LoadPlan:
|
||
|
requests = []
|
||
|
self.translation = {}
|
||
|
for fqn, obj in self.state_dict.items():
|
||
|
md = self.metadata.state_dict_metadata[fqn]
|
||
|
if not isinstance(obj, ShardedTensor):
|
||
|
requests += _create_read_items(fqn, md, obj)
|
||
|
continue
|
||
|
|
||
|
if fqn not in self.fqn_to_offset:
|
||
|
requests += _create_read_items(fqn, md, obj)
|
||
|
continue
|
||
|
|
||
|
offset = self.fqn_to_offset[fqn]
|
||
|
|
||
|
assert len(obj.local_shards()) == 1
|
||
|
original_shard = obj.local_shards()[0]
|
||
|
local_chunks = [
|
||
|
ChunkStorageMetadata(
|
||
|
offsets=torch.Size(
|
||
|
_element_wise_add(original_shard.metadata.shard_offsets, offset)
|
||
|
),
|
||
|
sizes=torch.Size(original_shard.metadata.shard_sizes),
|
||
|
)
|
||
|
]
|
||
|
|
||
|
reqs = create_read_items_for_chunk_list(
|
||
|
fqn, cast(TensorStorageMetadata, md), local_chunks
|
||
|
)
|
||
|
# TODO: The ReadItems will have a displaced MetadataIndex, fix it.
|
||
|
# TODO: we should change _create_sharded_read_items to have more ergonomic API
|
||
|
for ri in reqs:
|
||
|
assert ri.dest_index.offset is not None
|
||
|
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
|
||
|
original_index = dataclasses.replace(
|
||
|
ri.dest_index, offset=torch.Size(original_offset)
|
||
|
)
|
||
|
self.translation[ri.dest_index] = original_index
|
||
|
|
||
|
requests += reqs
|
||
|
return LoadPlan(requests)
|
||
|
|
||
|
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
|
||
|
return super().lookup_tensor(self.translation.get(index, index))
|
||
|
|
||
|
|
||
|
def load_sharded_optimizer_state_dict(
|
||
|
model_state_dict: STATE_DICT_TYPE,
|
||
|
optimizer_key: str,
|
||
|
storage_reader: StorageReader,
|
||
|
planner: Optional[LoadPlanner] = None,
|
||
|
) -> STATE_DICT_TYPE:
|
||
|
"""
|
||
|
Load a state_dict in conjunction with FSDP sharded optimizer state.
|
||
|
|
||
|
This is the current recommended way to checkpoint FSDP.
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> import torch.distributed.checkpoint as dist_cp
|
||
|
>>> # Save
|
||
|
>>> model: torch.nn.Model
|
||
|
>>> optim_params = model.parameters()
|
||
|
>>> optim = torch.optim.SGD(optim_params, lr=0.01)
|
||
|
>>> # Save
|
||
|
>>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
||
|
>>> state_dict = {
|
||
|
>>> "optimizer": FSDP.optim_state_dict(model, optim),
|
||
|
>>> "model": model.state_dict()
|
||
|
>>> }
|
||
|
>>> dist_cp.save_state_dict(
|
||
|
>>> state_dict=optim_state,
|
||
|
>>> storage_writer=dist_cp.FileSystemWriter("checkpoint"),
|
||
|
>>> planner=dist_cp.DefaultSavePlanner(),
|
||
|
>>> )
|
||
|
>>>
|
||
|
>>> # Load
|
||
|
>>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
|
||
|
>>> model_state_dict = model_tp.state_dict()
|
||
|
>>> checkpoint = {
|
||
|
>>> "model": model_state_dict
|
||
|
>>> }
|
||
|
>>> dist_cp.load_state_dict(
|
||
|
>>> state_dict=checkpoint,
|
||
|
>>> storage_reader=dist_cp.FileSystemReader(checkpoint_file),
|
||
|
>>> planner=dist_cp.DefaultLoadPlanner(),
|
||
|
>>> )
|
||
|
>>> model.load_state_dict(checkpoint["model_state"])
|
||
|
>>>
|
||
|
>>> optim_state = dist_cp.load_sharded_optimizer_state_dict(
|
||
|
>>> model_state_dict,
|
||
|
>>> optimizer_key="optimizer",
|
||
|
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
|
||
|
>>> )
|
||
|
>>>
|
||
|
>>> flattened_osd = FSDP.optim_state_dict_to_load(
|
||
|
>>> model, optim, optim_state["optimizer"]
|
||
|
>>> )
|
||
|
>>>
|
||
|
>>> optim.load_state_dict(flattened_osd)
|
||
|
"""
|
||
|
metadata = storage_reader.read_metadata()
|
||
|
|
||
|
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
|
||
|
dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
|
||
|
device_module = _get_device_module(dp_pg_device_type)
|
||
|
|
||
|
if dp_pg is None:
|
||
|
placements = []
|
||
|
for i in range(dist.get_world_size()):
|
||
|
device_info = _normalize_device_info(
|
||
|
dp_pg_device_type, i % device_module.device_count()
|
||
|
)
|
||
|
placements.append(f"rank:{i}/{device_info}")
|
||
|
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
|
||
|
else:
|
||
|
sharding_spec = _create_colwise_spec(dp_pg)
|
||
|
|
||
|
# Create a state_dict for optimizer state
|
||
|
state_dict: STATE_DICT_TYPE = {}
|
||
|
|
||
|
fqn_to_offset: Dict[str, Sequence[int]] = {}
|
||
|
for key, value in metadata.state_dict_metadata.items():
|
||
|
key_path = metadata.planner_data[key]
|
||
|
if key_path[0] != optimizer_key:
|
||
|
continue
|
||
|
|
||
|
if isinstance(value, BytesStorageMetadata):
|
||
|
state_dict[key] = "<bytes_io>"
|
||
|
continue
|
||
|
|
||
|
# value: TensorStorageMetadata
|
||
|
if value.size.numel() == 1:
|
||
|
state_dict[key] = _alloc_tensor(
|
||
|
value.properties, value.size, dp_pg_device_type
|
||
|
)
|
||
|
elif dp_pg is None:
|
||
|
state_dict[key] = _create_chunk_sharded_tensor(
|
||
|
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
|
||
|
rank=dist.get_rank(),
|
||
|
world_size=dist.get_world_size(),
|
||
|
num_devices_per_node=device_module.device_count(),
|
||
|
pg=_get_default_group(),
|
||
|
)
|
||
|
else:
|
||
|
spec_key = key_path[2]
|
||
|
alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
|
||
|
|
||
|
properties = ShardTensorProperties(
|
||
|
dtype=value.properties.dtype,
|
||
|
layout=value.properties.layout,
|
||
|
requires_grad=value.properties.requires_grad,
|
||
|
memory_format=value.properties.memory_format,
|
||
|
pin_memory=value.properties.pin_memory,
|
||
|
)
|
||
|
|
||
|
st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties)
|
||
|
local_shards = []
|
||
|
current_rank = dist.get_rank(dp_pg)
|
||
|
for shard_md in st_md.shards_metadata:
|
||
|
if cast(_remote_device, shard_md.placement).rank() != current_rank:
|
||
|
continue
|
||
|
local_shards.append(
|
||
|
Shard(
|
||
|
tensor=_alloc_tensor(
|
||
|
value.properties, shard_md.shard_sizes, dp_pg_device_type
|
||
|
),
|
||
|
metadata=shard_md,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
st = ShardedTensor._init_from_local_shards_and_global_metadata(
|
||
|
local_shards, st_md, process_group=dp_pg
|
||
|
)
|
||
|
|
||
|
if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
|
||
|
fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
|
||
|
|
||
|
state_dict[key] = st
|
||
|
|
||
|
# Whether we unflatten before or after doesn't matter
|
||
|
load_state_dict(
|
||
|
state_dict=state_dict,
|
||
|
storage_reader=storage_reader,
|
||
|
# FIXME the type of planner is wrong in load_state_dict
|
||
|
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner,
|
||
|
)
|
||
|
|
||
|
state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
|
||
|
|
||
|
return state_dict
|