Source code for jax._src.lax.parallel

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Parallelization primitives.
"""

import collections
import warnings

import numpy as np

from jax import core
from jax import dtypes
from jax import tree_util
from jax._src import source_info_util
from . import lax
from jax.core import ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.util import partial, unzip2, prod
from jax.lib import xla_client as xc
from jax.lib import xla_bridge as xb
from jax.config import config
from jax._src.numpy import lax_numpy

xops = xc.ops


### parallel traceables

[docs]def psum(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Inputs of boolean dtype are converted to integers before the reduction. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce sum along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.16666667 0.33333334 0.5 ] """ if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) leaves = [lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves] out_flat = psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat)
[docs]def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce mean along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [ 1.5 1.5 1.5 1.5 ] >>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.66666667 1.33333334 2.0 ] """ x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups) n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_map(lambda v: v / n, x)
[docs]def pmax(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce max along the axis ``axis_name``. """ if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) out_flat = pmax_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat)
[docs]def pmin(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce min along the axis ``axis_name``. """ if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) out_flat = pmin_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat)
def _validate_axis_index_groups(axis_index_groups): if axis_index_groups is None: return len_0 = len(axis_index_groups[0]) if any(len(g) != len_0 for g in axis_index_groups): raise ValueError("axis_index_groups must all be the same size") axis_space = range(len_0 * len(axis_index_groups)) if {i for g in axis_index_groups for i in g} != set(axis_space): raise ValueError("axis_index_groups must cover all indices exactly once")
[docs]def ppermute(x, axis_name, perm): """Perform a collective permutation according to the permutation ``perm``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This function is an analog of the CollectivePermute XLA HLO. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). perm: list of pairs of ints, representing ``(source_index, destination_index)`` pairs that encode how the mapped axis named ``axis_name`` should be shuffled. The integer values are treated as indices into the mapped axis ``axis_name``. Any two pairs should not have the same source index or the same destination index. For each index of the axis ``axis_name`` that does not correspond to a destination index in ``perm``, the corresponding values in the result are filled with zeros of the appropriate type. Returns: Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ return tree_util.tree_map( partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x)
[docs]def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). perm: list of of ints encoding sources for the permutation to be applied to the axis named ``axis_name``, so that the output at axis index i comes from the input at axis index perm[i]. Every integer in [0, N) should be included exactly once for axis size N. Returns: Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ if set(perm) != set(range(len(perm))): raise ValueError(f"`perm` does not represent a permutation: {perm}") return ppermute(x, axis_name, list(zip(perm, range(len(perm)))))
[docs]def pswapaxes(x, axis_name, axis): """Swap the pmapped axis ``axis_name`` with the unmapped axis ``axis``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. The mapped axis size must be equal to the size of the unmapped axis; that is, we must have ``lax.psum(1, axis_name) == x.shape[axis]``. This function is a special case of ``all_to_all`` where the pmapped axis of the input is placed at the position ``axis`` in the output. That is, it is equivalent to ``all_to_all(x, axis_name, axis, axis)``. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis: int indicating the unmapped axis of ``x`` to map with the name ``axis_name``. Returns: Array(s) with the same shape as ``x``. """ return all_to_all(x, axis_name, axis, axis)
[docs]def all_to_all(x, axis_name, split_axis, concat_axis): """Materialize the mapped axis and map a different axis. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. In the output, the input mapped axis ``axis_name`` is materialized at the logical axis position ``concat_axis``, and the input unmapped axis at position ``split_axis`` is mapped with the name ``axis_name``. The input mapped axis size must be equal to the size of the axis to be mapped; that is, we must have ``lax.psum(1, axis_name) == x.shape[split_axis]``. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). split_axis: int indicating the unmapped axis of ``x`` to map with the name ``axis_name``. concat_axis: int indicating the position in the output to materialize the mapped axis of the input with the name ``axis_name``. Returns: Array(s) with shape given by the expression:: np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) where ``axis_size`` is the size of the mapped axis named ``axis_name`` in the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``. """ def bind(x): if psum(1, axis_name) != x.shape[split_axis]: msg = ("all_to_all requires the size of the mapped axis axis_name to " "equal x.shape[split_axis], but they are {} and {} respectively.") raise ValueError(msg.format(psum(1, axis_name), x.shape[split_axis])) return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name) return tree_util.tree_map(bind, x)
[docs]def axis_index(axis_name): """Return the index along the mapped axis ``axis_name``. Args: axis_name: hashable Python object used to name the mapped axis. Returns: An integer representing the index. For example, with 8 XLA devices available: >>> from functools import partial >>> @partial(jax.pmap, axis_name='i') ... def f(_): ... return lax.axis_index('i') ... >>> f(np.zeros(4)) ShardedDeviceArray([0, 1, 2, 3], dtype=int32) >>> f(np.zeros(8)) ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): ... return lax.axis_index('i'), lax.axis_index('j') ... >>> x, y = f(np.zeros((4, 2))) >>> print(x) [[0 0] [1 1] [2 2] [3 3]] >>> print(y) [[0 1] [0 1] [0 1] [0 1]] """ return axis_index_p.bind(axis_name=axis_name)
### parallel primitives def _allreduce_soft_pmap_rule(prim, reducer, vals, mapped, chunk_size, *, axis_name, axis_index_groups): if axis_index_groups is not None: raise NotImplementedError("soft_pmap does not yet support axis_index_groups") reduced_vals = [reducer(x, [0]) if m else x for x, m in zip(vals, mapped)] outs = prim.bind(*reduced_vals, axis_name=axis_name, axis_index_groups=axis_index_groups) return outs, (False,) * len(vals) # This is only used for collectives that do not include the vmapped axis name, # which is why the rule is so simple. def _collective_batcher(prim, args, dims, **params): return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] def _batched_reduction_collective( prim, if_mapped, if_unmapped, frame, vals_in, dims_in, axis_name, axis_index_groups): assert prim.multiple_results assert frame.name in axis_name if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap collectives. " "Please open a feature request!") vals_out = [if_mapped(v, d) if d is not batching.not_mapped else if_unmapped(v, frame.size) for v, d in zip(vals_in, dims_in)] if len(axis_name) > 1: remaining_axis_names = tuple(n for n in axis_name if n != frame.name) vals_out = prim.bind(*vals_out, axis_name=remaining_axis_names, axis_index_groups=None) return vals_out, [batching.not_mapped] * len(vals_out) def _replica_groups(axis_env, axis_name, axis_index_groups): replica_groups = xla.axis_groups(axis_env, axis_name) if axis_index_groups is not None: replica_groups = [[axis_group[i] for i in axis_index_group] for axis_group in replica_groups for axis_index_group in axis_index_groups] return replica_groups def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups, axis_env, platform): if platform in ("cpu", "tpu"): return _notuple_allreduce_translation_rule( prim, c, *args, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex and prim is lax.add_p: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex and prim is lax.add_p: xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out) # TODO(b/155446630): An XLA:TPU optimization pass also doesn't support # tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior. def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env, axis_index_groups, platform): def all_reduce(x): replica_groups_protos = xc.make_replica_groups( _replica_groups(axis_env, axis_name, axis_index_groups)) scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) return xops.AllReduce(x, computation, replica_groups_protos, None, None) if prim is not lax.add_p: outs = [all_reduce(x) for x in args] else: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x))) if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating) else all_reduce(x) for x in args] return xops.Tuple(c, outs) def _psum_transpose_rule(cts, axis_name, axis_index_groups): nonzero_out_cts, treedef = tree_util.tree_flatten(cts) nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) psum_p = core.Primitive('psum') psum_p.multiple_results = True psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.soft_pmap_rules[psum_p] = \ partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p) batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p) batching.collective_rules[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, d: v.sum(d), lambda v, axis_size: axis_size * v) # We set a special bind rule for psum so that psum(1, 'i') can be evaluated at # tracing time. @psum_p.def_custom_bind def psum_bind(*args, axis_name, axis_index_groups): if all(not isinstance(x, core.Tracer) for x in args): if axis_index_groups is not None: size = len(axis_index_groups[0]) elif isinstance(axis_name, (list, tuple)): size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore else: size = core.axis_frame(axis_name).size # type: ignore return tuple(size * x for x in args) return core.Primitive.bind( psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups) pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p) pxla.multi_host_supported_collectives.add(pmax_p) batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p) batching.collective_rules[pmax_p] = \ partial(_batched_reduction_collective, pmax_p, lambda v, d: v.max(d), lambda v, axis_size: v) pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True pmin_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule, lax.min_p) pxla.multi_host_supported_collectives.add(pmin_p) batching.primitive_batchers[pmin_p] = partial(_collective_batcher, pmin_p) batching.collective_rules[pmin_p] = \ partial(_batched_reduction_collective, pmin_p, lambda v, d: v.min(d), lambda v, axis_size: v) def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform): replica_groups = _replica_groups(axis_env, axis_name, None) group_size = len(replica_groups[0]) srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): msg = "ppermute sources and destinations must be unique, got {}." raise ValueError(msg.format(perm)) full_perm = [] for grp in replica_groups: grp = list(sorted(grp)) full_perm.extend((grp[src], grp[dst]) for src, dst in perm) return xops.CollectivePermute(x, full_perm) def _ppermute_transpose_rule(t, perm, axis_name): srcs, dsts = unzip2(perm) inverse_perm = list(zip(dsts, srcs)) return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm): assert len(perm) == frame.size, "Permutation doesn't match the axis size!" assert axis_name == frame.name, "ppermute batcher called with wrong axis name" (v,), (d,) = vals_in, dims_in assert d is not batching.not_mapped perm_indices = [None] * frame.size for src, dst in perm: perm_indices[src] = dst return lax_numpy.take(v, perm_indices, d), d ppermute_p = core.Primitive('ppermute') ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) ad.deflinear(ppermute_p, _ppermute_transpose_rule) xla.parallel_translations[ppermute_p] = _ppermute_translation_rule pxla.multi_host_supported_collectives.add(ppermute_p) batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) batching.collective_rules[ppermute_p] = _ppermute_batcher def _moveaxis(src, dst, x): perm = [i for i in range(x.ndim) if i != src] perm.insert(dst, src) return lax.transpose(x, perm) def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis): global_full = all_gather(x, axis_name) idx = axis_index(axis_name) local_slice = lax.dynamic_index_in_dim(global_full, idx, split_axis + 1, keepdims=False) return _moveaxis(0, concat_axis, local_slice) def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name, axis_env, platform): # Workaround for AllToAll not being implemented on CPU. replica_groups = _replica_groups(axis_env, axis_name, None) if len(replica_groups[0]) == 1: return x elif platform != 'tpu': warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other " "backends emulate it using a very slow and memory intensive algorithm, so expect " "significant slowdowns.") lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True) return lowering(c, x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_env=axis_env, platform=platform) else: split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') replica_groups_protos = xc.make_replica_groups(replica_groups) if concat_axis == split_axis: return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) else: if concat_axis < split_axis: split_axis += 1 elif split_axis < concat_axis: concat_axis += 1 x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x) x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x) return x def _all_to_all_transpose_rule(cts, axis_name, split_axis, concat_axis): return (all_to_all(cts, axis_name=axis_name, split_axis=concat_axis, concat_axis=split_axis),) def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis): x, = vals_in d, = dims_in if d <= split_axis: split_axis += 1 if d <= concat_axis: concat_axis += 1 # Note: At this point split_axis and concat_axis are adjusted to the extra # dimension and we have d != split_axis and d != concat_axis. if split_axis < d < concat_axis: d -= 1 elif concat_axis < d < split_axis: d += 1 result = all_to_all_p.bind(x, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis) return result, d def _all_to_all_batched_collective(frame, vals_in, dims_in, axis_name, split_axis, concat_axis): if isinstance(axis_name, (list, tuple)) and len(axis_name) > 1: raise NotImplementedError("update after #4835") # TODO(mattjj,apaszke) x, = vals_in d, = dims_in split_axis_adj = split_axis + (1 if d <= split_axis else 0) concat_axis_adj = concat_axis + (1 if split_axis_adj <= concat_axis else 0) if d < split_axis_adj < concat_axis_adj: split_axis_adj -= 1 elif concat_axis_adj < split_axis_adj < d: split_axis_adj += 1 return _moveaxis(d, concat_axis_adj, x), split_axis_adj def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis): input_aval = raise_to_shaped(x) shape = list(input_aval.shape) size = shape.pop(split_axis) shape.insert(concat_axis, size) return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False) all_to_all_p = core.Primitive('all_to_all') all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval) xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule ad.deflinear(all_to_all_p, _all_to_all_transpose_rule) pxla.multi_host_supported_collectives.add(all_to_all_p) batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher batching.collective_rules[all_to_all_p] = _all_to_all_batched_collective def _expand(dim, size, index, x): shape = list(x.shape) shape.insert(dim, size) out = lax.full(shape, lax._const(x, 0)) return lax.dynamic_update_index_in_dim(out, x, index, dim) def _allgather(x, dim, size, index, axis_name, axis_index_groups=None): outs = tree_util.tree_map(partial(_expand, dim, size, index), x) return psum(outs, axis_name, axis_index_groups=axis_index_groups)
[docs]def all_gather(x, axis_name, *, axis_index_groups=None): """Gather values of x across all replicas. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This is equivalent to, but faster than, all_to_all(broadcast(x)). Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) representing the result of an all-gather along the axis ``axis_name``. Shapes are the same as ``x.shape``, but with a leading dimension of the axis_size. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]] An example of using axis_index_groups, groups split by even & odd device ids: >>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.]] >>> y = jax.pmap(lambda x: jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x) >>> print(y) [[[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] [[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] """ index = axis_index(axis_name) if axis_index_groups is not None: indices = np.array(axis_index_groups).flatten() axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0]) index = lax_numpy.array(axis_index_to_group_index)[index] axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) return _allgather(x, 0, axis_size, index, axis_name, axis_index_groups)
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform): axis_pos = list(axis_env.names).index(axis_name) nreplicas = axis_env.nreps // prod(axis_env.sizes) div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]), dtype=np.uint32)) mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name): assert not vals and not mapped idx = axis_index(axis_name) # type: ignore return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True axis_index_p = core.Primitive('axis_index') xla.parallel_translations[axis_index_p] = _axis_index_translation_rule pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore axis_index_p.def_abstract_eval( lambda *args, **params: ShapedArray((), np.int32)) pxla.multi_host_supported_collectives.add(axis_index_p) # Axis index doesn't get any arguments, so that the default bind would have no # way to call into a data-dependency based trace such as vmap. Each trace that # wants to bind an axis name has to additionally implement `process_axis_index` # and put its main trace on the axis env stack. def _axis_index_bind(*, axis_name): if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) inner_size = 1 index = 0 for name in reversed(axis_name): frame = core.axis_frame(name) if frame.main_trace is not None: trace = frame.main_trace.with_cur_sublevel() name_idx = trace.process_axis_index(frame) else: name_idx = core.Primitive.bind(axis_index_p, axis_name=name) index += name_idx * inner_size inner_size *= psum(1, name) return index axis_index_p.def_custom_bind(_axis_index_bind) def _process_axis_index(self, frame): return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0) batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore @config.register_omnistaging_disabler def omnistaging_disabler() -> None: global axis_index psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1] nreps = dynamic_axis_env.nreps trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer def _axis_index_translation_rule(c, nreps, sizes, axis_name): div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval( lambda *args, **params: ShapedArray((), np.int32)) xla.translations[axis_index_p] = _axis_index_translation_rule