Warning
This page was created from a pull request.
jax.lax.all_to_allΒΆ
-
jax.lax.
all_to_all
(x, axis_name, split_axis, concat_axis)[source]ΒΆ Materialize the mapped axis and map a different axis.
If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.In the output, the input mapped axis
axis_name
is materialized at the logical axis positionconcat_axis
, and the input unmapped axis at positionsplit_axis
is mapped with the nameaxis_name
.The input mapped axis size must be equal to the size of the axis to be mapped; that is, we must have
lax.psum(1, axis_name) == x.shape[split_axis]
.- Parameters
x β array(s) with a mapped axis named
axis_name
.axis_name β hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).split_axis β int indicating the unmapped axis of
x
to map with the nameaxis_name
.concat_axis β int indicating the position in the output to materialize the mapped axis of the input with the name
axis_name
.
- Returns
Array(s) with shape given by the expression:
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
where
axis_size
is the size of the mapped axis namedaxis_name
in the inputx
, i.e.axis_size = lax.psum(1, axis_name)
.