Warning
This page was created from a pull request.
jax.lax.pmeanΒΆ
-
jax.lax.
pmean
(x, axis_name, *, axis_index_groups=None)[source]ΒΆ Compute an all-reduce mean on
x
over the pmapped axisaxis_name
.If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.- Parameters
x β array(s) with a mapped axis named
axis_name
.axis_name β hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).axis_index_groups β optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
- Returns
Array(s) with the same shape as
x
representing the result of an all-reduce mean along the axisaxis_name
.
For example, with 4 XLA devices available:
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [ 1.5 1.5 1.5 1.5 ] >>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.66666667 1.33333334 2.0 ]