Tutorials
Advanced JAX Tutorials
Notes
Developer documentation
API documentation
jit
vmap
pmap
Warning
This page was created from a pull request.
jax.random.
normal
Sample standard normal random values with given shape and float dtype.
key (ndarray) β a PRNGKey used as the random key.
ndarray
shape (Sequence[int]) β optional, a tuple of nonnegative integers representing the result shape. Default ().
Sequence
int
dtype (dtype) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
dtype
A random array with the specified shape and dtype.