Warning
This page was created from a pull request.
jax.numpy.dot¶
-
jax.numpy.
dot
(a, b, *, precision=None)[source]¶ Dot product of two arrays. Specifically,
LAX-backend implementation of
dot()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
dot(a, b, out=None)
If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation).
If both a and b are 2-D arrays, it is matrix multiplication, but using
matmul()
ora @ b
is preferred.If either a or b is 0-D (scalar), it is equivalent to
multiply()
and usingnumpy.multiply(a, b)
ora * b
is preferred.If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.
If a is an N-D array and b is an M-D array (where
M>=2
), it is a sum product over the last axis of a and the second-to-last axis of b:dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
- Returns
- outputndarray
Returns the dot product of a and b. If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned. If out is given, then it is returned.
- ValueError
If the last dimension of a is not the same size as the second-to-last dimension of b.
vdot : Complex-conjugating dot product. tensordot : Sum products over arbitrary axes. einsum : Einstein summation convention. matmul : ‘@’ operator as method with out parameter.
>>> np.dot(3, 4) 12
Neither argument is complex-conjugated:
>>> np.dot([2j, 3j], [2j, 3j]) (-13+0j)
For 2-D arrays it is the matrix product:
>>> a = [[1, 0], [0, 1]] >>> b = [[4, 1], [2, 2]] >>> np.dot(a, b) array([[4, 1], [2, 2]])
>>> a = np.arange(3*4*5*6).reshape((3,4,5,6)) >>> b = np.arange(3*4*5*6)[::-1].reshape((5,4,6,3)) >>> np.dot(a, b)[2,3,2,1,2,2] 499128 >>> sum(a[2,3,2,:] * b[1,2,:,2]) 499128