Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/mesh_utils.py
2023-06-19 00:49:18 +02:00

331 lines
14 KiB
Python

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for building a device mesh."""
import collections
import itertools
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple
import jax
import numpy as np
logger = logging.getLogger(__name__)
_TPU_V2 = 'TPU v2'
_TPU_V3 = 'TPU v3'
_TPU_V4 = 'TPU v4'
# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
# famous contiguous mesh trick.
#
# The trick only works for certain topologies and mesh shapes. Trivial dims of
# size 1 can be added to the shapes listed, and they are also supported.
_TRANSPOSE_TRICKS: Dict[Tuple[int, ...],
Dict[Tuple[int, ...], Tuple[int, ...]]] = {
(2, 2, 1): {
(2, 2): (0, 1, 2),
},
(2, 2, 4): {
(4, 4): (0, 1, 2),
},
(4, 4, 4): {
(16, 4): (0, 2, 1),
},
(4, 8, 8): {
(64, 4): (0, 2, 1),
(4, 64): (0, 2, 1),
},
(8, 8, 8): {
(64, 8): (0, 2, 1),
},
(8, 16, 16): {
(256, 8): (0, 2, 1),
(8, 256): (0, 2, 1),
},
}
# Physical ordering of core IDs in a tray that creates a ring
_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5)
def _create_device_mesh_for_nd_torus(
physical_mesh: np.ndarray, mesh_shape: Sequence[int],
) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
"""Assigns logical parallelism axes to physical axes of an N-D torus network.
Given logical parallelism axes with sizes in `mesh_shape` and devices in an
N-dimensional torus network represented by `physical_mesh`, maps each logical
axis to one or more physical axes. Prefer to map more-performance-sensitive
logical axes to larger numbers of physical axes to maximize the bandwidth
available to them. Also prefer to assign logical axes to multiple physical
axes of the same size (e.g., a 2D square) rather than multiple physical axes
of different sizes when possible.
Note that this routine will never split a physical axis over more than one
logical axis (which would reduce total usable bandwidth but may sometimes be
desired anyway). As a result, it will error out in cases where this is
necessary to produce a valid mapping.
Let's use a concrete example to explain the concepts and considerations.
As an example, suppose the logical mesh is [data, model], for data and model
parallelism respectively. Also suppose that data parallelism is less
performance sensitive than model parallelism. Consider a 3D TPU pod slice of
shape 4x4x16, represented by a physical mesh of shape (4, 4, 16).
A TPU pod slice has equal bandwidth along all axes with wraparound links, but
a 2D plane of size 4x4 may have faster XLA collective implementations than a
non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want
the more performance sensitive `model` axis to be mapped to the 4x4 XY plane.
Args:
physical_mesh: a np.ndarray of devices in the shape of the N-D torus
physical topology.
mesh_shape: shape of the logical mesh (size of the various logical
parallelism axes), with axes ordered by increasing network intensity.
prefer_symmetric: whether to prefer to assign a logical axis to multiple
physical axes of the same size rather than axes of different sizes.
Returns:
An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
each logical parallelism axis mapped to one or more physical mesh axes.
The axis assignment (a list of length num_logical_axes, whose elements
are tuples representing physical axis indices).
"""
# Remaining physical axes to be assigned to logical axes.
assignable_physical_mesh = list(physical_mesh.shape)
# Map each logical axis to a subset of physical axes.
assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
# Assign logical axes from highest network intensity to lowest.
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
# reverse it first.
for logical_axis_index, logical_axis_size in reversed(
list(enumerate(mesh_shape))):
# Preferentially map to more physical axes first for higher bandwidth.
for num_axes in range(3, 0, -1):
# Try assign to any subset of size num_axes. Generate all candidates.
axes = itertools.combinations(assignable_physical_mesh, num_axes)
indices = itertools.combinations(
range(len(assignable_physical_mesh)), num_axes)
for c_axes, c_indices in zip(axes, indices):
# TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only
# implemented for square 2D plane. Mapping a physical axis to two
# logical axes might be slower for non-square 2D plane, e.g., map 32 to
# 4x8 or a single axis. If XLA 2D collectives support non-square plane
# soon, we can continue to preferentially map to 2D plane in general,
# otherwise, we should treat non-square 2D plane and 1D submesh equally.
if np.prod(c_axes) == logical_axis_size:
assignment[logical_axis_index] = c_indices
# Zero the assigned physical axes.
assignable_physical_mesh = [
0 if i in c_indices else v
for i, v in enumerate(assignable_physical_mesh)
]
break
if assignment[logical_axis_index]:
# We already found an assignment from one candidate above.
break
else:
# If the num_axes for loop did not break, i.e. none of the candidates work
# goto here with this while-else construct.
if logical_axis_size > 1:
raise NotImplementedError(
'Failed to find assignment for logical_axis_index'
f' {logical_axis_index} of size {logical_axis_size} with remaining'
f' assignable mesh {assignable_physical_mesh}. The size of each'
' axis in your logical mesh must be equal to the product of'
' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
)
# Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1).
transpose: List[int] = []
for x in assignment:
for y in x:
transpose.append(int(y))
return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment
def _bounds_from_last_device(last_device) -> Sequence[int]:
"""Gets the bound from the given last device."""
# Must be passed the device at the highest-coordinate corner of the
# relevant mesh, which is a requirement we know is satisfied by the last
# device in jax.devices().
assert hasattr(last_device, 'coords'), 'Only TPU supported'
x, y, z = last_device.coords
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
r"""Rearrange TPU devices in a slice into a physical mesh.
Args:
jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x,
core order, e.g. from jax.devices().
Returns:
A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On
v2 and v3, global_z is instead cores_per_chip (i.e., 2).
"""
device_kind = jax_devices[0].device_kind
device_coords = [d.coords for d in jax_devices]
dims = tuple(d + 1 for d in max(device_coords))
assert len(dims) == 3, dims
if device_kind in (_TPU_V2, _TPU_V3):
cores_per_chip = max(d.core_on_chip for d in jax_devices) + 1
out = np.empty(dims[:2] + (cores_per_chip,), dtype=object)
for coords, d in zip(device_coords, jax_devices):
assert coords[2] == 0, d
out[coords[0], coords[1], d.core_on_chip] = d
else:
out = np.empty(dims, dtype=object)
for coords, d in zip(device_coords, jax_devices):
if d.core_on_chip != 0:
raise AssertionError(
'Creating meshes for TPU >v3 requires one device per chip.'
f'Got device id {d.core_on_chip} for a device of kind {device_kind}'
f': {d}'
)
out[coords[0], coords[1], coords[2]] = d
return out
# jekbradbury's famous trick for creating contiguous submeshes (where available)
def _transpose_trick(physical_mesh: np.ndarray,
mesh_shape: Sequence[int]) -> np.ndarray:
mesh_shape = tuple(mesh_shape)
topology = physical_mesh.shape
if topology not in _TRANSPOSE_TRICKS:
raise ValueError(
f"create_device_mesh cannot create contiguous submeshes for "
f"physical mesh topology {topology}")
mesh_shape_no_trivial_dims: Tuple[int, ...] = ()
for dim_size in mesh_shape:
if dim_size != 1:
mesh_shape_no_trivial_dims += (dim_size,)
if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]:
raise ValueError(
f"create_device_mesh cannot create contiguous submeshes for "
f"mesh_shape {mesh_shape} and physical mesh topology {topology}. "
f"Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}")
return physical_mesh.transpose(
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims])
def create_device_mesh(
mesh_shape: Sequence[int],
devices: Optional[Sequence[Any]] = None, *,
contiguous_submeshes: bool = False) -> np.ndarray:
"""Creates a performant device mesh for jax.sharding.Mesh.
Args:
mesh_shape: shape of logical mesh, ordered by increasing network-intensity
e.g. [replica, data, mdl] where mdl has the most network communication
requirements.
devices: optionally, the devices to construct a mesh for. Defaults to
jax.devices().
contiguous_submeshes: if True, this function will attempt to create a mesh
where each process's local devices form a contiguous submesh. This is
required when passing host local inputs to `pjit`. A ValueError will be
raised if this function can't produce a suitable mesh.
Returns:
A np.ndarray of JAX devices with mesh_shape as its shape that can be fed
into jax.sharding.Mesh with good collective performance.
"""
if devices is None:
devices = jax.devices()
if np.prod(mesh_shape) != len(devices):
raise ValueError(f'Number of devices {len(devices)} must equal the product '
f'of mesh_shape {mesh_shape}')
last_device = devices[-1]
if last_device.device_kind in (_TPU_V2, _TPU_V3):
if len(devices) == 8:
logger.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.')
device_mesh = np.asarray(devices)
device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
elif mesh_shape[-1] == 8:
device_mesh = np.asarray(devices).reshape(mesh_shape)
logger.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.')
perm = np.array(_TRAY_RING_ORDER)
device_mesh = device_mesh[..., perm]
return device_mesh
else:
# TODO(skye): implement 2D mesh_shape logic here:
# https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187
# (possibly replaces above mesh_shape[-1] == 8 case)
return np.asarray(devices).reshape(mesh_shape)
elif last_device.platform == 'tpu':
physical_mesh = _get_physical_tpu_mesh(devices)
if contiguous_submeshes:
physical_mesh = _transpose_trick(physical_mesh, mesh_shape)
device_mesh, assignment = _create_device_mesh_for_nd_torus(
physical_mesh, mesh_shape)
logger.info('_create_device_mesh_for_nd_torus assignment: %s', assignment)
return device_mesh
else:
device_mesh = np.asarray(devices).reshape(mesh_shape)
return device_mesh
def create_hybrid_device_mesh(mesh_shape: Sequence[int],
dcn_mesh_shape: Sequence[int],
devices: Optional[Sequence[Any]] = None, *,
process_is_granule: bool = False) -> np.ndarray:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
Args:
mesh_shape: shape of the logical mesh for the faster/inner network, ordered
by increasing network intensity, e.g. [replica, data, mdl] where mdl has
the most network communication requirements.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
in the same order as mesh_shape.
devices: optionally, the devices to construct a mesh for. Defaults to
jax.devices().
process_is_granule: if True, this function will treat processes as the units
of the slower/outer network. Otherwise it will look for slice_index
attributes on devices and use slices as the units. Enabling this is meant
as a fallback for platforms (e.g., GPU) that don't set slice_index.
Returns:
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
that can be fed into jax.sharding.Mesh for hybrid parallelism.
"""
if devices is None:
devices = jax.devices()
attr = 'process_index' if process_is_granule else 'slice_index'
assert hasattr(devices[0], attr)
granule_dict = collections.defaultdict(list)
for dev in devices:
granule_dict[getattr(dev, attr)].append(dev)
granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f'Number of slices {len(granules)} must equal the product of '
f'dcn_mesh_shape {dcn_mesh_shape}')
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
for granule in granules]
# TODO(jekbradbury): handle non-uniform DCN topologies
granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
blocks = np.vectorize(
lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh)
device_mesh = np.block(blocks.tolist())
return device_mesh