4675
Tutorials
JAX Quickstart
The Autodiff Cookbook
Autobatching log-densities example
Training a Simple Neural Network, with Tensorflow Datasets Data Loading
Advanced JAX Tutorials
🔪 JAX - The Sharp Bits 🔪
Custom derivative rules for JAX-transformable Python functions
How JAX primitives work
Writing custom Jaxpr interpreters in JAX
Notes
Change Log
JAX Frequently Asked Questions (FAQ)
Understanding Jaxprs
Asynchronous dispatch
Concurrency
GPU memory allocation
Profiling JAX programs
Device Memory Profiling
Pytrees
Rank promotion warning
Type promotion semantics
Developer documentation
Building from source
Running the tests
Type checking
Update documentation
Internal APIs
API documentation
Public API: jax package
JAX
Docs
»
Index
Edit on GitHub
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
|
X
|
Z
_
__init__() (jax.core.ClosedJaxpr method)
(jax.core.Jaxpr method)
(jax.numpy.bool_ method)
(jax.numpy.character method)
(jax.numpy.complex128 method)
(jax.numpy.complex64 method)
(jax.numpy.complexfloating method)
(jax.numpy.dtype method)
(jax.numpy.flexible method)
(jax.numpy.float16 method)
(jax.numpy.float32 method)
(jax.numpy.float64 method)
(jax.numpy.floating method)
(jax.numpy.iinfo method)
(jax.numpy.inexact method)
(jax.numpy.int16 method)
(jax.numpy.int32 method)
(jax.numpy.int64 method)
(jax.numpy.int8 method)
(jax.numpy.integer method)
(jax.numpy.ndarray method)
(jax.numpy.number method)
(jax.numpy.object_ method)
(jax.numpy.signedinteger method)
(jax.numpy.uint16 method)
(jax.numpy.uint32 method)
(jax.numpy.uint64 method)
(jax.numpy.uint8 method)
(jax.numpy.unsignedinteger method)
(jax.profiler.TraceContext method)
A
abs() (in module jax.lax)
(in module jax.numpy)
absolute() (in module jax.numpy)
acos() (in module jax.lax)
adagrad() (in module jax.experimental.optimizers)
adam() (in module jax.experimental.optimizers)
adamax() (in module jax.experimental.optimizers)
add() (in module jax.lax)
(in module jax.numpy)
all() (in module jax.numpy)
all_gather() (in module jax.lax)
all_leaves() (in module jax.tree_util)
all_to_all() (in module jax.lax)
allclose() (in module jax.numpy)
alltrue() (in module jax.numpy)
amax() (in module jax.numpy)
amin() (in module jax.numpy)
angle() (in module jax.numpy)
any() (in module jax.numpy)
append() (in module jax.numpy)
apply_along_axis() (in module jax.numpy)
apply_over_axes() (in module jax.numpy)
arange() (in module jax.numpy)
arccos() (in module jax.numpy)
arccosh() (in module jax.numpy)
arcsin() (in module jax.numpy)
arcsinh() (in module jax.numpy)
arctan() (in module jax.numpy)
arctan2() (in module jax.numpy)
arctanh() (in module jax.numpy)
argmax() (in module jax.lax)
(in module jax.numpy)
argmin() (in module jax.lax)
(in module jax.numpy)
argsort() (in module jax.numpy)
argwhere() (in module jax.numpy)
around() (in module jax.numpy)
array() (in module jax.numpy)
array_equal() (in module jax.numpy)
array_equiv() (in module jax.numpy)
array_repr() (in module jax.numpy)
array_split() (in module jax.numpy)
array_str() (in module jax.numpy)
asarray() (in module jax.numpy)
asin() (in module jax.lax)
associative_scan() (in module jax.lax)
atan() (in module jax.lax)
atan2() (in module jax.lax)
atleast_1d() (in module jax.numpy)
atleast_2d() (in module jax.numpy)
atleast_3d() (in module jax.numpy)
average() (in module jax.numpy)
AvgPool() (in module jax.experimental.stax)
axis_index() (in module jax.lax)
B
bartlett() (in module jax.numpy)
batch_matmul() (in module jax.lax)
BatchNorm() (in module jax.experimental.stax)
bernoulli() (in module jax.random)
bessel_i0e() (in module jax.lax)
bessel_i1e() (in module jax.lax)
beta() (in module jax.random)
betainc() (in module jax.lax)
(in module jax.scipy.special)
bincount() (in module jax.numpy)
bitcast_convert_type() (in module jax.lax)
bits (jax.numpy.iinfo attribute)
bitwise_and() (in module jax.lax)
(in module jax.numpy)
bitwise_not() (in module jax.lax)
(in module jax.numpy)
bitwise_or() (in module jax.lax)
(in module jax.numpy)
bitwise_xor() (in module jax.lax)
(in module jax.numpy)
blackman() (in module jax.numpy)
block() (in module jax.numpy)
block_diag() (in module jax.scipy.linalg)
bool_ (class in jax.numpy)
broadcast() (in module jax.lax)
broadcast_arrays() (in module jax.numpy)
broadcast_in_dim() (in module jax.lax)
broadcast_to() (in module jax.numpy)
broadcasted_iota() (in module jax.lax)
build_tree() (in module jax.tree_util)
C
can_cast() (in module jax.numpy)
categorical() (in module jax.random)
cauchy() (in module jax.random)
cbrt() (in module jax.numpy)
cdf() (in module jax.scipy.stats.laplace)
(in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.norm)
cdouble (in module jax.numpy)
ceil() (in module jax.lax)
(in module jax.numpy)
celu() (in module jax.nn)
cg() (in module jax.scipy.sparse.linalg)
character (class in jax.numpy)
checkpoint() (in module jax)
cho_factor() (in module jax.scipy.linalg)
cho_solve() (in module jax.scipy.linalg)
choice() (in module jax.random)
cholesky() (in module jax.lax.linalg)
(in module jax.numpy.linalg)
(in module jax.scipy.linalg)
choose() (in module jax.numpy)
clamp() (in module jax.lax)
clip() (in module jax.numpy)
clip_grads() (in module jax.experimental.optimizers)
ClosedJaxpr (class in jax.core)
collapse() (in module jax.lax)
column_stack() (in module jax.numpy)
complex() (in module jax.lax)
complex128 (class in jax.numpy)
complex64 (class in jax.numpy)
complex_ (in module jax.numpy)
complexfloating (class in jax.numpy)
ComplexWarning
compress() (in module jax.numpy)
concatenate() (in module jax.lax)
(in module jax.numpy)
cond() (in module jax.lax)
(in module jax.numpy.linalg)
cond_range() (jax.experimental.loops.Scope method)
conj() (in module jax.lax)
(in module jax.numpy)
conjugate() (in module jax.numpy)
constant() (in module jax.experimental.optimizers)
Conv() (in module jax.experimental.stax)
conv() (in module jax.lax)
Conv1DTranspose() (in module jax.experimental.stax)
conv_general_dilated() (in module jax.lax)
conv_general_dilated_patches() (in module jax.lax)
conv_transpose() (in module jax.lax)
conv_with_general_padding() (in module jax.lax)
convert_element_type() (in module jax.lax)
convolve() (in module jax.numpy)
(in module jax.scipy.signal)
convolve2d() (in module jax.scipy.signal)
ConvTranspose() (in module jax.experimental.stax)
copysign() (in module jax.numpy)
corrcoef() (in module jax.numpy)
correlate() (in module jax.numpy)
(in module jax.scipy.signal)
correlate2d() (in module jax.scipy.signal)
cos() (in module jax.lax)
(in module jax.numpy)
cosh() (in module jax.lax)
(in module jax.numpy)
count_nonzero() (in module jax.numpy)
cov() (in module jax.numpy)
cross() (in module jax.numpy)
csingle (in module jax.numpy)
cummax() (in module jax.lax)
cummin() (in module jax.lax)
cumprod() (in module jax.lax)
(in module jax.numpy)
cumproduct() (in module jax.numpy)
cumsum() (in module jax.lax)
(in module jax.numpy)
custom_jvp (class in jax)
custom_linear_solve() (in module jax.lax)
custom_root() (in module jax.lax)
custom_vjp (class in jax)
D
defjvp() (jax.custom_jvp method)
defjvps() (jax.custom_jvp method)
defvjp() (jax.custom_vjp method)
deg2rad() (in module jax.numpy)
degrees() (in module jax.numpy)
Dense() (in module jax.experimental.stax)
det (in module jax.numpy.linalg)
det() (in module jax.scipy.linalg)
device_count() (in module jax)
device_memory_profile() (in module jax.profiler)
device_put() (in module jax)
devices() (in module jax)
diag() (in module jax.numpy)
diag_indices() (in module jax.numpy)
diag_indices_from() (in module jax.numpy)
diagflat() (in module jax.numpy)
diagonal() (in module jax.numpy)
diff() (in module jax.numpy)
digamma() (in module jax.lax)
(in module jax.scipy.special)
digitize() (in module jax.numpy)
dirichlet() (in module jax.random)
disable_jit() (in module jax)
div() (in module jax.lax)
divide() (in module jax.numpy)
divmod() (in module jax.numpy)
dot() (in module jax.lax)
(in module jax.numpy)
dot_general() (in module jax.lax)
double (in module jax.numpy)
double_sided_maxwell() (in module jax.random)
Dropout() (in module jax.experimental.stax)
dsplit() (in module jax.numpy)
dstack() (in module jax.numpy)
dtype (class in jax.numpy)
dynamic_index_in_dim() (in module jax.lax)
dynamic_slice() (in module jax.lax)
dynamic_slice_in_dim() (in module jax.lax)
dynamic_update_index_in_dim() (in module jax.lax)
dynamic_update_slice() (in module jax.lax)
dynamic_update_slice_in_dim() (in module jax.lax)
E
ediff1d() (in module jax.numpy)
eig() (in module jax.lax.linalg)
(in module jax.numpy.linalg)
eigh() (in module jax.lax.linalg)
(in module jax.numpy.linalg)
(in module jax.scipy.linalg)
eigvals() (in module jax.numpy.linalg)
eigvalsh() (in module jax.numpy.linalg)
einsum() (in module jax.numpy)
einsum_path() (in module jax.numpy)
elementwise() (in module jax.experimental.stax)
elu() (in module jax.nn)
empty() (in module jax.numpy)
empty_like() (in module jax.numpy)
entr() (in module jax.scipy.special)
eq() (in module jax.lax)
equal() (in module jax.numpy)
erf() (in module jax.lax)
(in module jax.scipy.special)
erf_inv() (in module jax.lax)
erfc() (in module jax.lax)
(in module jax.scipy.special)
erfinv() (in module jax.scipy.special)
eval_shape() (in module jax)
exp() (in module jax.lax)
(in module jax.numpy)
exp2() (in module jax.numpy)
expand_dims() (in module jax.lax)
(in module jax.numpy)
expit (in module jax.scipy.special)
expm() (in module jax.scipy.linalg)
expm1() (in module jax.lax)
(in module jax.numpy)
expm_frechet() (in module jax.scipy.linalg)
exponential() (in module jax.random)
exponential_decay() (in module jax.experimental.optimizers)
extract() (in module jax.numpy)
eye() (in module jax.numpy)
F
fabs() (in module jax.numpy)
FanInConcat() (in module jax.experimental.stax)
FanOut() (in module jax.experimental.stax)
fft() (in module jax.lax)
(in module jax.numpy.fft)
fft2() (in module jax.numpy.fft)
fftfreq() (in module jax.numpy.fft)
fftn() (in module jax.numpy.fft)
fftshift() (in module jax.numpy.fft)
finfo() (in module jax.numpy)
fix() (in module jax.numpy)
flatnonzero() (in module jax.numpy)
flexible (class in jax.numpy)
flip() (in module jax.numpy)
fliplr() (in module jax.numpy)
flipud() (in module jax.numpy)
float16 (class in jax.numpy)
float32 (class in jax.numpy)
float64 (class in jax.numpy)
float_ (in module jax.numpy)
float_power() (in module jax.numpy)
floating (class in jax.numpy)
floor() (in module jax.lax)
(in module jax.numpy)
floor_divide() (in module jax.numpy)
fmax() (in module jax.numpy)
fmin() (in module jax.numpy)
fmod() (in module jax.numpy)
fold_in() (in module jax.random)
fori_loop() (in module jax.lax)
frexp() (in module jax.numpy)
full() (in module jax.lax)
(in module jax.numpy)
full_like() (in module jax.lax)
(in module jax.numpy)
G
gamma() (in module jax.random)
gammainc() (in module jax.scipy.special)
gammaincc() (in module jax.scipy.special)
gammaln() (in module jax.scipy.special)
gather() (in module jax.lax)
gcd() (in module jax.numpy)
ge() (in module jax.lax)
gelu() (in module jax.nn)
GeneralConv() (in module jax.experimental.stax)
GeneralConvTranspose() (in module jax.experimental.stax)
geomspace() (in module jax.numpy)
glorot_normal() (in module jax.nn.initializers)
glorot_uniform() (in module jax.nn.initializers)
glu() (in module jax.nn)
grad() (in module jax)
gradient() (in module jax.numpy)
greater() (in module jax.numpy)
greater_equal() (in module jax.numpy)
gt() (in module jax.lax)
gumbel() (in module jax.random)
H
hamming() (in module jax.numpy)
hanning() (in module jax.numpy)
hard_sigmoid() (in module jax.nn)
hard_silu() (in module jax.nn)
hard_swish() (in module jax.nn)
hard_tanh() (in module jax.nn)
he_normal() (in module jax.nn.initializers)
he_uniform() (in module jax.nn.initializers)
heaviside() (in module jax.numpy)
hessian() (in module jax)
hfft() (in module jax.numpy.fft)
histogram() (in module jax.numpy)
histogram2d() (in module jax.numpy)
histogram_bin_edges() (in module jax.numpy)
histogramdd() (in module jax.numpy)
host_count() (in module jax)
host_id() (in module jax)
host_ids() (in module jax)
hsplit() (in module jax.numpy)
hstack() (in module jax.numpy)
hypot() (in module jax.numpy)
I
i0() (in module jax.numpy)
(in module jax.scipy.special)
i0e() (in module jax.scipy.special)
i1() (in module jax.scipy.special)
i1e() (in module jax.scipy.special)
id_print() (in module jax.experimental.host_callback)
id_tap() (in module jax.experimental.host_callback)
identity() (in module jax.numpy)
ifft() (in module jax.numpy.fft)
ifft2() (in module jax.numpy.fft)
ifftn() (in module jax.numpy.fft)
ifftshift() (in module jax.numpy.fft)
igamma() (in module jax.lax)
igammac() (in module jax.lax)
ihfft() (in module jax.numpy.fft)
iinfo (class in jax.numpy)
imag() (in module jax.lax)
(in module jax.numpy)
in1d() (in module jax.numpy)
index (in module jax.ops)
index_add() (in module jax.ops)
index_in_dim() (in module jax.lax)
index_max() (in module jax.ops)
index_min() (in module jax.ops)
index_mul() (in module jax.ops)
index_take() (in module jax.lax)
index_update() (in module jax.ops)
indices() (in module jax.numpy)
inexact (class in jax.numpy)
init_fn() (jax.experimental.optimizers.Optimizer property)
inner() (in module jax.numpy)
int16 (class in jax.numpy)
int32 (class in jax.numpy)
int64 (class in jax.numpy)
int8 (class in jax.numpy)
int_ (in module jax.numpy)
integer (class in jax.numpy)
interp() (in module jax.numpy)
intersect1d() (in module jax.numpy)
inv() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
inverse_time_decay() (in module jax.experimental.optimizers)
invert() (in module jax.numpy)
iota() (in module jax.lax)
irfft() (in module jax.numpy.fft)
irfft2() (in module jax.numpy.fft)
irfftn() (in module jax.numpy.fft)
is_finite() (in module jax.lax)
isclose() (in module jax.numpy)
iscomplex() (in module jax.numpy)
iscomplexobj() (in module jax.numpy)
isf() (in module jax.scipy.stats.logistic)
isfinite() (in module jax.numpy)
isin() (in module jax.numpy)
isinf() (in module jax.numpy)
isnan() (in module jax.numpy)
isneginf() (in module jax.numpy)
isposinf() (in module jax.numpy)
isreal() (in module jax.numpy)
isrealobj() (in module jax.numpy)
isscalar() (in module jax.numpy)
issubdtype() (in module jax.numpy)
issubsctype() (in module jax.numpy)
iterable() (in module jax.numpy)
ix_() (in module jax.numpy)
J
jacfwd() (in module jax)
jacrev() (in module jax)
jax.core (module)
jax.dlpack (module)
jax.experimental (module)
jax.experimental.host_callback (module)
jax.experimental.loops (module)
jax.experimental.optimizers (module)
jax.experimental.stax (module)
jax.image (module)
jax.lax (module)
jax.lax.linalg (module)
jax.nn (module)
jax.nn.initializers (module)
jax.numpy (module)
jax.numpy.fft (module)
jax.numpy.linalg (module)
jax.ops (module)
jax.profiler (module)
jax.random (module)
jax.scipy.linalg (module)
jax.scipy.ndimage (module)
jax.scipy.signal (module)
jax.scipy.sparse.linalg (module)
jax.scipy.special (module)
jax.scipy.stats.bernoulli (module)
jax.scipy.stats.beta (module)
jax.scipy.stats.cauchy (module)
jax.scipy.stats.dirichlet (module)
jax.scipy.stats.expon (module)
jax.scipy.stats.gamma (module)
jax.scipy.stats.geom (module)
jax.scipy.stats.laplace (module)
jax.scipy.stats.logistic (module)
jax.scipy.stats.multivariate_normal (module)
jax.scipy.stats.norm (module)
jax.scipy.stats.pareto (module)
jax.scipy.stats.poisson (module)
jax.scipy.stats.t (module)
jax.scipy.stats.uniform (module)
jax.tree_util (module)
Jaxpr (class in jax.core)
jit() (in module jax)
JoinPoint (class in jax.experimental.optimizers)
jvp() (in module jax)
K
kaiser() (in module jax.numpy)
kron() (in module jax.numpy)
L
l2_norm() (in module jax.experimental.optimizers)
laplace() (in module jax.random)
lcm() (in module jax.numpy)
ldexp() (in module jax.numpy)
le() (in module jax.lax)
leaky_relu() (in module jax.nn)
lecun_normal() (in module jax.nn.initializers)
lecun_uniform() (in module jax.nn.initializers)
left_shift() (in module jax.numpy)
less() (in module jax.numpy)
less_equal() (in module jax.numpy)
lexsort() (in module jax.numpy)
lgamma() (in module jax.lax)
linear_transpose() (in module jax)
linearize() (in module jax)
linspace() (in module jax.numpy)
load() (in module jax.numpy)
local_device_count() (in module jax)
local_devices() (in module jax)
log() (in module jax.lax)
(in module jax.numpy)
log10() (in module jax.numpy)
log1p() (in module jax.lax)
(in module jax.numpy)
log2() (in module jax.numpy)
log_ndtr (in module jax.scipy.special)
log_sigmoid() (in module jax.nn)
log_softmax() (in module jax.nn)
logaddexp (in module jax.numpy)
logaddexp2 (in module jax.numpy)
logcdf() (in module jax.scipy.stats.norm)
logical_and() (in module jax.numpy)
logical_not() (in module jax.numpy)
logical_or() (in module jax.numpy)
logical_xor() (in module jax.numpy)
logistic() (in module jax.random)
logit (in module jax.scipy.special)
logpdf() (in module jax.scipy.stats.beta)
(in module jax.scipy.stats.cauchy)
(in module jax.scipy.stats.dirichlet)
(in module jax.scipy.stats.expon)
(in module jax.scipy.stats.gamma)
(in module jax.scipy.stats.laplace)
(in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.multivariate_normal)
(in module jax.scipy.stats.norm)
(in module jax.scipy.stats.pareto)
(in module jax.scipy.stats.t)
(in module jax.scipy.stats.uniform)
logpmf() (in module jax.scipy.stats.bernoulli)
(in module jax.scipy.stats.geom)
(in module jax.scipy.stats.poisson)
logspace() (in module jax.numpy)
logsumexp() (in module jax.nn)
(in module jax.scipy.special)
lstsq() (in module jax.numpy.linalg)
lt() (in module jax.lax)
lu() (in module jax.lax.linalg)
(in module jax.scipy.linalg)
lu_factor() (in module jax.scipy.linalg)
lu_solve() (in module jax.scipy.linalg)
M
make_jaxpr() (in module jax)
make_schedule() (in module jax.experimental.optimizers)
map() (in module jax.lax)
map_coordinates() (in module jax.scipy.ndimage)
mask_indices() (in module jax.numpy)
matmul() (in module jax.numpy)
matrix_power() (in module jax.numpy.linalg)
matrix_rank() (in module jax.numpy.linalg)
max (jax.numpy.iinfo attribute)
max() (in module jax.lax)
(in module jax.numpy)
maximum() (in module jax.numpy)
MaxPool() (in module jax.experimental.stax)
maxwell() (in module jax.random)
mean() (in module jax.numpy)
median() (in module jax.numpy)
meshgrid() (in module jax.numpy)
min (jax.numpy.iinfo attribute)
min() (in module jax.lax)
(in module jax.numpy)
minimum() (in module jax.numpy)
mod() (in module jax.numpy)
modf() (in module jax.numpy)
momentum() (in module jax.experimental.optimizers)
moveaxis() (in module jax.numpy)
msort() (in module jax.numpy)
mul() (in module jax.lax)
multi_dot() (in module jax.numpy.linalg)
multigammaln() (in module jax.scipy.special)
multiply() (in module jax.numpy)
multivariate_normal() (in module jax.random)
N
nan_to_num() (in module jax.numpy)
nanargmax() (in module jax.numpy)
nanargmin() (in module jax.numpy)
nancumprod() (in module jax.numpy)
nancumsum() (in module jax.numpy)
nanmax() (in module jax.numpy)
nanmean() (in module jax.numpy)
nanmedian() (in module jax.numpy)
nanmin() (in module jax.numpy)
nanpercentile() (in module jax.numpy)
nanprod() (in module jax.numpy)
nanquantile() (in module jax.numpy)
nanstd() (in module jax.numpy)
nansum() (in module jax.numpy)
nanvar() (in module jax.numpy)
ndarray (class in jax.numpy)
ndim() (in module jax.numpy)
ndtr() (in module jax.scipy.special)
ndtri() (in module jax.scipy.special)
ne() (in module jax.lax)
neg() (in module jax.lax)
negative() (in module jax.numpy)
nesterov() (in module jax.experimental.optimizers)
nextafter() (in module jax.lax)
(in module jax.numpy)
nonzero() (in module jax.numpy)
norm() (in module jax.numpy.linalg)
normal() (in module jax.nn.initializers)
(in module jax.random)
normalize() (in module jax.nn)
not_equal() (in module jax.numpy)
number (class in jax.numpy)
O
object_ (class in jax.numpy)
one_hot() (in module jax.nn)
ones() (in module jax.nn.initializers)
(in module jax.numpy)
ones_like() (in module jax.numpy)
Optimizer (class in jax.experimental.optimizers)
optimizer() (in module jax.experimental.optimizers)
OptimizerState (class in jax.experimental.optimizers)
outer() (in module jax.numpy)
outfeed_receiver() (in module jax.experimental.host_callback)
P
pack_optimizer_state() (in module jax.experimental.optimizers)
packbits() (in module jax.numpy)
packed_state() (jax.experimental.optimizers.OptimizerState property)
pad() (in module jax.lax)
(in module jax.numpy)
parallel() (in module jax.experimental.stax)
params_fn() (jax.experimental.optimizers.Optimizer property)
pareto() (in module jax.random)
Partial (class in jax.tree_util)
pdf() (in module jax.scipy.stats.beta)
(in module jax.scipy.stats.cauchy)
(in module jax.scipy.stats.dirichlet)
(in module jax.scipy.stats.expon)
(in module jax.scipy.stats.gamma)
(in module jax.scipy.stats.laplace)
(in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.multivariate_normal)
(in module jax.scipy.stats.norm)
(in module jax.scipy.stats.pareto)
(in module jax.scipy.stats.t)
(in module jax.scipy.stats.uniform)
percentile() (in module jax.numpy)
permutation() (in module jax.random)
piecewise() (in module jax.numpy)
piecewise_constant() (in module jax.experimental.optimizers)
pinv (in module jax.numpy.linalg)
pmap() (in module jax)
pmax() (in module jax.lax)
pmean() (in module jax.lax)
pmf() (in module jax.scipy.stats.bernoulli)
(in module jax.scipy.stats.geom)
(in module jax.scipy.stats.poisson)
pmin() (in module jax.lax)
poisson() (in module jax.random)
polyadd() (in module jax.numpy)
polyder() (in module jax.numpy)
polygamma() (in module jax.scipy.special)
polymul() (in module jax.numpy)
polynomial_decay() (in module jax.experimental.optimizers)
polysub() (in module jax.numpy)
polyval() (in module jax.numpy)
population_count() (in module jax.lax)
positive() (in module jax.numpy)
pow() (in module jax.lax)
power() (in module jax.numpy)
ppermute() (in module jax.lax)
ppf() (in module jax.scipy.stats.logistic)
(in module jax.scipy.stats.norm)
PRNGKey() (in module jax.random)
prod() (in module jax.numpy)
product() (in module jax.numpy)
promote_types() (in module jax.numpy)
pshuffle() (in module jax.lax)
psum() (in module jax.lax)
pswapaxes() (in module jax.lax)
ptp() (in module jax.numpy)
Q
qr() (in module jax.lax.linalg)
(in module jax.numpy.linalg)
(in module jax.scipy.linalg)
quantile() (in module jax.numpy)
R
rad2deg() (in module jax.numpy)
rademacher() (in module jax.random)
radians() (in module jax.numpy)
randint() (in module jax.random)
range() (jax.experimental.loops.Scope method)
ravel() (in module jax.numpy)
ravel_multi_index() (in module jax.numpy)
real() (in module jax.lax)
(in module jax.numpy)
reciprocal() (in module jax.lax)
(in module jax.numpy)
reduce() (in module jax.lax)
reduce_window() (in module jax.lax)
register_pytree_node() (in module jax.tree_util)
register_pytree_node_class() (in module jax.tree_util)
relu (in module jax.nn)
relu6() (in module jax.nn)
rem() (in module jax.lax)
remainder() (in module jax.numpy)
repeat() (in module jax.numpy)
reshape() (in module jax.lax)
(in module jax.numpy)
resize() (in module jax.image)
result_type() (in module jax.numpy)
rev() (in module jax.lax)
rfft() (in module jax.numpy.fft)
rfft2() (in module jax.numpy.fft)
rfftfreq() (in module jax.numpy.fft)
rfftn() (in module jax.numpy.fft)
right_shift() (in module jax.numpy)
rint() (in module jax.numpy)
rmsprop() (in module jax.experimental.optimizers)
rmsprop_momentum() (in module jax.experimental.optimizers)
roll() (in module jax.numpy)
rollaxis() (in module jax.numpy)
roots() (in module jax.numpy)
rot90() (in module jax.numpy)
round() (in module jax.lax)
(in module jax.numpy)
row_stack() (in module jax.numpy)
rsqrt() (in module jax.lax)
S
save() (in module jax.numpy)
save_device_memory_profile() (in module jax.profiler)
savez() (in module jax.numpy)
scale_and_translate() (in module jax.image)
scan() (in module jax.lax)
scatter() (in module jax.lax)
scatter_add() (in module jax.lax)
Scope (class in jax.experimental.loops)
searchsorted() (in module jax.numpy)
segment_sum() (in module jax.ops)
select() (in module jax.lax)
(in module jax.numpy)
selu() (in module jax.nn)
serial() (in module jax.experimental.stax)
set_printoptions() (in module jax.numpy)
setdiff1d() (in module jax.numpy)
sf() (in module jax.scipy.stats.logistic)
sgd() (in module jax.experimental.optimizers)
shape() (in module jax.numpy)
shape_dependent() (in module jax.experimental.stax)
shift_left() (in module jax.lax)
shift_right_arithmetic() (in module jax.lax)
shift_right_logical() (in module jax.lax)
shuffle() (in module jax.random)
sigmoid() (in module jax.nn)
sign() (in module jax.lax)
(in module jax.numpy)
signbit() (in module jax.numpy)
signedinteger (class in jax.numpy)
silu() (in module jax.nn)
sin() (in module jax.lax)
(in module jax.numpy)
sinc() (in module jax.numpy)
single (in module jax.numpy)
sinh() (in module jax.lax)
(in module jax.numpy)
size() (in module jax.numpy)
slice() (in module jax.lax)
slice_in_dim() (in module jax.lax)
slogdet (in module jax.numpy.linalg)
sm3() (in module jax.experimental.optimizers)
soft_sign() (in module jax.nn)
softmax() (in module jax.nn)
softplus() (in module jax.nn)
solve() (in module jax.numpy.linalg)
(in module jax.scipy.linalg)
solve_triangular() (in module jax.scipy.linalg)
sometrue() (in module jax.numpy)
sort() (in module jax.lax)
(in module jax.numpy)
sort_complex() (in module jax.numpy)
sort_key_val() (in module jax.lax)
split() (in module jax.numpy)
(in module jax.random)
sqrt() (in module jax.lax)
(in module jax.numpy)
square() (in module jax.lax)
(in module jax.numpy)
squeeze() (in module jax.lax)
(in module jax.numpy)
stack() (in module jax.numpy)
start_server() (in module jax.profiler)
start_subtrace() (jax.experimental.loops.Scope method)
std() (in module jax.numpy)
stop_gradient() (in module jax.lax)
sub() (in module jax.lax)
subtract() (in module jax.numpy)
subtree_defs() (jax.experimental.optimizers.OptimizerState property)
sum() (in module jax.numpy)
SumPool() (in module jax.experimental.stax)
svd() (in module jax.lax.linalg)
(in module jax.numpy.linalg)
(in module jax.scipy.linalg)
swapaxes() (in module jax.numpy)
swish() (in module jax.nn)
switch() (in module jax.lax)
T
t() (in module jax.random)
take() (in module jax.numpy)
take_along_axis() (in module jax.numpy)
tan() (in module jax.lax)
(in module jax.numpy)
tanh() (in module jax.numpy)
TapFunctionException
tensordot() (in module jax.numpy)
tensorinv() (in module jax.numpy.linalg)
tensorsolve() (in module jax.numpy.linalg)
tie_in() (in module jax.lax)
tile() (in module jax.numpy)
top_k() (in module jax.lax)
trace() (in module jax.numpy)
trace_function() (in module jax.profiler)
TraceContext (class in jax.profiler)
transpose() (in module jax.lax)
(in module jax.numpy)
trapz() (in module jax.numpy)
tree_all() (in module jax.tree_util)
tree_def() (jax.experimental.optimizers.OptimizerState property)
tree_flatten() (in module jax.tree_util)
tree_leaves() (in module jax.tree_util)
tree_map() (in module jax.tree_util)
tree_multimap() (in module jax.tree_util)
tree_reduce() (in module jax.tree_util)
tree_structure() (in module jax.tree_util)
tree_transpose() (in module jax.tree_util)
tree_unflatten() (in module jax.tree_util)
treedef_children() (in module jax.tree_util)
treedef_is_leaf() (in module jax.tree_util)
treedef_tuple() (in module jax.tree_util)
tri() (in module jax.numpy)
triangular_solve() (in module jax.lax.linalg)
tril() (in module jax.numpy)
(in module jax.scipy.linalg)
tril_indices() (in module jax.numpy)
tril_indices_from() (in module jax.numpy)
trim_zeros() (in module jax.numpy)
triu() (in module jax.numpy)
(in module jax.scipy.linalg)
triu_indices() (in module jax.numpy)
triu_indices_from() (in module jax.numpy)
true_divide() (in module jax.numpy)
trunc() (in module jax.numpy)
truncated_normal() (in module jax.random)
U
uint16 (class in jax.numpy)
uint32 (class in jax.numpy)
uint64 (class in jax.numpy)
uint8 (class in jax.numpy)
uniform() (in module jax.nn.initializers)
(in module jax.random)
unique() (in module jax.numpy)
unpack_optimizer_state() (in module jax.experimental.optimizers)
unpackbits() (in module jax.numpy)
unravel_index() (in module jax.numpy)
unsignedinteger (class in jax.numpy)
unwrap() (in module jax.numpy)
update_fn() (jax.experimental.optimizers.Optimizer property)
V
value_and_grad() (in module jax)
vander() (in module jax.numpy)
var() (in module jax.numpy)
variance_scaling() (in module jax.nn.initializers)
vdot() (in module jax.numpy)
vectorize() (in module jax.numpy)
,
[1]
vjp() (in module jax)
vmap() (in module jax)
vsplit() (in module jax.numpy)
vstack() (in module jax.numpy)
W
weibull_min() (in module jax.random)
where() (in module jax.numpy)
while_loop() (in module jax.lax)
while_range() (jax.experimental.loops.Scope method)
X
xla_computation() (in module jax)
xlog1py() (in module jax.scipy.special)
xlogy() (in module jax.scipy.special)
Z
zeros() (in module jax.nn.initializers)
(in module jax.numpy)
zeros_like() (in module jax.numpy)
zeta() (in module jax.scipy.special)