Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/shard_map.py

1265 lines
55 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
from functools import partial
import inspect
import itertools as it
import math
import operator as op
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
Set, Tuple, TypeVar, Union, cast)
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import core
from jax._src import dtypes
from jax._src import ad_util
from jax._src import callback
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import ops
from jax._src import pjit
from jax._src import prng
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src import array
from jax._src.core import Tracer
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, fft, linalg, control_flow)
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_leaves, keystr)
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
generate_key_paths, KeyPath)
from jax.experimental.multihost_utils import (host_local_array_to_global_array,
global_array_to_host_local_array)
P = PartitionSpec
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
traceback_util.register_exclusion(__file__)
# API
Specs = Any # PyTree[PartitionSpec]
@traceback_util.api_boundary
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
check_rep: bool = True):
return _shard_map(f, mesh, in_specs, out_specs, check_rep)
def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
out_specs: Union[Specs, Callable[[], Specs]],
check_rep: bool = True):
if not callable(f):
raise TypeError("shard_map requires a callable for its first argument, "
f"but got {f} of type {type(f)}.")
if not isinstance(mesh, Mesh):
raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its "
f"second argument, but got {mesh} of type {type(mesh)}.")
_check_specs(SpecErrorType.input, in_specs)
if not callable(out_specs):
_check_specs(SpecErrorType.out, out_specs)
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
fun = lu.wrap_init(f)
args_flat, in_tree = tree_flatten(args)
try: in_specs_flat = broadcast_prefix(in_specs, args)
except ValueError:
e, *_ = prefix_errors(in_specs, args)
raise e('shard_map in_specs') from None
_check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
@memoize
def out_names_thunk():
if callable(out_specs):
out_specs_ = out_specs()
_check_specs(SpecErrorType.out, out_specs_)
else:
out_specs_ = out_specs
dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves)
try: out_specs_flat = broadcast_prefix(out_specs_, dummy)
except ValueError:
e, *_ = prefix_errors(out_specs_, dummy)
raise e('shard_map out_specs') from None
return tuple(map(_canonicalize_spec, out_specs_flat))
try:
out_flat = shard_map_p.bind(
flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat,
out_names_thunk=out_names_thunk, check_rep=check_rep)
except _SpecError as e:
fails, = e.args
if not callable(out_specs):
msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails)
if any(fail is not no_fail and not fail.shape for fail in fails):
msg += (" In particular, for rank 0 outputs which are not constant "
"over the mesh, add at least one (singleton) axis to them so "
"that they can be concatenated using out_specs.")
raise ValueError(msg) from None
except _RepError as e:
fails, = e.args
if not callable(out_specs):
msg = _rep_error(f, mesh, out_tree(), out_specs, fails)
raise ValueError(msg) from None
return tree_unflatten(out_tree(), out_flat)
return wrapped
# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs
AxisName = Hashable
AxisNames = Dict[int, Tuple[AxisName, ...]] # TODO(mattjj): make it hashable
def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
if isinstance(spec, PartitionSpec):
return {i: names if isinstance(names, tuple) else (names,)
for i, names in enumerate(spec) if names is not None}
else:
return spec
# Error checking and messages
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
if error_type == SpecErrorType.input and specs is None:
raise TypeError(
"shard_map in_specs argument must be a pytree of "
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
"where `P = jax.sharding.PartitionSpec`?")
if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return
prefix = 'in' if error_type == SpecErrorType.input else 'out'
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
for key, x in generate_key_paths(specs) if not isinstance(x, P)]
raise TypeError(
f"shard_map {prefix}_specs argument must be a pytree of "
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
+ '\n\n'.join(msgs) + '\n\n'
f"Check the {prefix}_specs values passed to shard_map.")
class NoFail: pass
no_fail = NoFail()
def _check_specs_vs_args(
f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs,
in_specs_flat: List[P], xs: List) -> None:
in_avals = map(shaped_abstractify, xs)
fail = [a if not len(p) <= a.ndim else no_fail
for p, a in zip(in_specs_flat, in_avals)]
if any(f is not no_fail for f in fail):
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns)
for d, ns in names.items()) else no_fail
for a, names in zip(in_avals, in_names_flat)]
if any(f is not no_fail for f in fail):
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
raise ValueError(msg)
def _spec_rank_error(
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
fun_name = getattr(f, '__name__', str(f))
if error_type == SpecErrorType.input:
prefix, base = 'in', 'args'
ba = _try_infer_args(f, tree)
else:
prefix, base = 'out', f'{fun_name}(*args)'
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if error_type == SpecErrorType.input and ba is not None:
arg_key, *_ = fail_key
extra = (f", where {base}[{arg_key}] is bound to {fun_name}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
else:
extra = ""
msgs.append(
f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length "
f"{len(spec)}, but "
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
assert msgs
msg = (f"shard_map applied to the function '{fun_name}' was given an "
f"{prefix}_specs entry which is too long to be compatible with the "
f"corresponding {prefix}put value from the function:\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Entries in {prefix}_specs must be of length no greater than the "
f"number of axes in the corresponding {prefix}put value.\n\n"
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
f"that its {prefix}puts have sufficient rank.")
if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)):
msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs "
"entry of `P()`, where `P = jax.sharding.PartitionSpec`.")
return msg
def _spec_divisibility_error(
f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
ba = _try_infer_args(f, tree)
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if ba is not None:
arg_key, *_ = fail_key
extra = (f", where args[{arg_key}] is bound to {fun_name}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
names = _canonicalize_spec(spec)
for d, ns in names.items():
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
total = 'total ' if len(ns) > 1 else ''
sz = math.prod(mesh.shape[n] for n in ns)
msgs.append(
f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
f"{aval.shape[d]}")
assert msgs
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
f"arrays with axis sizes that are not evenly divisible by the "
f"corresponding mesh axis sizes:\n\n"
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
f"axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Array arguments' axis sizes must be evenly divisible by the mesh "
f"axis or axes indicated by the corresponding elements of the "
f"argument's in_specs entry. Consider checking that in_specs are "
f"correct, and if so consider changing the mesh axis sizes or else "
f"padding the input and adapting '{fun_name}' appropriately.")
return msg
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: List[Union[Set, NoFail]]) -> str:
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails):
dst = _canonicalize_spec(spec)
unmentioned = _unmentioned(mesh, dst)
if len(unmentioned) > 1:
need_rep = ','.join(map(str, unmentioned))
got_rep = ','.join(map(str, rep))
diff = ','.join(map(str, [n for n in unmentioned if n not in rep]))
msgs.append(
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axes "
f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, "
f"which is missing the required axes {diff}")
else:
need_rep_, = unmentioned
msgs.append(
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axis "
f"'{need_rep_}', but could not infer replication over any axes")
assert msgs
msg = (f"shard_map applied to the function '{fun_name}' was given "
f"out_specs which require replication which can't be statically "
f"inferred given the mesh:\n\n"
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
f"axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
"Check if these output values are meant to be replicated over those "
"mesh axes. If not, consider revising the corresponding out_specs "
"entries. If so, consider disabling the check by passing the "
"check_rep=False argument to shard_map.")
return msg
def _unmentioned(mesh: Mesh, names: AxisNames) -> List[AxisName]:
name_set = {n for ns in names.values() for n in ns}
return [n for n in mesh.axis_names if n not in name_set]
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
return inspect.signature(f).bind(*dummy_args)
except (TypeError, ValueError):
return None
T = TypeVar('T')
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: List[Union[T, NoFail]]
) -> List[Tuple[Tuple[KeyPath, P], Tuple[KeyPath, T]]]:
failures = tree_unflatten(tree, fails)
failures_aug = generate_key_paths(failures)
specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs))
leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P
specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf)
return [((spec_key, spec), (fail_key, fail_data))
for (spec_key, spec), (fail_key, fail_data)
in zip(specs_aug, failures_aug) if fail_data is not no_fail]
# Primitive
JaxType = Any
MaybeTracer = Union[JaxType, Tracer]
class ShardMapPrimitive(core.Primitive):
multiple_results = True
def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
check_rep: bool) -> Sequence[MaybeTracer]:
top_trace = core.find_top_trace(args)
fun, env_todo = process_env_traces(fun, top_trace.level, mesh,
in_names, out_names_thunk, check_rep)
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_names = out_names_thunk()
_, xforms = env_todo()
for t in xforms:
out_names = t(out_names)
return out_names
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
shard_map_p, fun, tracers, mesh=mesh, in_names=in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
todos, _ = env_todo()
return map(core.full_lower, core.apply_todos(todos, outs))
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
axes = new_params.pop('out_names')
new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
shard_map_p = ShardMapPrimitive('shard_map')
@lu.transformation_with_aux
def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
*args: Any):
outs = yield args, {}
todos, out_names_transforms = [], []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=op.attrgetter('_trace.level'))
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, (todo, xform) = trace.post_process_shard_map(
outs, mesh, in_names, out_names_thunk, check_rep)
todos.append(todo)
out_names_transforms.append(xform)
yield outs, (tuple(todos), tuple(out_names_transforms))
# Staging
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
check_rep: bool,
) -> Sequence[pe.DynamicJaxprTracer]:
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(
f, main, in_avals_)
out_avals_ = map(_check_shapedarray, out_avals_generic)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _output_rep(mesh, jaxpr, in_rep)
_check_reps(mesh, out_names_thunk(), out_rep)
out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_)
source_info = source_info_util.current()
out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals]
invars = map(trace.getvar, in_tracers)
constvars = map(trace.getvar, map(trace.instantiate_const, consts))
outvars = map(trace.makevar, out_tracers)
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_names=in_names_staged,
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
check_rep=check_rep)
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
jaxpr.effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)))
else:
raise NotImplementedError # TODO(mattjj): add table with handlers
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)),
named_shape={k: v for k, v in aval.named_shape.items()
if k not in mesh.shape})
else:
raise NotImplementedError # TODO(mattjj): add table with handlers
# Type-checking
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
check_rep):
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
raise core.JaxprTypeError("shard_map argument avals not compatible with "
"jaxpr binder avals and in_names")
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
core.check_jaxpr(jaxpr)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _output_rep(mesh, jaxpr, in_rep)
for rep, dst in zip(out_rep, out_names):
if not _valid_repeats(mesh, rep, dst):
raise core.JaxprTypeError("shard_map can't prove output is sufficiently "
"replicated")
out_avals_sharded = [x.aval for x in jaxpr.outvars]
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
return out_avals, jaxpr.effects
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> Set[AxisName]:
return set(mesh.axis_names) - set(n for ns in names.values() for n in ns)
def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
) -> Sequence[Set[AxisName]]:
env: Dict[core.Var, Set[AxisName]] = {}
def read(x: core.Atom) -> Set[AxisName]:
return env[x] if type(x) is core.Var else set(mesh.axis_names)
def write(v: core.Var, val: Set[AxisName]) -> None:
env[v] = val
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
map(write, jaxpr.invars, in_rep)
for e in jaxpr.eqns:
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
if e.primitive.multiple_results:
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
map(write, e.outvars, out_rep)
else:
write(e.outvars[0], out_rep)
return map(read, jaxpr.outvars)
def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool:
return set(_unmentioned(mesh, dst)).issubset(rep)
# Lowering
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
check_rep):
del check_rep
in_avals_ = [v.aval for v in jaxpr.invars]
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh), in_names, ctx.avals_in,
in_avals_, in_nodes)
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names)
)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
out_nodes_, _ = mlir._call_lowering(
"shmap_body", (), jaxpr, None, sub_ctx, in_avals_, out_avals_,
mlir.TokenSet(), *in_nodes_, dim_var_values=ctx.dim_var_values,
arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_),
result_names=map(_pspec_mhlo_attrs, out_names, out_avals_))
return map(partial(_xla_unshard, ctx, mesh), out_names, out_avals_,
ctx.avals_out, out_nodes_)
mlir.register_lowering(shard_map_p, _shard_map_lowering)
def _xla_shard(ctx: mlir.LoweringRuleContext,
mesh, names, aval_in, aval_out, x):
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_hlo_sharding(aval_in.ndim)
if dtypes.is_opaque_dtype(aval_in.dtype):
shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore
unspecified_dims=set())
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]
def _xla_unshard(ctx: mlir.LoweringRuleContext,
mesh, names, aval_in, aval_out, xs):
x, = xs
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_hlo_sharding(aval_out.ndim)
if dtypes.is_opaque_dtype(aval_out.dtype):
shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto)
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out,
shard_proto.to_proto(), set()) # type: ignore
def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str:
if isinstance(aval, core.ShapedArray):
return str(map(names.get, range(aval.ndim)))
return ''
# Eager evaluation
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
check_rep):
del prim
args = map(partial(_unmatch_spec, mesh), in_names, args)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main:
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main):
t = main.with_cur_sublevel()
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, ans, out_tracers
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
core.EvalTrace.process_shard_map = _shard_map_impl
def _names_to_pspec(names: AxisNames) -> PartitionSpec:
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType:
with core.eval_context():
return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x)
def _unmatch(mesh, src_tup, x):
src = _names_to_pspec(dict(src_tup))
dst = P(mesh.axis_names)
return shard_map(_add_singleton, mesh, (src,), dst)(x)
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
) -> None:
fail = [a if n and not max(n) < a.ndim else no_fail
for n, a in zip(names, avals)]
if any(f is not no_fail for f in fail): raise _SpecError(fail)
class _SpecError(Exception): pass
def _check_reps(mesh, names, reps):
fail = [r if not _valid_repeats(mesh, r, n) else no_fail
for n, r in zip(names, reps)]
if any(f is not no_fail for f in fail): raise _RepError(fail)
class _RepError(Exception): pass
def _match_spec(mesh: Mesh, rep: Set[AxisName], dst: AxisNames, x: JaxType
) -> JaxType:
with core.eval_context():
return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x)
def _match(mesh, dst_tup, x):
src = P(mesh.axis_names)
dst = _names_to_pspec(dict(dst_tup))
return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x)
def _rem_singleton(x): return x.reshape(x.shape[1:])
def _add_singleton(x): return x.reshape(1, *x.shape)
class ShardMapTrace(core.Trace):
mesh: Mesh
check: bool
def __init__(self, *args, mesh, check):
super().__init__(*args)
self.mesh = mesh
self.check = check
def pure(self, val):
val_ = _unmatch_spec(self.mesh, {}, val)
return ShardMapTracer(self, set(self.mesh.axis_names), val_)
def sublift(self, tracer):
return ShardMapTracer(self, tracer.rep, tracer.val)
def process_primitive(self, prim, tracers, params):
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
eager_rule = eager_rules.get(prim)
if eager_rule:
out_vals = eager_rule(self.mesh, *in_vals, **params)
else:
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
with core.eval_context(), jax.disable_jit(False):
out_vals = jax.jit(f)(*in_vals)
rep_rule = _rep_rules.get(prim, partial(_rep_rule, prim))
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
if prim.multiple_results:
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
return map(partial(ShardMapTracer, self), out_rep, out_vals)
return ShardMapTracer(self, out_rep, out_vals)
def process_call(self, call_primitive, fun, tracers, params):
raise NotImplementedError(
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
"function, and open a feature request at "
"https://github.com/google/jax/issues !")
def process_map(self, map_primitive, fun, tracers, params):
raise NotImplementedError(
"Eager evaluation of `pmap` inside a `shard_map` isn't yet supported."
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
raise NotImplementedError(
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def post_process_custom_jvp_call(self, out_tracers, _):
raise NotImplementedError(
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
raise NotImplementedError(
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def post_process_custom_vjp_call(self, out_tracers, _):
raise NotImplementedError(
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def process_axis_index(self, frame):
raise NotImplementedError(
"Eager evaluation of an `axis_index` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
class ShardMapTracer(core.Tracer):
rep: Set[AxisName]
val: JaxType
def __init__(self, trace, rep, val):
self._trace = trace
self.rep = rep
self.val = val
@property
def aval(self):
aval = core.get_aval(self.val)
if (isinstance(aval, core.ConcreteArray) and
self.rep == set(self._trace.mesh.axis_names)):
with core.eval_context():
return core.get_aval(self.val[0])
else:
aval = core.raise_to_shaped(aval)
return core.mapped_aval(self._trace.mesh.size, 0, aval)
def full_lower(self) -> ShardMapTracer:
return self
def __str__(self) -> str:
with core.eval_context():
blocks = list(self.val)
mesh = self._trace.mesh
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
return '\n'.join(
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
def _prim_applier(prim, params_tup, mesh, *args):
def apply(*args):
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
return tree_map(_add_singleton, outs)
spec = P(mesh.axis_names)
return shard_map(apply, mesh, spec, spec, False)(*args)
eager_rules: Dict[core.Primitive, Callable] = {}
# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually
def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any],
effect: debugging.DebugEffect):
del effect
with core.eval_context():
all_blocks = zip(*map(list, args))
for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks):
callback(*blocks)
return []
eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule
def _device_put_eager_rule(mesh, x, *, src, device):
del mesh, src
if device is None:
return x
else:
raise ValueError("device_put with explicit device not allowed within "
f"shard_map-decorated functions, but got device {device}")
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
# Static replication checking
def _rep_rule(prim: core.Primitive, mesh: Mesh, *in_rep: Set[AxisName],
**params: Any) -> Union[Set[AxisName], List[Set[AxisName]]]:
raise NotImplementedError(
f"No replication rule for {prim}. As a workaround, pass the "
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
"issue at https://github.com/google/jax/issues")
_rep_rules: Dict[core.Primitive, Callable] = {}
register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule)
register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule)
def _standard_rep_rule(_, *in_rep, **__):
return set.intersection(*in_rep) if in_rep else set()
for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
windowed_reductions.__dict__.values(), fft.__dict__.values(),
linalg.__dict__.values(), ops.__dict__.values(),
ad_util.__dict__.values(), prng.__dict__.values(),
custom_derivatives.__dict__.values()):
if isinstance(o, core.Primitive): register_standard(o)
register_standard(lax_parallel.ppermute_p) # doesn't change replication
@register_rule(lax_parallel.psum_p)
def _psum_rule(_, *in_rep, axes, axis_index_groups):
if axis_index_groups is not None: raise NotImplementedError
axes = (axes,) if not isinstance(axes, tuple) else axes
return [r | set(axes) for r in in_rep] # introduces replication
@register_rule(lax_parallel.all_gather_p)
def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size,
axis_index_groups, tiled):
if axis_index_groups is not None: raise NotImplementedError
if not tiled: raise NotImplementedError
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
return in_rep | set(axis_name) # introduces replication
@register_rule(lax_parallel.reduce_scatter_p)
def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size,
axis_index_groups, tiled):
if axis_index_groups is not None: raise NotImplementedError
if not tiled: raise NotImplementedError
return in_rep - {axis_name} # removes replication
@register_rule(lax_parallel.all_to_all_p)
def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name,
axis_index_groups):
if axis_index_groups is not None: raise NotImplementedError
return in_rep - {axis_name} # removes replication
@register_rule(lax_parallel.axis_index_p)
def _axis_index_rule(mesh, *, axis_name):
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
return set(mesh.shape) - set(axis_name)
@register_rule(pjit.pjit_p)
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
@register_rule(debugging.debug_callback_p)
def _debug_callback_rule(mesh, *in_rep, **_):
return []
@register_rule(callback.pure_callback_p)
def _pure_callback_rule(mesh, *in_rep, result_avals, **_):
return [set()] * len(result_avals)
@register_rule(dispatch.device_put_p)
def _device_put_rep_rule(mesh, x, *, src, device):
return x
@register_rule(control_flow.loops.scan_p)
def _scan_rule(mesh, *in_rep, jaxpr, num_consts, num_carry, linear, length,
reverse, unroll):
const_rep, carry_rep, xs_rep = split_list(in_rep, [num_consts, num_carry])
for _ in range(1 + num_carry):
out_rep = _output_rep(mesh, jaxpr.jaxpr, [*const_rep, *carry_rep, *xs_rep])
if carry_rep == out_rep[:num_carry]:
break
else:
carry_rep = map(op.and_, carry_rep, out_rep[:num_carry])
else:
assert False, 'Fixpoint not reached'
return out_rep
# Batching
def _shard_map_batch(
trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun,
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
check_rep: bool) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers)
if all(bdim is batching.not_mapped for bdim in in_dims):
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
out_names_thunk=out_names_thunk, check_rep=check_rep)
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
raise NotImplementedError
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore
for ax in names} for names, d in zip(in_names, in_dims)]
spmd_axis_name = trace.spmd_axis_name
if spmd_axis_name is not None:
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore
else ns for ns, d in zip(new_in_names, in_dims)]
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk())
new_params = dict(mesh=mesh, in_names=new_in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
out_vals = prim.bind(fun, *in_vals, **new_params)
make_tracer = partial(batching.BatchTracer, trace,
source_info=source_info_util.current())
return map(make_tracer, out_vals, out_dims())
batching.BatchTrace.process_shard_map = _shard_map_batch
def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep):
del mesh, in_names, out_names_thunk, check_rep
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
m = trace.main
def todo(vals):
trace = m.with_cur_sublevel()
return map(partial(batching.BatchTracer, trace), vals, dims, srcs)
out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims)
return vals, (todo, out_names_transform)
batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process
def _batch_out_names(spmd_axis_name, dims, out_names):
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
for ax in names} for names, d in zip(out_names, dims)]
if spmd_axis_name is not None:
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(out_names_, dims)]
return out_names_
# Autodiff
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep):
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
which_nz = [ type(t) is not ad.Zero for t in tangents]
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
args, in_tree = tree_flatten((primals, tangents))
f_jvp = ad.jvp_subtrace(f, trace.main)
f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp)
tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz]
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_ax = out_names_thunk()
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
result = shard_map_p.bind(f_jvp, *args, **params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t
for p, t in zip(primal_out, tangent_out)]
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep):
del mesh, in_names, out_names_thunk, check_rep
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
tangents_nz = [type(t) is not ad.Zero for t in tangents]
m = trace.main
def todo(x):
primals, tangents = tree_unflatten(treedef, x)
return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents)
def out_names_transform(out_names):
return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz))
return out, (todo, out_names_transform)
ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep):
in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits(f, trace.main, False)
f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits(
f, (*in_knowns,), (*in_avals_sharded,))
@as_hashable_function(closure=out_names_thunk)
def known_out_names():
out_knowns, _, jaxpr, _ = aux()
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
assert not any(not v.aval.shape for v in jaxpr.constvars)
res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
return (*out_known_names, *res_names)
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep)
out = shard_map_p.bind(f_known, *in_consts, **known_params)
out_knowns, out_avals_sharded, jaxpr, env = aux()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
+ (*unk_in_names,))
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()]
unk_params = dict(mesh=mesh, in_names=unk_in_names,
out_names=unk_out_names, jaxpr=jaxpr, check_rep=False)
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type]
out_tracers, shard_map_p, unk_params,
jaxpr.effects, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_partial_eval_post_process(
trace, tracers, mesh, in_names, out_names_thunk, check_rep):
del check_rep
unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
jaxpr, res = _promote_scalar_residuals_jaxpr(jaxpr, res)
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
out = [*consts, *res]
main = trace.main
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr)
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
out_names=(*out_names_unknown,), check_rep=False)
out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
name_stack = trace._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
shard_map_p, staged_params, jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def out_names_transform(out_names):
nonlocal out_names_unknown
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
out_names_unknown: Optional[list] = None
return out, (todo, out_names_transform)
pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
@lu.transformation
def _promote_scalar_residuals(*args, **kwargs):
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
jaxpr, out_consts = _promote_scalar_residuals_jaxpr(jaxpr, out_consts)
yield jaxpr, (out_pvals, out_consts, env)
def _promote_scalar_residuals_jaxpr(jaxpr, res):
which = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
for v in jaxpr.constvars]
res_ = [jax.lax.broadcast(x, (1,)) if s else x for x, s in zip(res, which)]
@lu.wrap_init
def fun(*args):
res = [_rem_singleton(x) if s else x for x, s in zip(res_, which)]
return core.eval_jaxpr(jaxpr, res, *args)
jaxpr, _, res = pe.trace_to_jaxpr_dynamic(fun, [v.aval for v in jaxpr.invars])
return jaxpr, res
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
for ns, x in zip(in_names, args)]
all_args, in_tree = tree_flatten((out_cts, args))
@lu.wrap_init
def fun_trans(out_cts, args):
res, undefs = partition_list(map(ad.is_undefined_primal, args), args)
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False)
res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res)
out = ad.backward_pass(
jaxpr_unknown.jaxpr, (), False, (), (*res_reshaped, *undefs), out_cts
)
return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
for ns, x in zip(in_names, out)]
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
new_in_names = \
[n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \
[n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal]
def new_out_names_thunk():
return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz)
out_flat = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
def _shard_map_axis_subst(params, subst, traverse):
if 'jaxpr' not in params:
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['mesh'].shape else subst(name)
with core.extend_axis_env_nd(params['mesh'].shape.items()):
new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
return dict(params, jaxpr=new_jaxpr)
core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst
# Remat
def _partial_eval_jaxpr_custom_rule(
saveable: Callable[..., bool], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: core.JaxprEqn
) -> Tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
List[core.Var]]:
jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh']
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged])
params_known, params_staged = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
residuals = [newvar(_unshard_aval(mesh, {0: (*mesh.axis_names,)}, var.aval))
for var in jaxpr_staged.invars[:num_res]]
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info)
eqn_staged = pe.new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
_partial_eval_jaxpr_custom_rule
def _add_reshapes(num_res, jaxpr_known, jaxpr_staged):
if not num_res: return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
@lu.wrap_init
def known(*args):
out = core.eval_jaxpr(jaxpr_known, (), *args)
out_known, res = split_list(out, [len(out) - num_res])
return [*out_known, *map(_add_singleton, res)]
avals_in = [v.aval for v in jaxpr_known.invars]
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
@lu.wrap_init
def staged(*args):
res_, ins = split_list(args, [num_res])
res = map(_rem_singleton, res_)
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
res_avals = [v.aval for v in jaxpr_known.outvars[-num_res:]]
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[num_res:]]]
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
return jaxpr_known, jaxpr_staged
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
num_res, params_known, params_staged):
# prune inputs to jaxpr_known according to unks_in
mesh = params_known['mesh']
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: (*mesh.axis_names,)}] * num_res
new_params_known = dict(params_known, in_names=tuple(in_names_known),
out_names=tuple(out_names_known))
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
in_names_staged = [{0: (*mesh.axis_names,)}] * num_res + in_names_staged
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
out_names=tuple(out_names_staged), check_rep=False)
return new_params_known, new_params_staged
# DCE
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
def _shard_map_dce(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
return used_inputs, None
else:
_, in_names = partition_list(used_inputs, eqn.params['in_names'])
_, out_names = partition_list(used_outputs, eqn.params['out_names'])
new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names),
out_names=tuple(out_names))
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[x for x, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
pe.dce_rules[shard_map_p] = _shard_map_dce
# Implementing pmap in terms of shard_map
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
static_broadcasted_argnums=(), devices=None, backend=None,
axis_size=None, donate_argnums=(), global_arg_shapes=None):
devices = tuple(devices) if devices is not None else devices
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
def infer_params(*args, **kwargs):
p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, devices, backend, axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
mesh = Mesh(_get_devices(p, backend), (axis_name,))
_pmapped, in_specs, out_specs = _cached_shard_map(
p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name)
flat_global_args = host_local_array_to_global_array(
p.flat_args, mesh, list(in_specs))
jitted_f = jax.jit(
_pmapped,
donate_argnums=(i for i, val in enumerate(p.donated_invars) if val))
return jitted_f, flat_global_args, p.out_tree, mesh, out_specs
def wrapped(*args, **kwargs):
(jitted_f, flat_global_args, out_tree, mesh,
out_specs) = infer_params(*args, **kwargs)
with jax.spmd_mode('allow_all'):
outs = jitted_f(*flat_global_args)
outs = global_array_to_host_local_array(outs, mesh, out_specs())
return tree_unflatten(out_tree(), outs)
def lower(*args, **kwargs):
jitted_f, _, _, _, _ = infer_params(*args, **kwargs)
with jax.spmd_mode('allow_all'):
return jitted_f.lower(*args, **kwargs)
wrapped.lower = lower
return wrapped
@lu.cache
def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat))
out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk())
fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk)
return (_shard_map(fun.call_wrapped, mesh, in_specs, out_specs, check_rep=False),
in_specs, out_specs)
@lu.transformation
def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs):
args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax),
list(args), list(in_axes))
out = yield args, {}
yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
list(out), list(out_axes_thunk()))
def _axis_to_spec(axis_name, ax):
if isinstance(ax, int):
specs = [None] * ax + [axis_name]
return P(*specs)
elif ax is None:
return P()
else:
raise TypeError(ax)
def _get_devices(p, backend):
if backend is not None and p.devices is None:
devs = jax.devices(backend=backend)
else:
devs = jax.devices() if p.devices is None else p.devices
if jax.process_count() > 1:
return devs[:p.global_axis_size]
return devs[:p.local_axis_size]