• 文档 >
  • torch >
  • torch.tensordot
快捷键

torch.tensordot

torch.tensordot(a, b, dims=2, out=None)[source][source]

返回 a 和 b 在多个维度上的收缩。

实现了一种通用的矩阵乘法。

参数:
  • a (张量) – 要收缩的左张量

  • b (张量) – 要收缩的右张量

  • dims (int 或 Tuple[List[int], List[int]] 或 List[List[int]],包含两个列表或张量) – 要收缩的维度数量或对 ab 分别显式的维度列表

当使用非负整数参数 dims = dd 调用时,如果 ab 的维度分别为 mmnn ,则 tensordot() 进行计算

ri0,...,imd,id,...,in=k0,...,kd1ai0,...,imd,k0,...,kd1×bk0,...,kd1,id,...,in.r_{i_0,...,i_{m-d}, i_d,...,i_n} = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.

当使用列表形式的 dims 调用时,将就地收缩 a 的最后 dd 个和 bb 的第一个 dd 个维度。这些维度的大小必须匹配,但 tensordot() 将处理广播维度。

示例:

>>> a = torch.arange(60.).reshape(3, 4, 5)
>>> b = torch.arange(24.).reshape(4, 3, 2)
>>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
tensor([[4400., 4730.],
        [4532., 4874.],
        [4664., 5018.],
        [4796., 5162.],
        [4928., 5306.]])

>>> a = torch.randn(3, 4, 5, device='cuda')
>>> b = torch.randn(4, 5, 6, device='cuda')
>>> c = torch.tensordot(a, b, dims=2).cpu()
tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
        [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
        [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])

>>> a = torch.randn(3, 5, 4, 6)
>>> b = torch.randn(6, 4, 5, 3)
>>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
tensor([[  7.7193,  -2.4867, -10.3204],
        [  1.5513, -14.4737,  -6.5113],
        [ -0.2850,   4.2573,  -3.5997]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源