# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of pmap and related functionality."""
from __future__ import annotations
import enum
from contextlib import contextmanager
from collections import defaultdict, namedtuple
import dataclasses
from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable,
TYPE_CHECKING, cast, TypeVar)
import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.tree_util import tree_map
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
xe = xc._xla
unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
Unstacked = sharding_specs.Unstacked
ShardedAxis = sharding_specs.ShardedAxis
Replicated = sharding_specs.Replicated
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = mesh_lib.Mesh
MeshAxisName = sharding_impls.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec
### util
def identity(x): return x
def shard_arg(arg, devices, arg_indices, sharding):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
arguments that are not supported by the C++ `ShardArg`.
arg: The Python argument.
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
def shard_args(
devices: Sequence[xb.xla_client.Device],
indices: Sequence[Sequence[Index]],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
) -> Sequence[jax.Array]:
"""Shard each argument data array along its leading axis.
devices: sequence of Devices mapping replica index to a physical device.
indices: sequence of the same length as `args` describing how each arg
should be sharded/replicated across `devices`. Each element in `indices`
is the same length as `devices`.
args: a sequence of JaxTypes representing arguments to be sharded according
to `indices` and placed on `devices`.
A list of length matching args, containing lists of per-device buffers
for each argument.
return [shard_arg(arg, devices, indices[i], shardings[i])
for i, arg in enumerate(args)]
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
def _shard_token(x, devices, indices, sharding):
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
out = batched_device_put(aval, sharding, [zeros for i in indices], devices)
return out
shard_arg_handlers[core.Token] = _shard_token
def _masked_array_error(x, devices, indices, sharding):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[] = _masked_array_error
def _shard_array(x, devices, indices, sharding):
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
out = batched_device_put(aval, sharding, [x[i] for i in indices], devices)
return out
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def shard_device_array(x, devices, indices, sharding):
start_indices, limit_indices, removed_dims = unzip3(
as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
aval = api_util.shaped_abstractify(x)
out = batched_device_put(aval, sharding, shards, devices)
return out
def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
def refine_shape_polymorphism(module: ir.Module) -> ir.Module:
# In order to avoid depending on jax2tf/ we will monkey patch
# this from jax_export to refine the polymorphic shapes in the module.
raise NotImplementedError("Compiling modules with shape polymorphism")
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def as_slice_indices(arr: Any, idx: Index) -> Tuple[
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
limit_indices = list(arr.shape)
removed_dims = []
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
for dim, sub_idx in enumerate(tuple_idx):
if isinstance(sub_idx, int):
start_indices[dim] = sub_idx
limit_indices[dim] = sub_idx + 1
elif sub_idx == slice(None):
assert isinstance(sub_idx, slice), sub_idx
assert isinstance(sub_idx.start, int), sub_idx
assert isinstance(sub_idx.stop, int), sub_idx
start_indices[dim] = sub_idx.start
limit_indices[dim] = sub_idx.stop
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
def shard_aval(size, axis: int, aval):
return shard_aval_handlers[type(aval)](size, axis, aval)
except KeyError as err:
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
def _shard_abstract_array(size, axis: int, x):
if x.shape[axis] != size:
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
f"shape {x.shape}")
except IndexError:
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
return x.update(shape=tuple_delete(x.shape, axis))
shard_aval_handlers[ShapedArray] = _shard_abstract_array
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
indices: Optional[Tuple[Index, ...]],
) -> Callable[[List[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
aval: The local output AbstractValue.
sharding_spec: Indicates how the output is sharded across devices, or None
for non-array avals.
indices: The pre-computed result of spec_to_indices, or None for non-array
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
return local_result_handlers[(type(aval))](aval, sharding, indices)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool,
is_out_sharding_from_xla: bool
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
aval: The global output AbstractValue.
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e.
the sharding is extracted from the HLO.
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
return global_result_handlers[type(aval)](
aval, out_sharding, committed, is_out_sharding_from_xla)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
# TODO(yashkatariya, phawkins): Remove this function after March 15, 2023.
def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
# Any is for JAX extensions implementing their own buffer.
device_buffers: List[Any],
indices: Optional[Tuple[Index, ...]] = None,
"""Returns a ShardedDeviceArray implementation based on arguments.
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
JAX buffers.
aval: The `ShapedArray` for this array.
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
first dimension.
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
returned (if the version is high enough). Otherwise, a Python object will
be returned, for JAX extensions not implementing the C++ API.
indices: For caching purposes, will be computed if `None`.
if sharding_spec is None:
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
mesh = mesh_lib.thread_resources.env.physical_mesh
sharding: sharding_impls.XLACompatibleSharding
if mesh.empty:
sharding = sharding_impls.PmapSharding(
np.asarray([d.device() for d in device_buffers]), sharding_spec)
hlo_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec)
pspec = sharding_impls.parse_flatten_op_sharding(
hlo_sharding, mesh)[0].get_partition_spec()
sharding = sharding_impls.NamedSharding(mesh, pspec)
return jax.make_array_from_single_device_arrays(
aval.shape, sharding, device_buffers) # type: ignore
ShardedDeviceArray = Any
class ShardedDeviceArray(object):
def __init__(self):
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
"and cannot be instantiated.")
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
# The fast path is handled directly in shard_args().
# TODO(yashkatariya): Move this to when SDA is deleted. The local
# import of Array should go away at that time.
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
from jax._src.array import ArrayImpl
candidates = defaultdict(list)
if isinstance(x, ArrayImpl):
bufs = [ for buf in x.addressable_shards]
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
bufs = x.device_buffers
arr_indices = x.indices
for buf, idx in safe_zip(bufs, arr_indices):
bufs = []
for idx, device in safe_zip(indices, devices):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[_hashable_index(idx)]
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return shard_arg_handlers[type(x._value)](
x._value, devices, indices, sharding)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
if buf.device() == device:
return batched_device_put(x.aval, sharding, bufs, devices)
### the xla_pmap primitive and its rules are comparable to xla_call in
def xla_pmap_impl_lazy(
fun: lu.WrappedFun,
backend: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Optional[Sequence[Any]],
name: str,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes,
return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, *abstract_args)
# Don't re-abstractify args unless logging is enabled for performance.
if config.jax_distributed_debug:
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
("abstract args", map(xla.abstractify, args)),
("fingerprint", fingerprint))
return compiled_fun
def xla_pmap_impl(fun: lu.WrappedFun, *args, **params):
compiled_fun = xla_pmap_impl_lazy(fun, *args, **params)
return compiled_fun(*args)
class EmapInfo(NamedTuple):
backend: Optional[str]
devices: Optional[Sequence[Any]]
def _emap_impl(fun: lu.WrappedFun, *args,
backend: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Optional[Sequence[Any]],
name: str,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
from jax._src import array
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
if is_explicit_global_axis_size:
raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.")
emap_info = EmapInfo(backend, devices)
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
t = main.with_cur_sublevel()
tracers = [
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
ans = fun.call_wrapped(*tracers)
out_tracers = map(t.full_raise, ans)
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
del main
out_axes = out_axes_thunk()
platform = xb.get_backend(backend).platform
donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else ()
new_outvals = []
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax.disable_jit(False):
donate_argnums_ = donate_argnums
if isinstance(outval, array.ArrayImpl):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(
lambda _, x: x,
in_axes=(0, out_axis_src.get(axis_name)),
devices=(None if devices is None else list(devices)),
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
return new_outvals
def _map_schedule(idx: Tuple[Optional[int], ...]) -> Tuple[Optional[int], ...]:
# In order to do a multi-map (a simultaneous map over several axes), we will
# nest several maps. Each time we do a map, we "remove" an input axis so we
# need to update the remaining map axes. For example, if we are to map over
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
return tuple(None if i is None else
i - sum(j is not None and j < i for j in idx[:l])
for l, i in enumerate(idx))
# We're often creating `f`s on the fly and we try to carefully make them have
# the right __hash__ and __eq__. However, despite our attempts pmap's caching
# still ends up not working, because it has a separate cache per
# _function object_. Adding this annotation here lets us reuse the same pmap
# callable for all equivalent primitive pmaps.
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
all_axes: List[Tuple[Optional[int], ...]]
) -> Tuple[Callable, Dict[core.AxisName, int]]:
used_names = []
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(
devices=(None if info.devices is None else list(info.devices)))
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes
FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
class MapTrace(core.Trace):
def __init__(self, *args, emap_info):
self.emap_info = emap_info
def pure(self, val):
return MapTracer(self, val, {})
def sublift(self, tracer):
return MapTracer(self, tracer.val, tracer.shard_axes)
def process_primitive(self, primitive, tracers, params):
info = self.main.payload["emap_info"]
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
names = tuple( for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main)
all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations
f = HashableFunction(lambda *args: primitive.bind(*args, **params),
(primitive, tuple(params.items())))
f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes)
with core.eval_context(), jax.disable_jit(False):
outvals = f_mapped(*vals)
if primitive.multiple_results:
return [MapTracer(self, val, out_shard_axes) for val in outvals]
return MapTracer(self, outvals, out_shard_axes)
def process_call(self, call_primitive, fun, tracers, params):
raise NotImplementedError
def process_map(self, map_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit:
bind = HashableFunction(
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
(map_primitive, fun))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, params)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
if ax is not None else s
for v, ax, s in zip(vals, in_axes, shard_axes)]
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
t = self.main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
return map(partial(MapTracer, self), out, outaxes)
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: prim.bind(
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
(prim, fun, jvp, symbolic_zeros))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(
fun, fwd, bwd, *args, out_trees=out_trees,
symbolic_zeros=symbolic_zeros, **kwargs),
(primitive, fun, fwd, bwd))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})
def process_axis_index(self, frame):
bind = HashableFunction(
lambda _: jax.lax.axis_index(,
fake_primitive = FakePrimitive(multiple_results=False, bind=bind)
with core.eval_context():
range = jax.lax.iota(np.int32, frame.size)
dummy_tracer = MapTracer(self, range, { 0})
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
annotation: Optional[int]) -> Optional[int]:
if annotation is None: return None
mapped_axes_ = set(mapped_axes)
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
shard_axis_src: Dict[core.AxisName, int],
dst_annotation: Optional[int]
) -> Tuple[Any, Dict[core.AxisName, int]]:
shard_axis_out = dict(shard_axis_src)
src = shard_axis_out.pop(axis_name, None)
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
with core.eval_context():
if src == dst:
outval = val
elif type(src) == type(dst) == int:
outval = batching.moveaxis(val, src, dst)
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
elif src is None and dst is not None:
outval = batching.broadcast(val, axis_size, dst)
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
raise NotImplementedError
return outval, shard_axis_out
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
src: int, dst: int) -> Dict[core.AxisName, int]:
lst: List[Optional[core.AxisName]] = [None] * ndim
for k, v in shard_axes.items():
lst[v] = k
name = lst.pop(src)
lst.insert(dst - (src < dst), name)
return {name: i for i, name in enumerate(lst) if name is not None}
class MapTracer(core.Tracer):
__slots__ = ["val", "shard_axes"]
def __init__(self, trace: MapTrace, val, shard_axes: Dict[core.AxisName, int]):
self._trace = trace
self.val = val
self.shard_axes = shard_axes
assert all(val < self.val.ndim for val in self.shard_axes.values())
def aval(self):
aval = xla.abstractify(self.val)
shard_axes = dict(self.shard_axes)
for axis_idx in sorted(shard_axes.values())[::-1]:
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
return aval
def full_lower(self):
return self
def __str__(self):
named_axes = [f"{k}={v}" for k, v in self.shard_axes.items()]
return f"{self.val}{{{','.join(named_axes)}}}"
def parallel_callable(fun: lu.WrappedFun,
backend_name: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Optional[Sequence[Any]],
name: str,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals, lowering_platform=None)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
class ParallelCallableInfo:
name: str
backend: xc.Client
axis_name: core.AxisName
axis_size: int
global_axis_size: int
devices: Optional[Sequence[xc.Device]]
in_axes: Iterable[Optional[int]]
out_axes_thunk: Callable[[], Sequence[Optional[int]]]
avals: Sequence[core.AbstractValue]
def local_devices(self):
if self.devices:
out = [d for d in self.devices
if d.process_index == xb.process_index(self.backend)]
assert len(out) > 0
out = None # type: ignore
return out
def out_axes(self):
return self.out_axes_thunk()
class ShardInfo(NamedTuple):
sharded_avals: Sequence[core.AbstractValue]
out_sharded_avals: Sequence[core.ShapedArray]
global_sharded_avals: Sequence[core.AbstractValue]
num_local_shards: int
num_global_shards: int
class ReplicaInfo(NamedTuple):
jaxpr_replicas: int
num_local_replicas: int
num_global_replicas: int
def find_replicas(
jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int
) -> ReplicaInfo:
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
num_global_replicas = global_axis_size * jaxpr_replicas
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
def stage_parallel_callable(
pci: ParallelCallableInfo, fun: lu.WrappedFun
) -> Tuple[core.Jaxpr, List[Any], ReplicaInfo, ShardInfo]:
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
# for now raise an error
if pci.devices is not None:
is_multi_host_pmap = len(pci.local_devices) != len(pci.devices)
is_multi_host_pmap = xb.process_count(pci.backend) > 1
if is_multi_host_pmap:
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
num_local_shards = replicas.num_local_replicas
num_global_shards = replicas.num_global_replicas
shards = ShardInfo(
sharded_avals, out_sharded_avals, sharded_avals,
num_local_shards, num_global_shards)
return jaxpr, consts, replicas, shards
def lower_parallel_callable(
fun: lu.WrappedFun,
backend_name: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Optional[Sequence[xc.Device]],
name: str,
in_axes: Iterable[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
lowering_platform: Optional[str]):
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
if (xb.process_count() == 1 and is_explicit_global_axis_size
and global_axis_size != axis_size):
raise ValueError(
f"Specified axis_size {global_axis_size} doesn't match received "
f"axis_size {axis_size}.")
if devices is not None and backend_name is None:
backend = xb.get_device_backend(devices[0])
backend = xb.get_backend(backend_name)
no_nested_sharding = False
must_run_on_all_devices = False
if not is_explicit_global_axis_size:
if xb.process_count(backend) > 1:
if devices:
# This allows each host in a multi-host pmap to run on a different number
# of devices, but precludes nested sharding (i.e. inner pmaps).
no_nested_sharding = True
# This assumes all hosts run on the same number of devices. We make sure
# this assumption is true by requiring that the pmap is run on all devices
# (and making the further assumption that each host has the same number of
# devices). Nested sharding is ok in this case.
must_run_on_all_devices = True
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
logger.debug("num_replicas: %d num_local_replicas: %d",
replicas.num_global_replicas, replicas.num_local_replicas)
logger.debug("devices: %s", devices)
logger.debug("local_devices: %s", pci.local_devices)
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
shards.num_local_shards != xb.local_device_count(backend)):
if shards.num_local_shards == axis_size:
raise ValueError(
f"On multi-host platforms, the input to pmapped functions must have "
f"leading axis size equal to the number of local devices if no "
f"`devices` argument is specified. Got {axis_size=}, "
raise ValueError(
f"On multi-host platforms, pmapped functions must run across all "
f"devices, i.e. num_replicas * num_partitions should equal the "
f"number of local devices. Got "
f"num_replicas={replicas.num_local_replicas}, and "
if no_nested_sharding and replicas.jaxpr_replicas > 1:
raise ValueError(
f"On multi-host platforms, pmapped functions that both have `devices` "
f"specified and contain an inner_pmap must specify an "
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas)
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
ordered_effects = list(
if ordered_effects:
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = list(
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
lowering_platform or backend.platform,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
shards=shards, tuple_args=tuple_args,
class PmapComputation(stages.XlaLowering):
_hlo: ir.Module
_executable: Optional[PmapExecutable]
def __init__(self, hlo: ir.Module, **compile_args):
self._executable = None
self._hlo = hlo
self.compile_args = compile_args
# -- stages.XlaLowering overrides
def stablehlo(self) -> ir.Module:
return self._hlo
def compile(self, compiler_options=None) -> PmapExecutable:
if self._executable is None or compiler_options is not None:
executable = UnloadedPmapExecutable.from_hlo(
self._hlo, **self.compile_args,
if compiler_options is None:
self._executable = executable
return executable
return self._executable
def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray:
assert isinstance(aval, ShapedArray), aval
return cast(ShapedArray, aval)
class UnloadedPmapExecutable:
compiled: Any
backend: xb.XlaBackend
local_input_avals: Sequence[core.AbstractValue]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo
def build_execute_fun(self):
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
assert isinstance(aval, core.ShapedArray), aval
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
handle_args = InputsHandler(self.compiled.local_devices(),
self.input_shardings, input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
self.ordered_effects, self.keepalive,
return execute_fun
def load(self) -> PmapExecutable:
fingerprint = getattr(self.compiled, "fingerprint", None)
return PmapExecutable(
self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self.jaxpr_debug_info, self)
def from_hlo(hlo: ir.Module,
pci: ParallelCallableInfo,
replicas: ReplicaInfo,
shards: ShardInfo,
tuple_args: bool,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
devices = pci.devices
if devices is None:
if shards.num_global_shards > xb.device_count(pci.backend):
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
"devices are available (num_replicas={})")
raise ValueError(msg.format(shards.num_global_shards,
# On a single host, we simply grab the first N devices from jax.devices().
# In the single host case, we want the default device order of pmap to
# match jax.devices().
# On multiple hosts, we create a default device assignment that ensures
# each host is responsible for a contiguous set of replicas.
if shards.num_global_shards > shards.num_local_shards:
# TODO(skye): use a locality-aware assignment that satisfies the above
# constraint.
devices = [d for process_index in range(xb.process_count(pci.backend))
for d in xb.local_devices(process_index, pci.backend)]
devices = xb.local_devices(backend=pci.backend)[:shards.num_local_shards]
if shards.num_local_shards != len(pci.local_devices):
local_devices_str = ", ".join(map(str, pci.local_devices))
if shards.num_local_shards == pci.axis_size:
raise ValueError(
f"Leading axis size of input to pmapped function must equal the "
f"number of local devices passed to pmap. Got axis_size="
f"{pci.axis_size}, num_local_devices={len(pci.local_devices)}.\n"
f"(Local devices available to pmap: {local_devices_str})")
raise ValueError(
f"pmapped function requires {shards.num_local_shards} local "
f"devices to run due to nested pmapped or other parallel "
f"functions, but only {len(pci.local_devices)} are available.\n"
f"(outer axis size: {pci.axis_size}, local devices available to "
f"pmap: {local_devices_str})")
if shards.num_global_shards != len(devices):
raise ValueError("compiling computation that creates %s shards, "
"but %s devices were specified" %
(shards.num_global_shards, len(devices)))
# 'devices' may be 1D or 2D at this point (e.g.
# get_default_device_assignment() returns 2D assignment, caller may have
# provided 1D list of devices).
# Convert to 2D in case it's 1D and we have > 1 partitions.
num_partitions = 1
device_assignment: np.ndarray = np.array(devices).reshape(
(replicas.num_global_replicas, num_partitions))
compile_options = xb.get_compile_options(
compile_options.parameter_is_tupled_arguments = tuple_args
process_index = xb.process_index(pci.backend)
local_device_assignment = np.array([
d for d in device_assignment.flat if d.process_index == process_index
input_sharding_specs = [
replicas.num_local_replicas, pci.axis_size,
cast(ShapedArray, aval).shape, in_axis)
for aval, in_axis in safe_zip(shards.sharded_avals, pci.in_axes)]
in_shardings = _get_pmap_sharding(local_device_assignment,
local_unmapped_avals = [
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
if out_axis is not None else aval
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
out_specs = [
replicas.num_local_replicas, pci.axis_size, aval.shape, out_axis)
for aval, out_axis in safe_zip(
shards.out_sharded_avals, pci.out_axes)]
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
if hasattr(pci.backend, "compile_replicated"):
input_indices = [
sharding_specs.spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
return _compile_replicated_pmap_executable_from_hlo(
hlo, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, bool(unordered_effects),
ordered_effects, jaxpr_debug_info)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",, event=dispatch.BACKEND_COMPILE_EVENT):
compiled = dispatch.compile_or_get_cached(
pci.backend, hlo, device_assignment, compile_options,
return UnloadedPmapExecutable(
def _compile_replicated_pmap_executable_from_hlo(
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False,, computation=hlo,
compile_options=compile_options, host_callbacks=host_callbacks,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
jaxpr_debug_info, None)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_jaxpr_debug_info",
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
in_avals, jaxpr_debug_info, unloaded_executable):
self.xla_executable = xla_executable
self._unsafe_call = None
self.build_unsafe_call = build_unsafe_call
self.fingerprint = fingerprint
self.in_avals = in_avals
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
# -- stages.XlaExecutable overrides
def xla_extension_executable(self):
return self.xla_executable
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(xla.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
def _get_pmap_sharding(devices, specs):
return [sharding_impls.PmapSharding(devices, spec) for spec in specs]
multi_host_supported_collectives: Set[core.Primitive] = set()
def check_multihost_collective_allowlist(jaxpr):
used_collectives = set(xla.jaxpr_collectives(jaxpr))
if not used_collectives.issubset(multi_host_supported_collectives):
bad_collectives = used_collectives - multi_host_supported_collectives
msg = "using collectives that aren't supported for multi-host: {}"
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
def __init__(self, local_devices, in_shardings, input_indices):
self.handler = partial(
shard_args, local_devices, input_indices, in_shardings)
self.local_devices = local_devices
self.in_shardings = in_shardings
self.input_indices = input_indices
def __call__(self, input_buffers):
return self.handler(input_buffers)
def __str__(self):
return ("InputsHandler(\n"
class ResultsHandler:
# `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap
# with `config.parallel_functions_output_gda=True`. It is the local one
# otherwise, and also when using `pmap`.
__slots__ = ("handlers", "out_shardings", "out_avals")
def __init__(self, handlers, out_shardings, out_avals):
self.handlers = handlers
self.out_shardings = out_shardings
self.out_avals = out_avals
def __call__(self, out_bufs):
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
def local_avals_to_results_handler(
unmapped_local_out_avals: Sequence[ShapedArray],
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
handlers = [
local_aval_to_result_handler(aval, s, idcs)
for aval, s, idcs in safe_zip(unmapped_local_out_avals, local_shardings, out_indices)
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
committed: bool,
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
handlers = [
global_aval_to_result_handler(global_aval, s, committed, x)
for global_aval, s, x in safe_zip(global_out_avals, shardings,
return ResultsHandler(handlers, shardings, global_out_avals)
def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
"""Replicates ``val`` across multiple devices.
val: the value to be replicated.
axis_size: the length of the output, i.e. the logical number of replicas to
create. Usually equal to `nrep`, but in the case of nested pmaps, `nrep` may
be a multiple of `axis_size`.
nrep: the number of replicas to create. If ``devices`` is set, must be equal
to ``len(devices)``.
devices: the devices to replicate across. If None, ``nrep`` will be used to
generate a default device assignment.
backend: string specifying which backend to use.
in_axis: axis along which the value is to be replciated.
A ShardedDeviceArray of length `axis_size` where each shard is equal to
device_count = (len(devices) if devices else xb.local_device_count(backend))
if nrep > device_count:
msg = ("Cannot replicate across %d replicas because only %d local devices "
"are available." % (nrep, device_count))
if devices:
msg += (" (local devices = %s)"
% ", ".join(map(str, devices)) if devices else str(None))
raise ValueError(msg)
if devices is None:
assert nrep is not None
# TODO(skye): use different device assignment on multihost
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
assert nrep == len(devices)
aval = xla.abstractify(val)
if in_axis is not None:
replicated_aval = aval.update(shape=(axis_size,) + aval.shape)
replicated_aval = aval
# TODO(skye): figure out how partitioning should work here
sharding_spec = sharding_specs.pmap_sharding_spec(
nrep, axis_size, aval.shape, in_axis)
buf = jax.device_put(val, devices[0])
sharding = sharding_impls.PmapSharding(
np.asarray([d for d in devices]), sharding_spec)
return batched_device_put(replicated_aval, sharding, [buf] * len(devices),
class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
'has_unordered_effects', 'ordered_effects', 'keepalive',
'has_host_callbacks', '_local_devices', 'kept_var_idx',
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: Set[int]):
self.xla_executable = xla_executable = name
self.backend = backend
self.in_handler = in_handler
self.out_handler = out_handler
self.has_unordered_effects = bool(unordered_effects)
self.ordered_effects = ordered_effects
self._local_devices = self.xla_executable.local_devices()
if ordered_effects:
assert len(self._local_devices) == 1
self.keepalive = keepalive
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx
def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
device, = self._local_devices
tokens = [list(dispatch.runtime_tokens.get_token(eff, device))
for eff in self.ordered_effects]
input_bufs = [*tokens, *input_bufs]
return input_bufs
def _handle_token_bufs(self, token_bufs, sharded_token):
for i, device in enumerate(self._local_devices):
device, sharded_token.get_token(i))
for eff, token_buf in zip(self.ordered_effects, token_bufs):
dispatch.runtime_tokens.update_token(eff, token_buf)
def _call_with_tokens(self, input_bufs):
input_bufs = self._add_tokens_to_inputs(input_bufs)
out_bufs, sharded_token = (
num_output_tokens = len(self.ordered_effects)
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
self._handle_token_bufs(token_bufs, sharded_token)
return out_bufs
def __call__(self, *args):
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
input_bufs = self.in_handler(args)
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(, arrays)
return self.out_handler(out_arrays)
return results.consume_with_handlers(self.out_handler.handlers)
xla_pmap_p = core.MapPrimitive('xla_pmap')
xla_pmap = xla_pmap_p.bind
def _pmap_partial_eval_custom_params_updater(
unks_in, inst_in, kept_outs_known, kept_outs_staged, num_res, params_known,
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
in_axes_known, _ = partition_list(unks_in, params_known['in_axes'])
_, out_axes_known = partition_list(kept_outs_known, params_known['out_axes'])
out_axes_known = out_axes_known + [0] * num_res
new_params_known = dict(params_known, in_axes=tuple(in_axes_known),
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
_, donated_invars_staged = partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res + donated_invars_staged
_, in_axes_staged = partition_list(inst_in, params_staged['in_axes'])
in_axes_staged = [0] * num_res + in_axes_staged
_, out_axes_staged = partition_list(kept_outs_staged, params_staged['out_axes'])
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
return new_params_known, new_params_staged
def _pmap_partial_eval_custom_res_maker(params_known, aval):
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)
def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
with maybe_extend_axis_env(eqn.params['axis_name'],
eqn.params['global_axis_size'], None):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
in_axes=tuple(in_axes), out_axes=tuple(out_axes))
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
def _pmap_axis_subst(params, subst, traverse):
if 'call_jaxpr' not in params:
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['axis_name'] else subst(name)
with maybe_extend_axis_env(params['axis_name'],
params['global_axis_size'], None):
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'],
return dict(params, call_jaxpr=new_jaxpr)
core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst
def _unravel_index_hlo(axis_env):
div = mlir.ir_constant(
np.array(axis_env.nreps //, np.uint32))
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
return hlo.RemOp(
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
def _hlo_shard(aval, axis_env, xs, in_axis):
if aval is core.abstract_token:
return xs
elif isinstance(aval, core.ShapedArray):
x, = xs
dims = list(aval.shape)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [zero] * len(dims)
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
dims_unsqueezed = dims.copy()
dims_unsqueezed.insert(in_axis, 1)
dynamic_slice_result = hlo.DynamicSliceOp(
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
return [
hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
raise TypeError(aval)
# TODO(b/110096942): more efficient gather
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
if aval is core.abstract_token:
return xs
elif isinstance(aval, core.ShapedArray):
x, = xs
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
and platform in ('cpu', 'gpu'))
if convert_bool:
aval = aval.update(dtype=np.dtype(np.float32))
x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
dims = list(aval.shape)
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
padded = mlir.full_like_aval(ctx, 0, padded_aval)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
broadcast_result = hlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(
xla.axis_groups(axis_env, axis_env.names[-1]))
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
if out_axis != 0:
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
perm = list(range(1, len(dims)))
perm.insert(out_axis, 0)
transposed_dims = list(dims)
transposed_dims.insert(out_axis, axis_env.sizes[-1])
aval = aval.update(shape=transposed_dims)
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(ctx, 0, padded_aval)
out = hlo.CompareOp(
return out
raise TypeError(aval)
def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, is_explicit_global_axis_size):
del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if ctx.module_context.axis_env.names and devices is not None:
raise ValueError("Nested pmap with explicit devices argument.")
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
# Shard the in_nodes that are mapped
in_avals = [v.aval for v in call_jaxpr.invars]
in_nodes_sharded = (
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sub_ctx = ctx.module_context.replace(
util.wrap_name(name, 'pmap')))
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard,
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
return outs
mlir.register_lowering(xla_pmap_p, _pmap_lowering)
# ------------------- xmap -------------------
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
named_shape = dict(aval.named_shape)
for name, axis in in_axes.items():
assert shape[axis] % axis_sizes[name] == 0
assert name not in named_shape
named_shape[name] = axis_sizes[name]
shape[axis] //= axis_sizes[name]
return aval.update(shape=tuple(shape), named_shape=named_shape)
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
named_shape = dict(aval.named_shape)
for name, axis in out_axes.items():
shape[axis] *= axis_sizes[name]
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
return aval.update(shape=tuple(shape), named_shape=named_shape)
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):
return untile_aval_nd(mesh.shape, axes,
tile_aval_nd(mesh.local_mesh.shape, axes, aval))
def mesh_global_to_local(mesh, axes: ArrayMapping, aval):
return untile_aval_nd(mesh.local_mesh.shape, axes,
tile_aval_nd(mesh.shape, axes, aval))
class SPMDBatchTrace(batching.BatchTrace):
def get_axis_primitive_batcher(self, primitive, frame):
if primitive in spmd_primitive_batchers:
return partial(spmd_primitive_batchers[primitive],
frame.size,, frame.main_trace.trace_type)
return super().get_axis_primitive_batcher(primitive, frame)
spmd_primitive_batchers: Dict[core.Primitive, Callable] = {}
def vtile_by_mesh(fun: lu.WrappedFun,
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping]):
# We vectorize in reversed order, because vmap is often biased towards
# moving the batch axis to the front, and this way of stacking transforms
# will order the batch axes according to the mesh axis order.
# Not strictly necessary, but seems nicer than reversing it?
for name, size in reversed(mesh.shape.items()):
fun = batching.vtile(fun,
tuple(a.get(name, None) for a in in_axes),
tuple(a.get(name, None) for a in out_axes),
return fun
full_to_shard_p = core.Primitive('full_to_shard')
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return tile_aval_nd(mesh.shape, axes, x)
def manual_proto(
aval: core.ShapedArray,
manual_axes_set: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
named_mesh_shape = mesh.shape
mesh_shape = list(named_mesh_shape.values())
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
manual_axes = list(sorted(manual_axes_set, key=str))
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
tad_perm = ([axis_order[a] for a in replicated_axes] +
[axis_order[a] for a in manual_axes])
tad_shape = [1] * aval.ndim
tad_shape.append([named_mesh_shape[a] for a in replicated_axes]))
tad_shape.append([named_mesh_shape[a] for a in manual_axes]))
raw_mesh = np.arange(
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = tad_shape
proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat)
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
return proto
@partial(mlir.register_lowering, full_to_shard_p)
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
sharding_proto = mesh_sharding_specs(
mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto().to_proto()
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims)
proto = manual_proto(aval_in, manual_axes, mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims),
shard_to_full_p = core.Primitive('shard_to_full')
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return untile_aval_nd(mesh.shape, axes, x)
@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims)
sharding_proto = mesh_sharding_specs(
mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto().to_proto()
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
for arg, axes in zip(args, in_axes)]
tiled_outs = yield tiled_args, {}
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
for out, axes in zip(tiled_outs, out_axes)]
yield outs
class TileVectorize:
class TileManual:
manual_axes: FrozenSet[sharding_impls.MeshAxisName]
TilingMethod = Union[TileVectorize, TileManual]
def check_if_any_auto(
shardings: Iterable[Union[sharding_impls.XLACompatibleSharding,
AUTO, UnspecifiedValue]]) -> bool:
for s in shardings:
if is_auto(s):
return True
return False
class MismatchType(enum.Enum):
def __str__(self):
if == 'IN_SHARDING':
return 'explicit input sharding'
elif == 'OUT_SHARDING':
return 'explicit output sharding'
return 'devices'
return f'{}'
class DeviceAssignmentMismatch:
da: Sequence[xc.Device]
m_type: MismatchType
source_info: Optional[dispatch.SourceInfo]
def device_ids(self) -> Sequence[int]:
return [ for d in self.da]
def platform(self) -> str:
return self.da[0].platform.upper()
def _maybe_api_name(self, api_name) -> str:
return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else ""
def source_info_str(self):
return "" if self.source_info is None else f" at {self.source_info.source_info}"
def _dev_ids_plat_str(self):
return f"device ids {self.device_ids} on platform {self.platform}"
def m_type_str(self, api_name):
return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}'
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
def _str(self, api_name):
return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with "
class DeviceAssignmentMismatchError(Exception):
ShardingInfo = Tuple[
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
def _get_default_device() -> xc.Device:
return config.jax_default_device or xb.local_devices()[0]
def _get_and_check_device_assignment(
shardings: Iterable[ShardingInfo],
devices: Optional[Sequence[xc.Device]],
) -> Tuple[xc.Client, Tuple[xc.Device, ...]]:
first_sharding_info = None
if devices is None:
devices = ()
devices = tuple(devices)
for i, s_type, source_info in shardings:
if is_unspecified(i):
if first_sharding_info is None:
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
else (i._device_assignment, s_type, source_info)) # type: ignore
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
if devices != arr_device_assignment:
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_sharding_info is None:
final_device_assignment = (_get_default_device(),)
final_device_assignment = first_sharding_info[0]
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
def cache_wrap(fn):
_wrapped_with_lu_cache = lu.cache(fn)
_wrapped_with_weakref_lru_cache = weakref_lru_cache(fn)
def wrapped(f, *args, **kwargs):
if isinstance(f, lu.WrappedFun):
return _wrapped_with_lu_cache(f, *args, **kwargs)
return _wrapped_with_weakref_lru_cache(f, *args, **kwargs)
return wrapped
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars, auto_spmd_lowering):
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, global_in_avals)
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
if (keep_unused or auto_spmd_lowering or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
kept_var_idx, name_stack)
class SemanticallyEqualShardings:
shardings: Tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
def __hash__(self):
return hash(tuple(
s._hlo_sharding_hash if isinstance(s, sharding_impls.GSPMDSharding) else s # type: ignore
for s in self.shardings))
def __eq__(self, other):
if not isinstance(other, SemanticallyEqualShardings):
return False
return all(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
if (isinstance(s, sharding_impls.GSPMDSharding) and
isinstance(o, sharding_impls.GSPMDSharding))
else s == o for s, o in zip(self.shardings, other.shardings))
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object, lowering_platform,
donated_invars, name_stack):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
device_assignment = da_object.device_assignment
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
"Compiling %s for with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
# Look at the number of replcas present in the jaxpr. In
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
# handled here so as to deprecate the lower_xla_callable codepath when
# `jax.Array` is turned on by default.
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
nreps = dispatch.jaxpr_replicas(jaxpr)
nreps, backend, fun_name, jaxpr)
in_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
out_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if nreps == 1:
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(device_assignment)
num_partitions = len(device_assignment)
# This path is triggered for `jit(pmap)` cases.
replicated_args = None
in_mlir_shardings = None
out_mlir_shardings = None
axis_env = sharding_impls.AxisEnv(nreps, (), ())
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_partitions = 1
module_name = f"{api_name}_{fun_name}"
if len(device_assignment) > 1:
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
# Optionally, override the lowering platform
lowering_platform or backend.platform,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
return (lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, lowering_result.shape_poly_state)
class _DeviceAssignment:
device_assignment: Tuple[xc.Device, ...]
def _hash(self):
return hash(self.device_assignment)
def __hash__(self):
return self._hash
def __eq__(self, other):
if not isinstance(other, _DeviceAssignment):
return False
if id(self) == id(other):
return True
return (self.device_assignment == other.device_assignment)
def is_fully_addressable(self):
return len(self.device_assignment) == len(self.addressable_device_assignment)
def addressable_device_assignment(self):
return [d for d in self.device_assignment
if d.process_index == d.client.process_index()]
def _create_da_object(
device_assignment: Tuple[xc.Device, ...]) -> _DeviceAssignment:
return _DeviceAssignment(device_assignment)
def lower_sharding_computation(
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
keep_unused: bool,
inline: bool,
always_lower: bool,
devices_from_context: Optional[Sequence[xc.Device]] = None,
lowering_platform: Optional[str],
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton UNSPECIFIED because the
number of out_avals might not be known at that time and
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton UNSPECIFIED to all out_avals.
# 1. Trace to jaxpr and preprocess/verify it
auto_spmd_lowering = (
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars, auto_spmd_lowering)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
if is_unspecified(out_shardings):
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in jaxpr_sharding]),
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
da_object = _create_da_object(tuple(device_assignment))
if not da_object.is_fully_addressable:
if inline and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. Its very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs. "
"If youre not already familiar with JAXs multi-process "
"programming model, please read "
" "
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
kept_outputs = [True] * len(global_out_avals)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_shardings=in_shardings,
backend=backend, da_object=da_object,
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)
# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
# to get the device_assignment and backend. So pass it to MeshExecutable
# because we calculate the device_assignment and backend before in_shardings,
# etc are pruned.
return MeshComputation(
def _to_logical_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTO]
) -> Optional[sharding_impls.XLACompatibleSharding]:
if is_unspecified(sharding) or is_auto(sharding):
return None
elif isinstance(aval, ShapedArray):
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
return sharding
elif isinstance(aval, core.AbstractToken):
return None
raise TypeError(aval)
def lower_mesh_computation(
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
mesh: Mesh,
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO]],
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO,
donated_invars: Sequence[bool],
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
tiling_method: Optional[TilingMethod],
lowering_platform: Optional[str]) -> MeshComputation:
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
global_axis_sizes = mesh.shape
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
"Compiling %s for %s mesh with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
# 1. Trace to jaxpr and preprocess/verify it
if spmd_lowering:
manual_axes: FrozenSet[MeshAxisName] = frozenset()
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
if tiling_method is not None:
if isinstance(tiling_method, TileVectorize):
tiling_transform = vtile_by_mesh
elif isinstance(tiling_method, TileManual):
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
manual_axes = tiling_method.manual_axes
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_shardings)
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
# is why `.spec` can be accessed.
fun_or_jaxpr = tiling_transform(
fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
[get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
in_jaxpr_avals = global_in_avals
assert isinstance(tiling_method, TileVectorize)
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
# why `.spec` can be accessed.
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
for aval, i in safe_zip(global_in_avals, in_shardings)]
in_jaxpr_avals = in_tiled_avals
with core.extend_axis_env_nd(mesh.shape.items()):
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, in_jaxpr_avals)
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
# why `.spec` can be accessed.
global_out_avals = [untile_aval_nd(global_axis_sizes, get_array_mapping(o.spec), aval) # type: ignore
for aval, o in safe_zip(out_jaxpr_avals, out_shardings)]
if mesh.is_multi_process:
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
in_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
out_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if spmd_lowering:
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
out_partitions = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
num_replicas = 1
num_partitions = mesh.devices.size
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
in_partitions = None
out_partitions = None
axis_env = sharding_impls.AxisEnv(
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_replicas = mesh.devices.size
num_partitions = 1
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects not supported in mesh computations.")
unordered_effects = list(effects.ordered_effects.filter_not_in(
ordered_effects = list(effects.ordered_effects.filter_in(
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
lowering_platform or backend.platform,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
return MeshComputation(
class MeshComputation(stages.XlaLowering):
_hlo: Optional[ir.Module]
_executable: Optional[MeshExecutable]
def __init__(self, name: str, hlo: Optional[ir.Module],
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
self.is_trivial = is_trivial
self._donated_invars = donated_invars
self.compile_args = compile_args
self._executable = None
# -- stages.XlaLowering overrides
def stablehlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no HLO")
return self._hlo
def compile(
) -> MeshExecutable:
if self._executable is None or compiler_options is not None:
if self.is_trivial:
executable = MeshExecutable.from_trivial_jaxpr(
executable = UnloadedMeshExecutable.from_hlo(
if compiler_options is None:
self._executable = executable
return executable
return self._executable
def cost_analysis(self) -> Dict[str, float]:
backend = self.compile_args["backend"]
if xb.using_pjrt_c_api(backend):
raise NotImplementedError(
"Lowered.cost_analysis not implemented on platform "
f"'{backend.platform}'. Use compile().cost_analysis() for "
"post-compilation cost estimates.")
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
def _get_replicated_slices(num_addressable_devices: int, ndim: Optional[int]):
if ndim is None:
return ((slice(None),),) * num_addressable_devices
return ((slice(None),) * ndim,) * num_addressable_devices
def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: Union[_DeviceAssignment, Sequence[xc.Device]],
) -> Sequence[Tuple[Optional[Index], ...]]:
input_indices = []
if isinstance(da_object, _DeviceAssignment):
num_addressable_devices = len(da_object.addressable_device_assignment)
num_addressable_devices = len(
[d for d in da_object if d.process_index == d.client.process_index()])
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = _get_replicated_slices(num_addressable_devices, None)
if sharding.is_fully_replicated:
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
index = tuple(
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
return input_indices
def get_gspmd_shardings_from_executable(
xla_executable, device_assignment: Sequence[xc.Device],
num_in_avals: int, num_out_avals: int
) -> Tuple[Sequence[sharding_impls.XLACompatibleSharding],
from jax.experimental import pjit
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
if len(device_assignment) == 1:
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
return [ss] * num_in_avals, [ss] * num_out_avals
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
for i in in_op_shardings]
out_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, o)
for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple.
# TODO(b/245667823): Remove this when XLA fixes this.
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
out_shardings_xla = out_shardings_xla * num_out_avals
assert len(out_shardings_xla) == num_out_avals, (
len(out_shardings_xla), num_out_avals)
return in_shardings_xla, out_shardings_xla
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
# without mesh.
def _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh: Mesh
) -> Tuple[Sequence[sharding_impls.NamedSharding],
from jax.experimental import pjit
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec],
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
SubClassT = TypeVar("SubClassT", bound=sharding_impls.XLACompatibleSharding)
OrigHandlerType = Dict[Type[SubClassT],
Callable[[xc.OpSharding, SubClassT], SubClassT]]
orig_out_sharding_handlers: OrigHandlerType = {}
def _gspmd_to_named_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
op_sharding, self.mesh)[0]
return create_mesh_pspec_sharding(
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding
def _gspmd_to_positional_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
return sharding_impls._from_op_sharding_to_pos_sharding(
op_sharding, self._device_assignment)
orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding
def _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_s, orig_aval, are_out_sharding_from_xla):
out = []
orig_handler = orig_out_sharding_handlers[type(orig_s)]
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
if isinstance(o, sharding_impls.GSPMDSharding):
# Only return the same input sharding object if the OpShardings and
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
# replicated then, it doesn't encode the ndim in it. The devices
# will be the same at this point because those checks happen before.
if (orig_aval is not None and out_aval is not None and
out_aval.ndim == orig_aval.ndim and
o._hlo_sharding, orig_s._to_xla_hlo_sharding(orig_aval.ndim))):
out.append((orig_s, False))
out.append((orig_handler(o._hlo_sharding, orig_s), False))
out.append((o, from_xla))
out.append((o, from_xla))
return out
def maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla, in_avals,
if all(hasattr(o, '_original_sharding') for o in out_shardings):
return ([o._original_sharding for o in out_shardings],
(False,) * len(out_shardings))
orig_s = None
orig_aval = None
for i, aval in safe_zip(in_shardings, in_avals):
oi = getattr(i, '_original_sharding', None)
if type(oi) in orig_out_sharding_handlers:
orig_s = oi
orig_aval = aval
if orig_s is not None:
return zip(*_get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_s, orig_aval, are_out_shardings_from_xla))
return out_shardings, are_out_shardings_from_xla
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering,
_allow_propagation_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
num_replicas, num_partitions = dev.size, 1
if pmap_nreps > 1:
# In `jit` device_assignment is set to None when num_replicas > 1. Do
# the same thing here too.
xla_device_assignment = None
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
if compiler_options_keys is None:
compiler_options = None
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
compile_options = xb.get_compile_options(
opts = compile_options.executable_build_options
if auto_spmd_lowering:
assert mesh is not None
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
compile_options.parameter_is_tupled_arguments = tuple_args
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
if hasattr(backend, "compile_replicated"):
return None, compile_options
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = dispatch.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks)
return xla_executable, compile_options
class UnloadedMeshExecutable:
xla_executable: Any
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]]
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
committed: bool
are_out_shardings_from_xla: Sequence[bool]
name: str
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: Set[int]
auto_spmd_lowering: bool
jaxpr_debug_info: Optional[core.JaxprDebugInfo]
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
handle_args = InputsHandler(self.xla_executable.local_devices(),
self.input_shardings, input_indices)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable,, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
return unsafe_call
def load(self) -> MeshExecutable:
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.jaxpr_debug_info, self)
# May return a MeshExecutable in the compile_replicated case.
def from_hlo(name: str,
hlo: ir.Module,
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO]],
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO,
spmd_lowering: bool,
tuple_args: bool,
auto_spmd_lowering: bool,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
keepalive: Any,
kept_var_idx: Set[int],
backend: xb.XlaBackend,
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
committed: bool,
pmap_nreps: int = 1,
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None,
shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None,
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = refine_shape_polymorphism(hlo)
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
compiler_options.values()) if compiler_options is not None else None
da = device_assignment if isinstance(
device_assignment, _DeviceAssignment) else tuple(device_assignment)
del device_assignment
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
mesh = None
if auto_spmd_lowering:
for i in it.chain.from_iterable([in_shardings, out_shardings]):
if is_auto(i):
mesh = i.mesh # type: ignore
xla_executable, compile_options = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values)
if hasattr(backend, "compile_replicated"):
semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantics_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
return _compile_replicated_mesh_executable_from_hlo(
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
compile_options, tuple(host_callbacks), bool(unordered_effects),
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
pmap_nreps, jaxpr_debug_info)
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if is_auto(o) else (o, False)
for x, o in safe_zip(out_shardings_xla, out_shardings)
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
device_assignment = da.device_assignment if isinstance( # type: ignore
da, _DeviceAssignment) else da
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, # type: ignore
len(global_in_avals), len(global_out_avals))
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
if is_unspecified(orig):
if not op_shardings.are_op_shardings_equal(
xla_s._to_xla_hlo_sharding(aval.ndim), # type: ignore
orig._to_xla_hlo_sharding(aval.ndim)): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if pmap_nreps > 1:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla,
global_in_avals, global_out_avals)
return UnloadedMeshExecutable(
device_assignment=da, # type: ignore
input_shardings=in_shardings, # type: ignore
output_shardings=out_shardings, # type: ignore # arg-type
class MeshExecutableFastpathData(NamedTuple):
xla_executable: xc.LoadedExecutable
out_pytree_def: Any
in_shardings: Sequence[sharding_impls.XLACompatibleSharding]
out_shardings: Sequence[sharding_impls.XLACompatibleSharding]
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_jaxpr_debug_info", "_unloaded_executable",
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
jaxpr_debug_info=None, unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self._unsafe_call = None
self._in_shardings = in_shardings
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
in_shardings, backend, da_object,
committed, kept_var_idx, keepalive) -> MeshExecutable:
assert keepalive is None
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
backend, da_object, committed, kept_var_idx, 1)
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment)
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
local_device_assignment = da_object.addressable_device_assignment
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
handle_outs, kept_var_idx)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
# -- stages.XlaExecutable overrides
def xla_extension_executable(self):
return self.xla_executable
def call(self, *args):
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
arg_avals = map(xla.abstractify, kept_args)
ref_avals = self.in_avals
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
return self._in_shardings
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
return self._out_shardings
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
not self.unsafe_call.has_unordered_effects and
not self.unsafe_call.has_host_callbacks):
return None
def aot_cache_miss(*args, **kwargs):
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
outs, out_flat, args_flat =, *args, **kwargs)
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args_flat))]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
fastpath_data = None
return outs, fastpath_data
return xc._xla.pjit(, None, aot_cache_miss, [], [], []) # type: ignore
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else
errors = []
num_errors = 5
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
if not core.typematch(ref_aval, arg_aval):
errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for "
f"arg {name}")
if errors:
str_errors = '\n'.join(errors[:num_errors])
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise TypeError(
"Computation was compiled for different input types and called with "
f"different types. Here are {num_mismatch_str}:\n{str_errors}")
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
gs = sharding_impls.GSPMDSharding.get_replicated(local_devices)
in_shardings = [gs] * num_in_shardings
out_shardings = [gs] * num_out_shardings
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.
committed = True
return in_shardings, out_shardings, committed, tuple(local_devices)
def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
device_assignment: Sequence[xc.Device],
) -> List[sharding_impls.XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute;
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
# a replicated sharding
from jax._src import array
if len(device_assignment) > 1:
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(
i._original_sharding if hasattr(i, '_original_sharding') else i
for i in in_shardings)
dev, = device_assignment
rep = sharding_impls.SingleDeviceSharding(dev)
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
shardings: Dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding
map(shardings.setdefault, jaxpr.invars, in_shardings)
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
for x in jaxpr.outvars]
def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
env: Dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return out_handler(in_handler(outs))
def _compile_replicated_mesh_executable_from_hlo(
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
semantics_out_shardings, auto_spmd_lowering, compile_options,
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, da, committed, pmap_nreps, jaxpr_debug_info):
assert not auto_spmd_lowering
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings
input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
if pmap_nreps > 1:
# For a jit wrapping a pmap, replicate each input index to match the
# devices of the replicated jit computation.
input_indices = [index * pmap_nreps for index in input_indices]
kept_var_idx = set(kept_var_idx)
# Will compute out_handler with executable information.
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
compile_options=compile_options, host_callbacks=host_callbacks,
ordered_effects=ordered_effects, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, jaxpr_debug_info, None)
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
da_object, committed, kept_var_idx, pmap_nreps):
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment)
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
# Use the standard out_handler.
unsafe_call = backend.compile_replicated(
is_trivial=True, jaxpr=jaxpr, consts=consts,
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None
) -> sharding_impls.NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return sharding_impls.NamedSharding(mesh, pspec, parsed_pspec)
def check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if is_unspecified(i) or is_auto(i):
if hasattr(i, '_original_sharding') and getattr(
i._original_sharding, '_device_backend', False):
return True
return False
def check_gda_or_array_xla_sharding_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
jaxpr_debug_info: Optional[core.JaxprDebugInfo]) -> None:
from jax._src.array import ArrayImpl
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
errors = []
num_errors = 5
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
if not isinstance(arg, ArrayImpl):
# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if (not check_device_backend_on_shardings([xs]) and
arg._committed and
not op_shardings.are_op_shardings_equal(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")
if errors:
str_errors = '\n'.join(errors[:num_errors])
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
"Array(s) sharding does not match the input(s) sharding. "
f"Here are {num_mismatch_str}:\n{str_errors}")
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
pspec, "pspec to array_mapping")
return _get_array_mapping(parsed_pspec)
_forbidden_primitives = {
'xla_pmap': 'pmap',
def _sanitize_mesh_jaxpr(jaxpr):
if isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = jaxpr.jaxpr
for eqn in jaxpr.eqns:
if in _forbidden_primitives:
raise RuntimeError(f"Nesting {_forbidden_primitives[]} "
f"inside xmaps not supported!")
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
custom_resource_typing_rules: Dict[core.Primitive, Callable] = {}
def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
if isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = jaxpr.jaxpr
def _check_aval(aval, what_thunk):
if not hasattr(aval, 'named_shape'):
resource_to_axis = {}
for axis in aval.named_shape:
if axis_resources:
for resource in axis_resources[axis]:
if resource in resource_to_axis:
other_axis = resource_to_axis[resource]
axis, other_axis = sorted([str(axis), str(other_axis)])
raise JAXTypeError(
f"Axes `{axis}` and `{other_axis}` are both mapped to the "
f"resource `{resource}`, but they coincide in the named_shape "
f"of {what_thunk()}")
resource_to_axis[resource] = axis
what_thunk = lambda: (f"an input to {what_jaxpr_thunk()}")
for v in jaxpr.constvars:
_check_aval(v.aval, what_thunk)
for v in jaxpr.invars:
_check_aval(v.aval, what_thunk)
what_thunk = lambda: (f"a value returned from a primitive {eqn.primitive} created "
f"at {source_info_util.summarize(eqn.source_info)}")
rec_what_jaxpr_thunk = lambda: (f"a primitive {eqn.primitive} created at"
for eqn in jaxpr.eqns:
typing_rule = custom_resource_typing_rules.get(eqn.primitive, None)
if typing_rule:
typing_rule([v.aval for v in eqn.invars], eqn.params, eqn.source_info,
resource_env, axis_resources)
for v in eqn.outvars:
_check_aval(v.aval, what_thunk)
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
# NOTE: This takes in the non-sharded avals!
def mk_sharding_spec(aval, aval_axes):
if aval is core.abstract_token:
assert not aval_axes
return ShardingSpec([], [Replicated(axis_size) for axis_size in axis_sizes.values()])
aval_shape = list(aval.shape)
# NOTE: sorted is stable, which is important when multiple resources
# map to the same axis.
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
if not allow_uneven_axes:
if aval_shape[axis] % axis_sizes[name] != 0:
raise ValueError(
f'The aval shape on dimension {axis} is {aval_shape[axis]} and '
f'the size of axis {name} is {axis_sizes[name]}. The aval shape % '
'axis size should be zero but got '
f'{aval_shape[axis] % axis_sizes[name]}')
aval_shape[axis] //= axis_sizes[name]
return sharding_specs.make_sharding_spec(
axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
return mk_sharding_spec
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
def device_put(x, devices: Sequence[xc.ArrayImpl],
replicate: bool=False) -> List[xc.ArrayImpl]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return [jax.device_put(x, device) for device in devices]
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
# TODO(phawkins): fix external users not to use these functions.
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
return sharding_specs.create_pmap_sharding_spec(
aval.shape, sharded_dim, sharded_dim_size)
def _pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_aval, map_axis: Optional[int]) -> ShardingSpec:
assert npart == 1, npart
assert parts is None, parts
return sharding_specs.pmap_sharding_spec(
nrep, axis_size, sharded_aval.shape, map_axis)