Warning
This page was created from a pull request.
jax.lax.pswapaxesΒΆ
-
jax.lax.pswapaxes(x, axis_name, axis)[source]ΒΆ Swap the pmapped axis
axis_namewith the unmapped axisaxis.If
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.The mapped axis size must be equal to the size of the unmapped axis; that is, we must have
lax.psum(1, axis_name) == x.shape[axis].This function is a special case of
all_to_allwhere the pmapped axis of the input is placed at the positionaxisin the output. That is, it is equivalent toall_to_all(x, axis_name, axis, axis).- Parameters
x β array(s) with a mapped axis named
axis_name.axis_name β hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).axis β int indicating the unmapped axis of
xto map with the nameaxis_name.
- Returns
Array(s) with the same shape as
x.