Warning
This page was created from a pull request.
jax.random.split¶
-
jax.random.split(key, num=2)[source]¶ Splits a PRNG key into num new keys by adding a leading axis.
- Parameters
key (
ndarray) – a PRNGKey (an array with shape (2,) and dtype uint32).num (
int) – optional, a positive integer indicating the number of keys to produce (default 2).
- Return type
ndarray- Returns
An array with shape (num, 2) and dtype uint32 representing num new keys.