Warning
This page was created from a pull request.
jax.lax.all_gatherΒΆ
-
jax.lax.
all_gather
(x, axis_name, *, axis_index_groups=None)[source]ΒΆ Gather values of x across all replicas.
If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.This is equivalent to, but faster than, all_to_all(broadcast(x)).
- Parameters
x β array(s) with a mapped axis named
axis_name
.axis_name β hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).axis_index_groups β optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
- Returns
Array(s) representing the result of an all-gather along the axis
axis_name
. Shapes are the same asx.shape
, but with a leading dimension of the axis_size.
For example, with 4 XLA devices available:
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
An example of using axis_index_groups, groups split by even & odd device ids:
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.]] >>> y = jax.pmap(lambda x: jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x) >>> print(y) [[[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] [[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]]