Warning
This page was created from a pull request.
jax.numpy.nanargmax¶
-
jax.numpy.
nanargmax
(a, axis=None)[source]¶ - Return the indices of the maximum values in the specified axis ignoring
NaNs. For all-NaN slices
ValueError
is raised. Warning: the results cannot be trusted if a slice contains only NaNs and -Infs.
LAX-backend implementation of
nanargmax()
. Warning: jax.numpy.argmax returns -1 for all-NaN slices and does not raise an error.Original docstring below.
- Returns
- index_arrayndarray
An array of indices or a single index value.
argmax, nanargmin
>>> a = np.array([[np.nan, 4], [2, 3]]) >>> np.argmax(a) 0 >>> np.nanargmax(a) 1 >>> np.nanargmax(a, axis=0) array([1, 0]) >>> np.nanargmax(a, axis=1) array([1, 1])