71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
|
from typing import List, Tuple
|
||
|
|
||
|
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
|
||
|
|
||
|
__all__: List[str] = []
|
||
|
|
||
|
|
||
|
def _check_shard_metadata_pair_overlap(
|
||
|
shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
|
||
|
):
|
||
|
"""Check if two shards overlap."""
|
||
|
# For each dim of each shard, check if one shard resides on the other
|
||
|
# end of second shard with respect to that dim. As an example for a 2D
|
||
|
# shard, we would check if one shard is above or on the left of the
|
||
|
# other shard.
|
||
|
ndims = len(shard1.offsets)
|
||
|
for i in range(ndims):
|
||
|
if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]:
|
||
|
return False
|
||
|
if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]:
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _shards_get_overlap_region_wrt_saved_tensor(
|
||
|
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
|
||
|
) -> List[Tuple[int, int, int, int]]:
|
||
|
"""
|
||
|
Return the overlapping region between saved_shard and current_shard.
|
||
|
|
||
|
There returned list has the same number of elements as the tensor's dimension.
|
||
|
For each element, we produce a tuple with the following contents:
|
||
|
(dimension, `saved_shard` offset, `current_shard` offset, length)
|
||
|
|
||
|
Offsets are relative to each shard.
|
||
|
"""
|
||
|
narrows = []
|
||
|
for dim, (
|
||
|
saved_shard_offset,
|
||
|
current_shard_offset,
|
||
|
saved_shard_size,
|
||
|
current_shard_size,
|
||
|
) in enumerate(
|
||
|
zip(
|
||
|
saved_shard.offsets,
|
||
|
current_shard.offsets,
|
||
|
saved_shard.sizes,
|
||
|
current_shard.sizes,
|
||
|
)
|
||
|
):
|
||
|
min_range_end = min(
|
||
|
saved_shard_offset + saved_shard_size,
|
||
|
current_shard_offset + current_shard_size,
|
||
|
)
|
||
|
|
||
|
length = min_range_end - max(current_shard_offset, saved_shard_offset)
|
||
|
|
||
|
if saved_shard_offset > current_shard_offset:
|
||
|
offset_for_saved_tensor = 0
|
||
|
offset_for_current_tensor = saved_shard_offset - current_shard_offset
|
||
|
else:
|
||
|
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
|
||
|
offset_for_current_tensor = 0
|
||
|
|
||
|
narrows.append(
|
||
|
(dim, offset_for_saved_tensor, offset_for_current_tensor, length)
|
||
|
)
|
||
|
|
||
|
return narrows
|