Warning
This page was created from a pull request.
jax.lax.pswapaxesΒΆ
-
jax.lax.
pswapaxes
(x, axis_name, axis)[source]ΒΆ Swap the pmapped axis
axis_name
with the unmapped axisaxis
.If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.The mapped axis size must be equal to the size of the unmapped axis; that is, we must have
lax.psum(1, axis_name) == x.shape[axis]
.This function is a special case of
all_to_all
where the pmapped axis of the input is placed at the positionaxis
in the output. That is, it is equivalent toall_to_all(x, axis_name, axis, axis)
.- Parameters
x β array(s) with a mapped axis named
axis_name
.axis_name β hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).axis β int indicating the unmapped axis of
x
to map with the nameaxis_name
.
- Returns
Array(s) with the same shape as
x
.