321 lines
13 KiB
Python
321 lines
13 KiB
Python
|
# Copyright 2018 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
# A ShardingSpec describes at a high level how a logical array is sharded across
|
||
|
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
||
|
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
||
|
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
||
|
# how the sharded array is "laid out" across devices. Given a sequence of
|
||
|
# devices, we shard the data across the devices in row-major order, with
|
||
|
# replication treated as an extra inner dimension.
|
||
|
#
|
||
|
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
||
|
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
||
|
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
||
|
#
|
||
|
# This encoding is assumed by various parts of the system, e.g. generating
|
||
|
# replica groups for collective operations.
|
||
|
|
||
|
import collections
|
||
|
import functools
|
||
|
import itertools
|
||
|
import math
|
||
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from jax._src import op_shardings
|
||
|
from jax._src import util
|
||
|
from jax._src.lib import pmap_lib
|
||
|
from jax._src.lib import xla_client as xc
|
||
|
|
||
|
unsafe_map, map = map, util.safe_map
|
||
|
|
||
|
NoSharding = pmap_lib.NoSharding
|
||
|
Chunked = pmap_lib.Chunked
|
||
|
Unstacked = pmap_lib.Unstacked
|
||
|
|
||
|
_UNSHARDED_INSTANCE = NoSharding()
|
||
|
|
||
|
ShardedAxis = pmap_lib.ShardedAxis
|
||
|
Replicated = pmap_lib.Replicated
|
||
|
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||
|
|
||
|
ShardingSpec = pmap_lib.ShardingSpec
|
||
|
|
||
|
OpShardingType = Any
|
||
|
|
||
|
|
||
|
|
||
|
def _sharding_spec_mesh_shape(self):
|
||
|
sharded_axis_sizes = []
|
||
|
for sharding in self.sharding:
|
||
|
if isinstance(sharding, NoSharding):
|
||
|
continue
|
||
|
elif isinstance(sharding, Unstacked):
|
||
|
sharded_axis_sizes.append(sharding.size)
|
||
|
elif isinstance(sharding, Chunked):
|
||
|
sharded_axis_sizes.extend(sharding.chunks)
|
||
|
else:
|
||
|
util.assert_unreachable(sharding)
|
||
|
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis)
|
||
|
else a.replicas
|
||
|
for a in self.mesh_mapping)
|
||
|
|
||
|
|
||
|
def get_logical_mesh_ids(mesh_shape):
|
||
|
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||
|
|
||
|
|
||
|
_MeshAxisName = Any
|
||
|
|
||
|
def sharding_spec_sharding_proto(
|
||
|
self, special_axes: Mapping[int, OpShardingType] = {}) -> xc.HloSharding:
|
||
|
"""Converts a ShardingSpec to an OpSharding proto.
|
||
|
|
||
|
See
|
||
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
||
|
for details on the OpSharding proto.
|
||
|
Unfortunately the semantics are not very well described in the proto spec, but
|
||
|
the code here might help:
|
||
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
||
|
"""
|
||
|
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
|
||
|
|
||
|
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
||
|
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
||
|
for maxis, assignment in enumerate(self.mesh_mapping):
|
||
|
if isinstance(assignment, Replicated):
|
||
|
replicated_maxes.append((maxis, assignment.replicas))
|
||
|
elif isinstance(assignment, ShardedAxis):
|
||
|
sharded_axes[assignment.axis] = maxis
|
||
|
else:
|
||
|
util.assert_unreachable(assignment)
|
||
|
|
||
|
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
|
||
|
return xc.HloSharding.replicate()
|
||
|
|
||
|
mesh_permutation = []
|
||
|
new_mesh_shape = []
|
||
|
next_sharded_axis = 0
|
||
|
for axis, sharding in enumerate(self.sharding):
|
||
|
if isinstance(sharding, NoSharding):
|
||
|
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
|
||
|
elif isinstance(sharding, Chunked):
|
||
|
for nchunks in sharding.chunks:
|
||
|
maxis = sharded_axes[next_sharded_axis]
|
||
|
assert mesh_shape[maxis] == nchunks
|
||
|
mesh_permutation.append(maxis)
|
||
|
next_sharded_axis += 1
|
||
|
new_mesh_shape.append(math.prod(sharding.chunks))
|
||
|
elif isinstance(sharding, Unstacked):
|
||
|
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
||
|
else:
|
||
|
util.assert_unreachable(sharding)
|
||
|
|
||
|
# Create a partial sharding proto if tensor is replicated or partitioned
|
||
|
# specially over some mesh axes.
|
||
|
last_tile_dims = []
|
||
|
if replicated_maxes:
|
||
|
axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {}
|
||
|
size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
|
||
|
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
|
||
|
for axis, size in replicated_maxes:
|
||
|
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
|
||
|
axes_by_type.setdefault(ty, []).append(axis)
|
||
|
size_by_type[ty] *= size
|
||
|
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
||
|
last_tile_dims.append(ty)
|
||
|
new_mesh_shape.append(size_by_type[ty])
|
||
|
mesh_permutation.extend(axes)
|
||
|
|
||
|
return xc.HloSharding.iota_tile(
|
||
|
dims=new_mesh_shape, reshape_dims=mesh_shape,
|
||
|
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)
|
||
|
|
||
|
|
||
|
def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||
|
"""Returns NumPy-style indices corresponding to a sharding spec.
|
||
|
|
||
|
Args:
|
||
|
shape: The shape of the logical array being sharded.
|
||
|
|
||
|
Returns:
|
||
|
An ndarray with the same shape as the logical mesh (as derived form
|
||
|
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
|
||
|
the data array to be placed on a corresponding device. The indices can be
|
||
|
ints, slice objects with step=1, or tuples of those.
|
||
|
"""
|
||
|
assert len(shape) == len(self.sharding), (shape, self.sharding)
|
||
|
|
||
|
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
|
||
|
# Take the op sharding indices generation route for pjit/xmap cases.
|
||
|
if not has_unstacked:
|
||
|
hlo_sharding = sharding_spec_sharding_proto(self)
|
||
|
return op_shardings.op_sharding_to_numpy_indices(
|
||
|
hlo_sharding, shape, math.prod(self.mesh_shape)
|
||
|
).reshape(self.mesh_shape)
|
||
|
|
||
|
axis_indices: List[Sequence[Index]] = []
|
||
|
shard_indices_shape = []
|
||
|
for dim, sharding in enumerate(self.sharding):
|
||
|
axis_size = shape[dim]
|
||
|
if isinstance(sharding, NoSharding):
|
||
|
axis_indices.append([slice(None)])
|
||
|
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
|
||
|
# because they do not appear in the mesh mapping.
|
||
|
elif isinstance(sharding, Unstacked):
|
||
|
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
|
||
|
axis_indices.append(range(axis_size))
|
||
|
shard_indices_shape.append(axis_size)
|
||
|
elif isinstance(sharding, Chunked):
|
||
|
total_chunks = math.prod(sharding.chunks)
|
||
|
shard_size, ragged = divmod(axis_size, total_chunks)
|
||
|
assert not ragged, (axis_size, total_chunks, dim)
|
||
|
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||
|
for i in range(total_chunks)])
|
||
|
shard_indices_shape.extend(sharding.chunks)
|
||
|
else:
|
||
|
util.assert_unreachable(sharding)
|
||
|
|
||
|
# shard_indices is an ndarray representing the sharded axes of the logical array,
|
||
|
# with each dimension having size equal to the number of shards across the corresponding
|
||
|
# logical array dimension, and each element containing the multi-dimensional index that
|
||
|
# is used to extract the corresponding shard of the logical array.
|
||
|
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
|
||
|
for i, idxs in enumerate(itertools.product(*axis_indices)):
|
||
|
shard_indices[i] = idxs
|
||
|
shard_indices = shard_indices.reshape(shard_indices_shape)
|
||
|
|
||
|
# Ensure that each sharded axis is used exactly once in the mesh mapping
|
||
|
num_sharded_dim = len(shard_indices_shape)
|
||
|
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
|
||
|
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
|
||
|
len(sharded_dim_perm) == num_sharded_dim)
|
||
|
# Replicate/reorder the indices according to the mesh mapping
|
||
|
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
|
||
|
replica_dim, sharded_dim = itertools.count(0), iter(sharded_dim_perm)
|
||
|
perm = [next(replica_dim) if isinstance(a, Replicated) else
|
||
|
len(replica_sizes) + next(sharded_dim)
|
||
|
for a in self.mesh_mapping]
|
||
|
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
|
||
|
.transpose(perm))
|
||
|
|
||
|
def _sharding_spec_repr(self):
|
||
|
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
|
||
|
|
||
|
|
||
|
ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape)
|
||
|
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
|
||
|
ShardingSpec.indices = _sharding_spec_indices
|
||
|
# mypy raises: error: Cannot assign to a method [assignment]
|
||
|
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
|
||
|
|
||
|
|
||
|
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||
|
|
||
|
def spec_to_indices(shape: Sequence[int],
|
||
|
spec: ShardingSpec) -> Tuple[Index, ...]:
|
||
|
"""Returns numpy-style indices corresponding to a sharding spec.
|
||
|
|
||
|
Each index describes a shard of the array. The order of the indices is the
|
||
|
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
|
||
|
row-major).
|
||
|
|
||
|
Args:
|
||
|
shape: The shape of the logical array being sharded.
|
||
|
spec: Describes how the array is sharded and how the shards are assigned to
|
||
|
the logical mesh.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of length equal to the size of the mesh (inferred as the product of
|
||
|
sharded dimension sizes and all replication factors). Each element is an
|
||
|
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
||
|
index into the full logical array.
|
||
|
"""
|
||
|
return tuple(spec.indices(shape).flat) # type: ignore
|
||
|
|
||
|
|
||
|
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
|
||
|
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
||
|
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
|
||
|
next_sharded_axis = 0
|
||
|
# NOTE: sorted is stable, which is important when multiple resources
|
||
|
# map to the same axis.
|
||
|
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
||
|
chunked = sharding[axis]
|
||
|
if isinstance(chunked, NoSharding):
|
||
|
chunked = Chunked([])
|
||
|
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
|
||
|
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
|
||
|
"Value mapped to the same mesh axis twice"
|
||
|
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
|
||
|
next_sharded_axis += 1
|
||
|
return ShardingSpec(sharding, mesh_mapping)
|
||
|
|
||
|
|
||
|
def new_mesh_sharding_specs(axis_sizes, axis_names):
|
||
|
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||
|
return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos)
|
||
|
|
||
|
def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
|
||
|
map_axis: Optional[int]) -> ShardingSpec:
|
||
|
"""Sharding spec for arguments or results of a pmap.
|
||
|
Args:
|
||
|
nrep: number of local XLA replicas (product of local axis sizes)
|
||
|
axis_size: local axis size for outer pmap
|
||
|
sharded_aval: the aval of the value inside the outer pmap, an instance of
|
||
|
a ShapedArray.
|
||
|
map_axis: the axis along which the value is mapped in the outer pmap
|
||
|
Returns:
|
||
|
A ShardingSpec.
|
||
|
"""
|
||
|
replication_factor, ragged = divmod(nrep, axis_size)
|
||
|
assert not ragged
|
||
|
pspec = ShardingSpec(sharding=[_UNSHARDED_INSTANCE] * len(sharded_shape),
|
||
|
mesh_mapping=())
|
||
|
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
|
||
|
if map_axis is not None:
|
||
|
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
|
||
|
def shift_sharded_axis(a: MeshDimAssignment):
|
||
|
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
|
||
|
return ShardedAxis(a.axis + 1)
|
||
|
return a
|
||
|
# replication_factor represents the product of inner pmaps, so it goes
|
||
|
# after the outer pmapped axis at index 0
|
||
|
return ShardingSpec(
|
||
|
sharding=util.tuple_insert(
|
||
|
pspec.sharding, map_axis, Unstacked(axis_size)),
|
||
|
mesh_mapping=itertools.chain(
|
||
|
[ShardedAxis(sharded_in_axis)], maybe_replicate,
|
||
|
map(shift_sharded_axis, pspec.mesh_mapping)))
|
||
|
else:
|
||
|
return ShardingSpec(
|
||
|
sharding=pspec.sharding,
|
||
|
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
|
||
|
|
||
|
|
||
|
def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
|
||
|
sharded_dim_size: Optional[int] = None):
|
||
|
if sharded_dim is not None:
|
||
|
sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:]
|
||
|
if sharded_dim_size is None:
|
||
|
sharded_dim_size = shape[sharded_dim]
|
||
|
else:
|
||
|
assert sharded_dim_size is not None
|
||
|
sharded_shape = shape
|
||
|
|
||
|
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, sharded_shape,
|
||
|
sharded_dim)
|