Warning
This page was created from a pull request.
jax.random.multivariate_normalΒΆ
-
jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>)[source]ΒΆ Sample multivariate normal random values with given mean and covariance.
- Parameters
key (
ndarray) β a PRNGKey used as the random key.mean (
ndarray) β a mean vector of shape(..., n).cov (
ndarray) β a positive definite covariance matrix of shape(..., n, n). The batch shape...must be broadcast-compatible with that ofmean.shape (
Optional[Sequence[int]]) β optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible withmean.shape[:-1]andcov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmeanandcov.dtype (
dtype) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Return type
ndarray- Returns
A random array with the specified dtype and shape given by
shape + mean.shape[-1:]ifshapeis not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].