# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # A ShardingSpec describes at a high level how a logical array is sharded across # devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also # describe how to shard inputs to a parallel computation). spec_to_indices() # encodes exactly how a given ShardingSpec is translated to device buffers, i.e. # how the sharded array is "laid out" across devices. Given a sequence of # devices, we shard the data across the devices in row-major order, with # replication treated as an extra inner dimension. # # For example, given the logical data array [1, 2, 3, 4], if we were to # partition this array 4 ways with a replication factor of 2, for a total of 8 # devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4]. # # This encoding is assumed by various parts of the system, e.g. generating # replica groups for collective operations. import collections import functools import itertools import math from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np from jax._src import op_shardings from jax._src import util from jax._src.lib import pmap_lib from jax._src.lib import xla_client as xc unsafe_map, map = map, util.safe_map NoSharding = pmap_lib.NoSharding Chunked = pmap_lib.Chunked Unstacked = pmap_lib.Unstacked _UNSHARDED_INSTANCE = NoSharding() ShardedAxis = pmap_lib.ShardedAxis Replicated = pmap_lib.Replicated MeshDimAssignment = Union[ShardedAxis, Replicated] ShardingSpec = pmap_lib.ShardingSpec OpShardingType = Any def _sharding_spec_mesh_shape(self): sharded_axis_sizes = [] for sharding in self.sharding: if isinstance(sharding, NoSharding): continue elif isinstance(sharding, Unstacked): sharded_axis_sizes.append(sharding.size) elif isinstance(sharding, Chunked): sharded_axis_sizes.extend(sharding.chunks) else: util.assert_unreachable(sharding) return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas for a in self.mesh_mapping) def get_logical_mesh_ids(mesh_shape): return np.arange(math.prod(mesh_shape)).reshape(mesh_shape) _MeshAxisName = Any def sharding_spec_sharding_proto( self, special_axes: Mapping[int, OpShardingType] = {}) -> xc.HloSharding: """Converts a ShardingSpec to an OpSharding proto. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601 for details on the OpSharding proto. Unfortunately the semantics are not very well described in the proto spec, but the code here might help: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py """ mesh_shape = cast(Tuple[int, ...], self.mesh_shape) sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped replicated_maxes = [] # lists mesh axis identifiers to replicate over for maxis, assignment in enumerate(self.mesh_mapping): if isinstance(assignment, Replicated): replicated_maxes.append((maxis, assignment.replicas)) elif isinstance(assignment, ShardedAxis): sharded_axes[assignment.axis] = maxis else: util.assert_unreachable(assignment) if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes: return xc.HloSharding.replicate() mesh_permutation = [] new_mesh_shape = [] next_sharded_axis = 0 for axis, sharding in enumerate(self.sharding): if isinstance(sharding, NoSharding): new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over elif isinstance(sharding, Chunked): for nchunks in sharding.chunks: maxis = sharded_axes[next_sharded_axis] assert mesh_shape[maxis] == nchunks mesh_permutation.append(maxis) next_sharded_axis += 1 new_mesh_shape.append(math.prod(sharding.chunks)) elif isinstance(sharding, Unstacked): raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding") else: util.assert_unreachable(sharding) # Create a partial sharding proto if tensor is replicated or partitioned # specially over some mesh axes. last_tile_dims = [] if replicated_maxes: axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {} size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1) assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys())) for axis, size in replicated_maxes: ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED) axes_by_type.setdefault(ty, []).append(axis) size_by_type[ty] *= size for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value): last_tile_dims.append(ty) new_mesh_shape.append(size_by_type[ty]) mesh_permutation.extend(axes) return xc.HloSharding.iota_tile( dims=new_mesh_shape, reshape_dims=mesh_shape, transpose_perm=mesh_permutation, subgroup_types=last_tile_dims) def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: """Returns NumPy-style indices corresponding to a sharding spec. Args: shape: The shape of the logical array being sharded. Returns: An ndarray with the same shape as the logical mesh (as derived form `mesh_mapping`). Each entry is a NumPy-style index selecting the subset of the data array to be placed on a corresponding device. The indices can be ints, slice objects with step=1, or tuples of those. """ assert len(shape) == len(self.sharding), (shape, self.sharding) has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding) # Take the op sharding indices generation route for pjit/xmap cases. if not has_unstacked: hlo_sharding = sharding_spec_sharding_proto(self) return op_shardings.op_sharding_to_numpy_indices( hlo_sharding, shape, math.prod(self.mesh_shape) ).reshape(self.mesh_shape) axis_indices: List[Sequence[Index]] = [] shard_indices_shape = [] for dim, sharding in enumerate(self.sharding): axis_size = shape[dim] if isinstance(sharding, NoSharding): axis_indices.append([slice(None)]) # NOTE: We don't append unsharded dimensions to shard_indices_shape here, # because they do not appear in the mesh mapping. elif isinstance(sharding, Unstacked): assert axis_size == sharding.size, f'{axis_size} != {sharding.size}' axis_indices.append(range(axis_size)) shard_indices_shape.append(axis_size) elif isinstance(sharding, Chunked): total_chunks = math.prod(sharding.chunks) shard_size, ragged = divmod(axis_size, total_chunks) assert not ragged, (axis_size, total_chunks, dim) axis_indices.append([slice(i * shard_size, (i + 1) * shard_size) for i in range(total_chunks)]) shard_indices_shape.extend(sharding.chunks) else: util.assert_unreachable(sharding) # shard_indices is an ndarray representing the sharded axes of the logical array, # with each dimension having size equal to the number of shards across the corresponding # logical array dimension, and each element containing the multi-dimensional index that # is used to extract the corresponding shard of the logical array. shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_) for i, idxs in enumerate(itertools.product(*axis_indices)): shard_indices[i] = idxs shard_indices = shard_indices.reshape(shard_indices_shape) # Ensure that each sharded axis is used exactly once in the mesh mapping num_sharded_dim = len(shard_indices_shape) sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)] assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and len(sharded_dim_perm) == num_sharded_dim) # Replicate/reorder the indices according to the mesh mapping replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated)) replica_dim, sharded_dim = itertools.count(0), iter(sharded_dim_perm) perm = [next(replica_dim) if isinstance(a, Replicated) else len(replica_sizes) + next(sharded_dim) for a in self.mesh_mapping] return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape) .transpose(perm)) def _sharding_spec_repr(self): return f'ShardingSpec({self.sharding}, {self.mesh_mapping})' ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape) ShardingSpec.sharding_proto = sharding_spec_sharding_proto ShardingSpec.indices = _sharding_spec_indices # mypy raises: error: Cannot assign to a method [assignment] ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore Index = Union[int, slice, Tuple[Union[int, slice], ...]] def spec_to_indices(shape: Sequence[int], spec: ShardingSpec) -> Tuple[Index, ...]: """Returns numpy-style indices corresponding to a sharding spec. Each index describes a shard of the array. The order of the indices is the same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out row-major). Args: shape: The shape of the logical array being sharded. spec: Describes how the array is sharded and how the shards are assigned to the logical mesh. Returns: A tuple of length equal to the size of the mesh (inferred as the product of sharded dimension sizes and all replication factors). Each element is an int, a slice object with step=1, or a tuple thereof, to be treated as an index into the full logical array. """ return tuple(spec.indices(shape).flat) # type: ignore def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes): mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()] sharding = [_UNSHARDED_INSTANCE] * num_dimensions next_sharded_axis = 0 # NOTE: sorted is stable, which is important when multiple resources # map to the same axis. for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]): chunked = sharding[axis] if isinstance(chunked, NoSharding): chunked = Chunked([]) sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]]) assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \ "Value mapped to the same mesh axis twice" mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis) next_sharded_axis += 1 return ShardingSpec(sharding, mesh_mapping) def new_mesh_sharding_specs(axis_sizes, axis_names): mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos) def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int], map_axis: Optional[int]) -> ShardingSpec: """Sharding spec for arguments or results of a pmap. Args: nrep: number of local XLA replicas (product of local axis sizes) axis_size: local axis size for outer pmap sharded_aval: the aval of the value inside the outer pmap, an instance of a ShapedArray. map_axis: the axis along which the value is mapped in the outer pmap Returns: A ShardingSpec. """ replication_factor, ragged = divmod(nrep, axis_size) assert not ragged pspec = ShardingSpec(sharding=[_UNSHARDED_INSTANCE] * len(sharded_shape), mesh_mapping=()) maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),) if map_axis is not None: sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis]) def shift_sharded_axis(a: MeshDimAssignment): if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis: return ShardedAxis(a.axis + 1) return a # replication_factor represents the product of inner pmaps, so it goes # after the outer pmapped axis at index 0 return ShardingSpec( sharding=util.tuple_insert( pspec.sharding, map_axis, Unstacked(axis_size)), mesh_mapping=itertools.chain( [ShardedAxis(sharded_in_axis)], maybe_replicate, map(shift_sharded_axis, pspec.mesh_mapping))) else: return ShardingSpec( sharding=pspec.sharding, mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping) def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0, sharded_dim_size: Optional[int] = None): if sharded_dim is not None: sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:] if sharded_dim_size is None: sharded_dim_size = shape[sharded_dim] else: assert sharded_dim_size is not None sharded_shape = shape return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, sharded_shape, sharded_dim)