Warning
This page was created from a pull request.
jax.random.split¶
-
jax.random.
split
(key, num=2)[source]¶ Splits a PRNG key into num new keys by adding a leading axis.
- Parameters
key (
ndarray
) – a PRNGKey (an array with shape (2,) and dtype uint32).num (
int
) – optional, a positive integer indicating the number of keys to produce (default 2).
- Return type
ndarray
- Returns
An array with shape (num, 2) and dtype uint32 representing num new keys.