快捷键

torch.linalg.multi_dot

torch.linalg.multi_dot(tensors, *, out=None)

高效地通过重新排列乘法顺序来乘以两个或多个矩阵,以执行最少的算术运算。

支持浮点数、双精度浮点数、复浮点数和复双精度浮点数的数据类型输入。此函数不支持批量输入。

tensors 中的每个张量必须是二维的,除了第一个和最后一个,它们可以是 1 维的。如果第一个张量是一个形状为 (n,) 的一维向量,它被视为形状为 (1, n) 的行向量,同样地,如果最后一个张量是一个形状为 (n,) 的一维向量,它被视为形状为 (n, 1) 的列向量。

如果第一个和最后一个张量都是矩阵,输出将是一个矩阵。然而,如果其中任何一个是一维向量,则输出将是一维向量。

与 numpy.linalg.multi_dot 的区别:

  • 与 numpy.linalg.multi_dot 不同,第一个和最后一个张量必须是 1D 或 2D,而 NumPy 允许它们是 nD

警告

此函数不支持广播。

注意

此函数通过在计算最佳矩阵乘法顺序后链式调用 torch.mm() 实现。

注意

两个形状为(a, b)和(b, c)的矩阵相乘的成本是 a * b * c。给定形状为(10, 100)、(100, 5)、(5, 50)的矩阵 A、B、C,我们可以计算不同乘法顺序的成本如下:

cost((AB)C)=10×100×5+10×5×50=7500cost(A(BC))=10×100×50+100×5×50=75000\begin{align*} \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 \end{align*}

在这种情况下,先乘以 A 和 B,然后乘以 C,速度可以提高 10 倍。

参数:

张量(Sequence[Tensor])- 两个或多个张量相乘。第一个和最后一个张量可以是 1D 或 2D。其余的张量必须是 2D。

关键字参数:

out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。

示例:

>>> from torch.linalg import multi_dot

>>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
tensor(8)
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
tensor([8])
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
tensor([[8]])

>>> A = torch.arange(2 * 3).view(2, 3)
>>> B = torch.arange(3 * 2).view(3, 2)
>>> C = torch.arange(2 * 2).view(2, 2)
>>> multi_dot((A, B, C))
tensor([[ 26,  49],
        [ 80, 148]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源