Warning
This page was created from a pull request.
jax.numpy.inner¶
-
jax.numpy.
inner
(a, b, *, precision=None)[source]¶ Inner product of two arrays.
LAX-backend implementation of
inner()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
inner(a, b)
Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.
- Returns
- outndarray
out.shape = a.shape[:-1] + b.shape[:-1]
- ValueError
If the last dimension of a and b has different size.
tensordot : Sum products over arbitrary axes. dot : Generalised matrix product, using second last dimension of b. einsum : Einstein summation convention.
For vectors (1-D arrays) it computes the ordinary inner-product:
np.inner(a, b) = sum(a[:]*b[:])
More generally, if ndim(a) = r > 0 and ndim(b) = s > 0:
np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))
or explicitly:
np.inner(a, b)[i0,...,ir-1,j0,...,js-1] = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:])
In addition a or b may be scalars, in which case:
np.inner(a,b) = a*b
Ordinary inner product for vectors:
>>> a = np.array([1,2,3]) >>> b = np.array([0,1,0]) >>> np.inner(a, b) 2
A multidimensional example:
>>> a = np.arange(24).reshape((2,3,4)) >>> b = np.arange(4) >>> np.inner(a, b) array([[ 14, 38, 62], [ 86, 110, 134]])
An example where b is a scalar:
>>> np.inner(np.eye(2), 7) array([[7., 0.], [0., 7.]])