Warning

This page was created from a pull request.

jax.numpy package

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide a alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function jax.ops.index_update().

  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion.

A small number of NumPy operations that have data-dependent output shapes are incompatible with jax.jit() compilation. The XLA compiler requires that shapes of arrays be known at compile time. While it would be possible to provide a JAX implementation of an API such as numpy.nonzero(), we would be unable to JIT-compile it because the shape of its output depends on the contents of the input data.

Not every function in NumPy is implemented; contributions are welcome!

abs(x)

Calculate the absolute value element-wise.

absolute(x)

Calculate the absolute value element-wise.

add(x1, x2)

Add arguments element-wise.

all(a[, axis, out, keepdims])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Returns True if two arrays are element-wise equal within a tolerance.

alltrue(a[, axis, out, keepdims])

Test whether all array elements along a given axis evaluate to True.

amax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

amin(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

angle(z)

Return the angle of the complex argument.

any(a[, axis, out, keepdims])

Test whether any array element along a given axis evaluates to True.

append(arr, values[, axis])

Append values to the end of an array.

apply_along_axis(func1d, axis, arr, *args, …)

Apply a function to 1-D slices along the given axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over multiple axes.

arange(start[, stop, step, dtype])

Return evenly spaced values within a given interval.

arccos(x)

Trigonometric inverse cosine, element-wise.

arccosh(x)

Inverse hyperbolic cosine, element-wise.

arcsin(x)

Inverse sine, element-wise.

arcsinh(x)

Inverse hyperbolic sine element-wise.

arctan(x)

Trigonometric inverse tangent, element-wise.

arctan2(x1, x2)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

arctanh(x)

Inverse hyperbolic tangent element-wise.

argmax(a[, axis, out])

Returns the indices of the maximum values along an axis.

argmin(a[, axis, out])

Returns the indices of the minimum values along an axis.

argsort(a[, axis, kind, order])

Returns the indices that would sort an array.

argwhere(a)

Find the indices of array elements that are non-zero, grouped by element.

around(a[, decimals, out])

Round an array to the given number of decimals.

array(object[, dtype, copy, order, ndmin])

Create an array.

array_equal(a1, a2[, equal_nan])

True if two arrays have the same shape and elements, False otherwise.

array_equiv(a1, a2)

Returns True if input arrays are shape consistent and all elements equal.

array_repr(arr[, max_line_width, precision, …])

Return the string representation of an array.

array_split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays.

array_str(a[, max_line_width, precision, …])

Return a string representation of the data in an array.

asarray(a[, dtype, order])

Convert the input to an array.

atleast_1d(*arys)

Convert inputs to arrays with at least one dimension.

atleast_2d(*arys)

View inputs as arrays with at least two dimensions.

atleast_3d(*arys)

View inputs as arrays with at least three dimensions.

average(a[, axis, weights, returned])

Compute the weighted average along the specified axis.

bartlett(*args, **kwargs)

Return the Bartlett window.

bincount(x[, weights, minlength, length])

Count number of occurrences of each value in array of non-negative ints.

bitwise_and(x1, x2)

Compute the bit-wise AND of two arrays element-wise.

bitwise_not(x)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_or(x1, x2)

Compute the bit-wise OR of two arrays element-wise.

bitwise_xor(x1, x2)

Compute the bit-wise XOR of two arrays element-wise.

blackman(*args, **kwargs)

Return the Blackman window.

block(arrays)

Assemble an nd-array from nested lists of blocks.

bool_

broadcast_arrays(*args)

Like Numpy’s broadcast_arrays but doesn’t return views.

broadcast_to(arr, shape)

Broadcast an array to a new shape.

can_cast(from_, to[, casting])

Returns True if cast between data types can occur according to the casting rule.

cbrt(x)

Return the cube-root of an array, element-wise.

cdouble

alias of jax._src.numpy.lax_numpy.complex128

ceil(x)

Return the ceiling of the input, element-wise.

character

Abstract base class of all character string scalar types.

choose(a, choices[, out, mode])

Construct an array from an index array and a set of arrays to choose from.

clip(a[, a_min, a_max, out])

Clip (limit) the values in an array.

column_stack(tup)

Stack 1-D arrays as columns into a 2-D array.

complex_

alias of jax._src.numpy.lax_numpy.complex128

complex128

complex64

complexfloating

Abstract base class of all complex number scalar types that are made up of floating-point numbers.

ComplexWarning

The warning raised when casting a complex dtype to a real dtype.

compress(condition, a[, axis, out])

Return selected slices of an array along given axis.

concatenate(arrays[, axis])

Join a sequence of arrays along an existing axis.

conj(x)

Return the complex conjugate, element-wise.

conjugate(x)

Return the complex conjugate, element-wise.

convolve(a, v[, mode, precision])

Returns the discrete, linear convolution of two one-dimensional sequences.

copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

corrcoef(x[, y, rowvar])

Return Pearson product-moment correlation coefficients.

correlate(a, v[, mode, precision])

Cross-correlation of two 1-dimensional sequences.

cos(x)

Cosine element-wise.

cosh(x)

Hyperbolic cosine, element-wise.

count_nonzero(a[, axis, keepdims])

Counts the number of non-zero values in the array a.

cov(m[, y, rowvar, bias, ddof, fweights, …])

Estimate a covariance matrix, given data and weights.

cross(a, b[, axisa, axisb, axisc, axis])

Return the cross product of two (arrays of) vectors.

csingle

alias of jax._src.numpy.lax_numpy.complex64

cumprod(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumproduct(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumsum(a[, axis, dtype, out])

Return the cumulative sum of the elements along a given axis.

deg2rad(x)

Convert angles from degrees to radians.

degrees(x)

Convert angles from radians to degrees.

diag(v[, k])

Extract a diagonal or construct a diagonal array.

diagflat(v[, k])

Create a two-dimensional array with the flattened input as a diagonal.

diag_indices(n[, ndim])

Return the indices to access the main diagonal of an array.

diag_indices_from(arr)

Return the indices to access the main diagonal of an n-dimensional array.

diagonal(a[, offset, axis1, axis2])

Return specified diagonals.

diff(a[, n, axis])

Calculate the n-th discrete difference along the given axis.

digitize(x, bins[, right])

Return the indices of the bins to which each value in input array belongs.

divide(x1, x2)

Returns a true division of the inputs, element-wise.

divmod(x1, x2)

Return element-wise quotient and remainder simultaneously.

dot(a, b, *[, precision])

Dot product of two arrays.

double

alias of jax._src.numpy.lax_numpy.float64

dsplit(ary, indices_or_sections)

Split array into multiple sub-arrays along the 3rd axis (depth).

dstack(tup)

Stack arrays in sequence depth wise (along third axis).

dtype(obj[, align, copy])

Create a data type object.

ediff1d(ary[, to_end, to_begin])

The differences between consecutive elements of an array.

einsum(*operands[, out, optimize, precision])

Evaluates the Einstein summation convention on the operands.

einsum_path(subscripts, *operands[, optimize])

Evaluates the lowest cost contraction order for an einsum expression by

empty(shape[, dtype])

Return a new array of given shape and type, filled with zeros.

empty_like(a[, dtype, shape])

Return an array of zeros with the same shape and type as a given array.

equal(x1, x2)

Return (x1 == x2) element-wise.

exp(x)

Calculate the exponential of all elements in the input array.

exp2(x)

Calculate 2**p for all p in the input array.

expand_dims(a, axis)

Expand the shape of an array.

expm1(x)

Calculate exp(x) - 1 for all elements in the array.

extract(condition, arr)

Return the elements of an array that satisfy some condition.

eye(N[, M, k, dtype])

Return a 2-D array with ones on the diagonal and zeros elsewhere.

fabs(x)

Compute the absolute values element-wise.

finfo(dtype)

Machine limits for floating point types.

fix(x[, out])

Round to nearest integer towards zero.

flatnonzero(a)

Return indices that are non-zero in the flattened version of a.

flexible

Abstract base class of all scalar types without predefined length.

flip(m[, axis])

Reverse the order of elements in an array along the given axis.

fliplr(m)

Flip array in the left/right direction.

flipud(m)

Flip array in the up/down direction.

float_

alias of jax._src.numpy.lax_numpy.float64

float16

float32

float64

floating

Abstract base class of all floating-point scalar types.

float_power(x1, x2)

First array elements raised to powers from second array, element-wise.

floor(x)

Return the floor of the input, element-wise.

floor_divide(x1, x2)

Return the largest integer smaller or equal to the division of the inputs.

fmax(x1, x2)

Element-wise maximum of array elements.

fmin(x1, x2)

Element-wise minimum of array elements.

fmod(x1, x2)

Return the element-wise remainder of division.

frexp(x)

Decompose the elements of x into mantissa and twos exponent.

full(shape, fill_value[, dtype])

Return a new array of given shape and type, filled with fill_value.

full_like(a, fill_value[, dtype, shape])

Return a full array with the same shape and type as a given array.

gcd(x1, x2)

Returns the greatest common divisor of |x1| and |x2|

geomspace(start, stop[, num, endpoint, …])

Return numbers spaced evenly on a log scale (a geometric progression).

gradient(f, *varargs[, axis, edge_order])

Return the gradient of an N-dimensional array.

greater(x1, x2)

Return the truth value of (x1 > x2) element-wise.

greater_equal(x1, x2)

Return the truth value of (x1 >= x2) element-wise.

hamming(*args, **kwargs)

Return the Hamming window.

hanning(*args, **kwargs)

Return the Hanning window.

heaviside(x1, x2)

Compute the Heaviside step function.

histogram(a[, bins, range, weights, density])

Compute the histogram of a set of data.

histogram_bin_edges(a[, bins, range, weights])

Function to calculate only the edges of the bins used by the histogram

histogram2d(x, y[, bins, range, weights, …])

Compute the bi-dimensional histogram of two data samples.

histogramdd(sample[, bins, range, weights, …])

Compute the multidimensional histogram of some data.

hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

hstack(tup)

Stack arrays in sequence horizontally (column wise).

hypot(x1, x2)

Given the “legs” of a right triangle, return its hypotenuse.

i0(x)

Modified Bessel function of the first kind, order 0.

identity(n[, dtype])

Return the identity array.

iinfo(type)

Machine limits for integer types.

imag(val)

Return the imaginary part of the complex argument.

in1d(ar1, ar2[, assume_unique, invert])

Test whether each element of a 1-D array is also present in a second array.

indices(dimensions[, dtype, sparse])

Return an array representing the indices of a grid.

inexact

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

inner(a, b, *[, precision])

Inner product of two arrays.

int_

alias of jax._src.numpy.lax_numpy.int64

int16

int32

int64

int8

integer

Abstract base class of all integer scalar types.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation.

intersect1d(ar1, ar2[, assume_unique, …])

Find the intersection of two arrays.

invert(x)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

isclose(a, b[, rtol, atol, equal_nan])

Returns a boolean array where two arrays are element-wise equal within a

iscomplex(x)

Returns a bool array, where True if input element is complex.

iscomplexobj(x)

Check for a complex type or an array of complex numbers.

isfinite(x)

Test element-wise for finiteness (not infinity or not Not a Number).

isin(element, test_elements[, …])

Calculates element in test_elements, broadcasting over element only.

isinf(x)

Test element-wise for positive or negative infinity.

isnan(x)

Test element-wise for NaN and return result as a boolean array.

isneginf(x[, out])

Test element-wise for negative infinity, return result as bool array.

isposinf(x[, out])

Test element-wise for positive infinity, return result as bool array.

isreal(x)

Returns a bool array, where True if input element is real.

isrealobj(x)

Return True if x is a not complex type or an array of complex numbers.

isscalar(element)

Returns True if the type of element is a scalar type.

issubdtype(arg1, arg2)

Returns True if first argument is a typecode lower/equal in type hierarchy.

issubsctype(arg1, arg2)

Determine if the first argument is a subclass of the second argument.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Construct an open mesh from multiple sequences.

kaiser(*args, **kwargs)

Return the Kaiser window.

kron(a, b)

Kronecker product of two arrays.

lcm(x1, x2)

Returns the lowest common multiple of |x1| and |x2|

ldexp(x1, x2)

Returns x1 * 2**x2, element-wise.

left_shift(x1, x2)

Shift the bits of an integer to the left.

less(x1, x2)

Return the truth value of (x1 < x2) element-wise.

less_equal(x1, x2)

Return the truth value of (x1 =< x2) element-wise.

lexsort(keys[, axis])

Perform an indirect stable sort using a sequence of keys.

linspace(start, stop[, num, endpoint, …])

Return evenly spaced numbers over a specified interval.

load(file[, mmap_mode, allow_pickle, …])

Load arrays or pickled objects from .npy, .npz or pickled files.

log(x)

Natural logarithm, element-wise.

log10(x)

Return the base 10 logarithm of the input array, element-wise.

log1p(x)

Return the natural logarithm of one plus the input array, element-wise.

log2(x)

Base-2 logarithm of x.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

logical_and(*args)

Compute the truth value of x1 AND x2 element-wise.

logical_not(*args)

Compute the truth value of NOT x element-wise.

logical_or(*args)

Compute the truth value of x1 OR x2 element-wise.

logical_xor(*args)

Compute the truth value of x1 XOR x2, element-wise.

logspace(start, stop[, num, endpoint, base, …])

Return numbers spaced evenly on a log scale.

mask_indices(*args, **kwargs)

Return the indices to access (n, n) arrays, given a masking function.

matmul(a, b, *[, precision])

Matrix product of two arrays.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

maximum(x1, x2)

Element-wise maximum of array elements.

mean(a[, axis, dtype, out, keepdims])

Compute the arithmetic mean along the specified axis.

median(a[, axis, out, overwrite_input, keepdims])

Compute the median along the specified axis.

meshgrid(*args, **kwargs)

Return coordinate matrices from coordinate vectors.

min(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

minimum(x1, x2)

Element-wise minimum of array elements.

mod(x1, x2)

Return element-wise remainder of division.

modf(x[, out])

Return the fractional and integral parts of an array, element-wise.

moveaxis(a, source, destination)

Move axes of an array to new positions.

msort(a)

Return a copy of an array sorted along the first axis.

multiply(x1, x2)

Multiply arguments element-wise.

nanargmax(a[, axis])

Return the indices of the maximum values in the specified axis ignoring

nanargmin(a[, axis])

Return the indices of the minimum values in the specified axis ignoring

nancumprod(a[, axis, dtype, out])

Return the cumulative product of array elements over a given axis treating Not a

nancumsum(a[, axis, dtype, out])

Return the cumulative sum of array elements over a given axis treating Not a

nanmax(a[, axis, out, keepdims])

Return the maximum of an array or maximum along an axis, ignoring any

nanmean(a[, axis, dtype, out, keepdims])

Compute the arithmetic mean along the specified axis, ignoring NaNs.

nanmedian(a[, axis, out, overwrite_input, …])

Compute the median along the specified axis, while ignoring NaNs.

nanmin(a[, axis, out, keepdims])

Return minimum of an array or minimum along an axis, ignoring any NaNs.

nanpercentile(a, q[, axis, out, …])

Compute the qth percentile of the data along the specified axis,

nanprod(a[, axis, dtype, out, keepdims])

Return the product of array elements over a given axis treating Not a

nanquantile(a, q[, axis, out, …])

Compute the qth quantile of the data along the specified axis,

nanstd(a[, axis, dtype, out, ddof, keepdims])

Compute the standard deviation along the specified axis, while

nansum(a[, axis, dtype, out, keepdims])

Return the sum of array elements over a given axis treating Not a

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN with zero and infinity with large finite numbers (default

nanvar(a[, axis, dtype, out, ddof, keepdims])

Compute the variance along the specified axis, while ignoring NaNs.

ndarray([dtype, buffer, offset, strides, order])

ndim(a)

Return the number of dimensions of an array.

negative(x)

Numerical negative, element-wise.

nextafter(x1, x2)

Return the next floating-point value after x1 towards x2, element-wise.

nonzero(a)

Return the indices of the elements that are non-zero.

not_equal(x1, x2)

Return (x1 != x2) element-wise.

number

Abstract base class of all numeric scalar types.

object_

Any Python object.

ones(shape[, dtype])

Return a new array of given shape and type, filled with ones.

ones_like(a[, dtype, shape])

Return an array of ones with the same shape and type as a given array.

outer(a, b[, out])

Compute the outer product of two vectors.

packbits(a[, axis, bitorder])

Packs the elements of a binary-valued array into bits in a uint8 array.

pad(array, pad_width[, mode, constant_values])

Pad an array.

percentile(a, q[, axis, out, …])

Compute the q-th percentile of the data along the specified axis.

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a piecewise-defined function.

polyadd(a1, a2)

Find the sum of two polynomials.

polyder(p[, m])

Return the derivative of the specified order of a polynomial.

polymul(a1, a2, *[, trim_leading_zeros])

Find the product of two polynomials.

polysub(a1, a2)

Difference (subtraction) of two polynomials.

polyval(p, x)

Evaluate a polynomial at specific values.

positive(x)

Numerical positive, element-wise.

power(x1, x2)

First array elements raised to powers from second array, element-wise.

prod(a[, axis, dtype, out, keepdims, …])

Return the product of array elements over a given axis.

product(a[, axis, dtype, out, keepdims, …])

Return the product of array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Range of values (maximum - minimum) along an axis.

quantile(a, q[, axis, out, overwrite_input, …])

Compute the q-th quantile of the data along the specified axis.

rad2deg(x)

Convert angles from radians to degrees.

radians(x)

Convert angles from degrees to radians.

ravel(a[, order])

Return a contiguous flattened array.

ravel_multi_index(multi_index, dims[, mode, …])

Converts a tuple of index arrays into an array of flat

real(val)

Return the real part of the complex argument.

reciprocal(x)

Return the reciprocal of the argument, element-wise.

remainder(x1, x2)

Return element-wise remainder of division.

repeat(a, repeats[, axis, total_repeat_length])

Repeat elements of an array.

reshape(a, newshape[, order])

Gives a new shape to an array without changing its data.

result_type(*args)

Returns the type that results from applying the NumPy

right_shift(x1, x2)

Shift the bits of an integer to the right.

rint(x)

Round elements of the array to the nearest integer.

roll(a, shift[, axis])

Roll array elements along a given axis.

rollaxis(a, axis[, start])

Roll the specified axis backwards, until it lies in a given position.

roots(p, *[, strip_zeros])

Return the roots of a polynomial with coefficients given in p.

rot90(m[, k, axes])

Rotate an array by 90 degrees in the plane specified by axes.

round(a[, decimals, out])

Round an array to the given number of decimals.

row_stack(tup)

Stack arrays in sequence vertically (row wise).

save(file, arr[, allow_pickle, fix_imports])

Save an array to a binary file in NumPy .npy format.

savez(file, *args, **kwds)

Save several arrays into a single file in uncompressed .npz format.

searchsorted(a, v[, side, sorter])

Find indices where elements should be inserted to maintain order.

select(condlist, choicelist[, default])

Return an array drawn from elements in choicelist, depending on conditions.

set_printoptions([precision, threshold, …])

Set printing options.

setdiff1d(ar1, ar2[, assume_unique])

Find the set difference of two arrays.

shape(a)

Return the shape of an array.

sign(x)

Returns an element-wise indication of the sign of a number.

signbit(x)

Returns element-wise True where signbit is set (less than zero).

signedinteger

Abstract base class of all signed integer scalar types.

sin(x)

Trigonometric sine, element-wise.

sinc(x)

Return the sinc function.

single

alias of jax._src.numpy.lax_numpy.float32

sinh(x)

Hyperbolic sine, element-wise.

size(a[, axis])

Return the number of elements along a given axis.

sometrue(a[, axis, out, keepdims])

Test whether any array element along a given axis evaluates to True.

sort(a[, axis, kind, order])

Return a sorted copy of an array.

sort_complex(a)

Sort a complex array using the real part first, then the imaginary part.

split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays as views into ary.

sqrt(x)

Return the non-negative square-root of an array, element-wise.

square(x)

Return the element-wise square of the input.

squeeze(a[, axis])

Remove single-dimensional entries from the shape of an array.

stack(arrays[, axis, out])

Join a sequence of arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims])

Compute the standard deviation along the specified axis.

subtract(x1, x2)

Subtract arguments, element-wise.

sum(a[, axis, dtype, out, keepdims, …])

Sum of array elements over a given axis.

swapaxes(a, axis1, axis2)

Interchange two axes of an array.

take(a, indices[, axis, out, mode])

Take elements from an array along an axis.

take_along_axis(arr, indices, axis)

Take values from the input array by matching 1d index and data slices.

tan(x)

Compute tangent element-wise.

tanh(x)

Compute hyperbolic tangent element-wise.

tensordot(a, b[, axes, precision])

Compute tensor dot product along specified axes.

tile(A, reps)

Construct an array by repeating A the number of times given by reps.

trace(a[, offset, axis1, axis2, dtype, out])

Return the sum along diagonals of the array.

transpose(a[, axes])

Reverse or permute the axes of an array; returns the modified array.

trapz(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

tri(N[, M, k, dtype])

An array with ones at and below the given diagonal and zeros elsewhere.

tril(m[, k])

Lower triangle of an array.

tril_indices(*args, **kwargs)

Return the indices for the lower-triangle of an (n, m) array.

tril_indices_from(arr[, k])

Return the indices for the lower-triangle of arr.

trim_zeros(filt[, trim])

Trim the leading and/or trailing zeros from a 1-D array or sequence.

triu(m[, k])

Upper triangle of an array.

triu_indices(*args, **kwargs)

Return the indices for the upper-triangle of an (n, m) array.

triu_indices_from(arr[, k])

Return the indices for the upper-triangle of arr.

true_divide(x1, x2)

Returns a true division of the inputs, element-wise.

trunc(x)

Return the truncated value of the input, element-wise.

uint16

uint32

uint64

uint8

unique(ar[, return_index, return_inverse, …])

Find the unique elements of an array.

unpackbits(a[, axis, count, bitorder])

Unpacks elements of a uint8 array into a binary-valued output array.

unravel_index(indices, shape)

Converts a flat index or array of flat indices into a tuple

unsignedinteger

Abstract base class of all unsigned integer scalar types.

unwrap(p[, discont, axis])

Unwrap by changing deltas between values to 2*pi complement.

vander(x[, N, increasing])

Generate a Vandermonde matrix.

var(a[, axis, dtype, out, ddof, keepdims])

Compute the variance along the specified axis.

vdot(a, b, *[, precision])

Return the dot product of two vectors.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

vstack(tup)

Stack arrays in sequence vertically (row wise).

where(condition[, x, y])

Return elements chosen from x or y depending on condition.

zeros(shape[, dtype])

Return a new array of given shape and type, filled with zeros.

zeros_like(a[, dtype, shape])

Return an array of zeros with the same shape and type as a given array.

jax.numpy.fft

fft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform.

fft2(a[, s, axes, norm])

Compute the 2-dimensional discrete Fourier Transform

fftfreq(n[, d])

Return the Discrete Fourier Transform sample frequencies.

fftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform.

fftshift(x[, axes])

Shift the zero-frequency component to the center of the spectrum.

hfft(a[, n, axis, norm])

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real

ifft(a[, n, axis, norm])

Compute the one-dimensional inverse discrete Fourier Transform.

ifft2(a[, s, axes, norm])

Compute the 2-dimensional inverse discrete Fourier Transform.

ifftn(a[, s, axes, norm])

Compute the N-dimensional inverse discrete Fourier Transform.

ifftshift(x[, axes])

The inverse of fftshift. Although identical for even-length x, the

ihfft(a[, n, axis, norm])

Compute the inverse FFT of a signal that has Hermitian symmetry.

irfft(a[, n, axis, norm])

Compute the inverse of the n-point DFT for real input.

irfft2(a[, s, axes, norm])

Compute the 2-dimensional inverse FFT of a real array.

irfftn(a[, s, axes, norm])

Compute the inverse of the N-dimensional FFT of real input.

rfft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform for real input.

rfft2(a[, s, axes, norm])

Compute the 2-dimensional FFT of a real array.

rfftfreq(n[, d])

Return the Discrete Fourier Transform sample frequencies

rfftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform for real input.

jax.numpy.linalg

cholesky(a)

Cholesky decomposition.

cond(x[, p])

Compute the condition number of a matrix.

det

Compute the determinant of an array.

eig(a)

Compute the eigenvalues and right eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Return the eigenvalues and eigenvectors of a complex Hermitian

eigvals(a)

Compute the eigenvalues of a general matrix.

eigvalsh(a[, UPLO])

Compute the eigenvalues of a complex Hermitian or real symmetric matrix.

inv(a)

Compute the (multiplicative) inverse of a matrix.

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear matrix equation.

matrix_power(a, n)

Raise a square matrix to the (integer) power n.

matrix_rank(M[, tol])

Return matrix rank of array using SVD method

multi_dot(arrays, *[, precision])

Compute the dot product of two or more arrays in a single function call,

norm(x[, ord, axis, keepdims])

Matrix or vector norm.

pinv

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr(a[, mode])

Compute the qr factorization of a matrix.

slogdet

Compute the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear matrix equation, or system of linear scalar equations.

svd(a[, full_matrices, compute_uv])

Singular Value Decomposition.

tensorinv(a[, ind])

Compute the ‘inverse’ of an N-dimensional array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.