NumPy dot 與 matmul 的區別

Manav Narula 2021年3月21日
NumPy dot 與 matmul 的區別

在 Python 中,陣列被視為向量。二維陣列也稱為矩陣。我們提供了一些功能,可以在 Python 中執行它們之間的乘法運算。使用的兩種方法是 numpy.dot() 函式和@運算子(陣列的 __matmul__ 方法)。現在看來它們都執行相同的乘法功能。但是,兩者之間存在一些差異,本教程對此進行了說明。

numpy.dot() 函式用於在 Python 中執行矩陣乘法。它還檢查矩陣乘法的條件,即第一個矩陣的列數必須等於第二個矩陣的行數。它也適用於多維陣列。我們還可以指定一個備用陣列作為引數來儲存結果。@乘法運算子呼叫用於執行相同乘法的陣列的 matmul() 函式。例如,

import numpy as np

a = np.array([[1, 2], [2, 3]])
b = np.array(([8, 4], [4, 7]))
print(np.dot(a, b))
print(a @ b)

輸出:

[[16 18]
 [28 29]]
[[16 18]
 [28 29]]

但是,當我們處理多維陣列(N> 2 的 N-D 陣列)時,結果略有不同。你可以在下面看到區別。

a = np.random.rand(2, 3, 3)
b = np.random.rand(2, 3, 3)
c = a @ b
d = np.dot(a, b)
print(c, c.shape)
print(d, d.shape)

輸出:

[[[0.63629871 0.55054463 0.22289276]
  [1.27578425 1.13950519 0.55370078]
  [1.37809353 1.32313811 0.75460862]]

 [[1.63546361 1.54607801 0.67134528]
  [1.05906619 1.07509384 0.42526795]
  [1.38932102 1.32829749 0.47240808]]] (2, 3, 3)
[[[[0.63629871 0.55054463 0.22289276]
   [0.7938068  0.85668481 0.26504028]]

  [[1.27578425 1.13950519 0.55370078]
   [1.55589497 1.45794424 0.5335743 ]]

  [[1.37809353 1.32313811 0.75460862]
   [1.60564885 1.39494713 0.59370927]]]


 [[[1.48529826 1.55580834 0.96142976]
   [1.63546361 1.54607801 0.67134528]]

  [[0.94601586 0.97181894 0.56701004]
   [1.05906619 1.07509384 0.42526795]]

  [[1.13268609 1.00262696 0.47226983]
   [1.38932102 1.32829749 0.47240808]]]] (2, 3, 2, 3)

matmul() 函式像矩陣的堆疊一樣廣播陣列,它們分別作為位於最後兩個索引中的元素。另一方面,numpy.dot() 函式將乘積作為第一個陣列的最後一個軸與第二個陣列的倒數第二個的乘積之和。

matmul()numpy.dot 函式之間的另一個區別是 matmul() 函式無法執行標量值與陣列的乘法。

作者: Manav Narula
Manav Narula avatar Manav Narula avatar

Manav is a IT Professional who has a lot of experience as a core developer in many live projects. He is an avid learner who enjoys learning new things and sharing his findings whenever possible.

LinkedIn