109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
|
# Copyright 2023 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Sharding utilities"""
|
||
|
|
||
|
import itertools
|
||
|
from typing import List, Sequence, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from jax._src.lib import xla_client as xc
|
||
|
|
||
|
|
||
|
def get_num_ways_dim_sharded(
|
||
|
hlo_sharding: xc.HloSharding) -> Tuple[Sequence[int], int]:
|
||
|
if hlo_sharding.is_replicated(): # type: ignore
|
||
|
return [], 1
|
||
|
partitions = hlo_sharding.tile_assignment_dimensions()
|
||
|
subgroup_types = hlo_sharding.subgroup_types()
|
||
|
|
||
|
if subgroup_types == [xc.OpSharding.Type.REPLICATED]:
|
||
|
replicate_on_last_tile_dim = True
|
||
|
else:
|
||
|
replicate_on_last_tile_dim = hlo_sharding.replicate_on_last_tile_dim()
|
||
|
if subgroup_types:
|
||
|
raise NotImplementedError(
|
||
|
"Unhandled OpSharding type. Please open a bug report!")
|
||
|
num_replicas = 1
|
||
|
if replicate_on_last_tile_dim:
|
||
|
num_replicas = partitions[-1]
|
||
|
partitions = partitions[:-1]
|
||
|
return partitions, num_replicas
|
||
|
|
||
|
|
||
|
def is_op_sharding_replicated(op: Union[xc.OpSharding, xc.HloSharding]) -> bool:
|
||
|
if isinstance(op, xc.OpSharding):
|
||
|
op = xc.HloSharding.from_proto(op)
|
||
|
if op.num_devices() == 1:
|
||
|
return True
|
||
|
return op.is_replicated() # type: ignore
|
||
|
|
||
|
|
||
|
def are_op_shardings_equal(op1: Union[xc.OpSharding, xc.HloSharding],
|
||
|
op2: Union[xc.OpSharding, xc.HloSharding]) -> bool:
|
||
|
if id(op1) == id(op2):
|
||
|
return True
|
||
|
if is_op_sharding_replicated(op1) and is_op_sharding_replicated(op2):
|
||
|
return True
|
||
|
hc1 = xc.HloSharding.from_proto(op1) if isinstance(op1, xc.OpSharding) else op1
|
||
|
hc2 = xc.HloSharding.from_proto(op2) if isinstance(op2, xc.OpSharding) else op2
|
||
|
return hc1 == hc2
|
||
|
|
||
|
|
||
|
_Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||
|
|
||
|
|
||
|
def op_sharding_to_numpy_indices(
|
||
|
hlo_sharding: xc.HloSharding, shape: Sequence[int],
|
||
|
num_devices: int) -> np.ndarray:
|
||
|
indices = np.empty(num_devices, dtype=np.object_)
|
||
|
|
||
|
# num_devices is required as an argument when hlo_sharding is
|
||
|
# REPLICATED. `jax.device_count()` cannot be used because you can create
|
||
|
# an opsharding with less number of devices than `jax.device_count()`.
|
||
|
if is_op_sharding_replicated(hlo_sharding):
|
||
|
indices.fill((slice(None),) * len(shape))
|
||
|
return indices
|
||
|
|
||
|
assert num_devices == hlo_sharding.num_devices()
|
||
|
|
||
|
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
|
||
|
assert len(partitions) == len(shape), (len(partitions), len(shape))
|
||
|
|
||
|
axis_indices: List[Sequence[_Index]] = []
|
||
|
for dim, n_shards in zip(shape, partitions):
|
||
|
if n_shards == 1:
|
||
|
axis_indices.append([slice(None)])
|
||
|
elif n_shards > 1:
|
||
|
shard_size, ragged = divmod(dim, n_shards)
|
||
|
assert not ragged, (dim, n_shards)
|
||
|
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||
|
for i in range(n_shards)])
|
||
|
else:
|
||
|
raise AssertionError('Unrecognized number of shards. Please file a bug!')
|
||
|
|
||
|
device_it = iter(hlo_sharding.tile_assignment_devices())
|
||
|
|
||
|
for i, idxs in enumerate(itertools.product(*axis_indices)):
|
||
|
for _ in range(num_replicas):
|
||
|
indices[next(device_it)] = idxs
|
||
|
return indices
|
||
|
|
||
|
|
||
|
def op_sharding_to_indices(
|
||
|
op_sharding: xc.HloSharding, shape: Sequence[int],
|
||
|
num_devices: int) -> Tuple[Tuple[slice, ...], ...]:
|
||
|
indices = op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
|
||
|
return tuple(indices.flat)
|