Warning
This page was created from a pull request.
jax.random.categoricalΒΆ
-
jax.random.
categorical
(key, logits, axis=-1, shape=None)[source]ΒΆ Sample random values from categorical distributions.
- Parameters
key β a PRNGKey used as the random key.
logits β Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.
axis β Axis along which logits belong to the same categorical distribution.
shape β Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with
np.delete(logits.shape, axis)
. The default (None) produces a result shape equal tonp.delete(logits.shape, axis)
.
- Returns
A random array with int dtype and shape given by
shape
ifshape
is not None, or elsenp.delete(logits.shape, axis)
.