torch.einsum¶
- torch.einsum(equation, *operands) Tensor [source][source]¶
沿着使用爱因斯坦求和约定表示法指定的维度对输入
operands
的元素求和。Einsum 允许通过基于爱因斯坦求和公式的简写格式来表示多维线性代数数组运算,从而计算许多常见的运算。该格式的详细信息将在下面描述,但基本思想是为输入的每个维度标注一些下标,并定义哪些下标是输出的一部分。然后通过沿着不是输出部分的下标维度对元素进行求积和来计算输出。例如,可以使用 einsum 来计算矩阵乘法,即 torch.einsum(“ij,jk->ik”,A, B)。在这里,j 是求和下标,i 和 k 是输出下标(更多细节请参阅下文)。
方程式:
equation
字符串指定了输入operands
的每个维度的下标(字母 [a-zA-Z]),其顺序与维度的顺序相同,每个操作数的下标之间用逗号(‘,’)分隔,例如‘ij,jk’指定了两个二维操作数的下标。具有相同下标的维度必须是可广播的,即它们的大小必须匹配或为 1。例外情况是如果同一个输入操作数重复了同一个下标,在这种情况下,该操作数下标所标记的维度大小必须匹配,并且该操作数将沿这些维度替换为其对角线。在equation
中恰好出现一次的下标将包含在输出中,按字母顺序排序。输出是通过逐元素相乘输入operands
,根据下标对齐维度,然后对不包含在输出中的维度进行求和来计算的。可选地,可以通过在等式末尾添加箭头(‘->’)后跟输出下标来显式定义输出下标。例如,以下等式计算矩阵乘积的转置:‘ij,jk->ki’。输出下标必须至少出现在某个输入操作数中一次,并且对于输出最多出现一次。
可以使用省略号(’…’)来代替下标,以广播省略号覆盖的维度。每个输入操作数最多只能包含一个省略号,它将覆盖除下标之外的所有维度,例如,对于一个具有 5 个维度的输入操作数,等式中的省略号‘ab…c’覆盖了第三和第四维度。省略号不需要在
operands
中覆盖相同数量的维度,但省略号的‘形状’(它们覆盖的维度的尺寸)必须能够广播。如果输出没有使用箭头(‘->’)符号显式定义,则省略号将首先出现在输出中(最左侧维度),然后是对于输入操作数恰好出现一次的下标标签。例如,以下等式实现了批量矩阵乘法‘…ij,…jk’。几点最后的说明:方程中可能包含不同元素(下标、省略号、箭头和逗号)之间的空格,但像‘…’这样的内容是不合法的。空字符串‘’对于标量操作数是合法的。
注意
torch.einsum
对省略号(‘…’)的处理与 NumPy 不同,它允许省略号覆盖的维度进行求和,也就是说,省略号不必是输出的一部分。注意
请安装 opt-einsum(https://optimized-einsum.readthedocs.io/en/stable/),以便使用更高效的 einstein 求和。您可以在安装 torch 时安装,如下所示:pip install torch[opt-einsum] 或单独安装:pip install opt-einsum。
如果 opt-einsum 可用,此函数将通过我们的 opt_einsum 后端自动通过优化收缩顺序来加速计算和/或减少内存消耗。这发生在至少有三个输入时,因为在这种情况下顺序并不重要。请注意,找到最佳路径是一个 NP 难题,因此,opt-einsum 依赖于不同的启发式算法来实现近似最优结果。如果 opt-einsum 不可用,默认顺序是从左到右进行收缩。
要绕过此默认行为,请添加以下内容以禁用 opt_einsum 并跳过路径计算:
torch.backends.opt_einsum.enabled = False
要指定 opt_einsum 计算收缩路径的策略,请添加以下行:
torch.backends.opt_einsum.strategy = 'auto'
。默认策略是‘auto’,我们还支持‘greedy’和‘optimal’。请注意,‘optimal’的运行时间是输入数量的阶乘!更多详细信息请参阅 opt_einsum 文档(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。注意
截至 PyTorch 1.10
torch.einsum()
也支持子列表格式(以下为示例)。在此格式中,每个操作数的子脚标由子列表指定,子列表是整数列表,范围在[0, 52)之间。这些子列表跟在其操作数后面,输入末尾可以出现一个额外的子列表,用于指定输出的子脚标,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 对象可以提供在子列表中,以启用如上方程部分所述的广播。- 参数:
方程(字符串)- 爱因斯坦求和的子脚标。
操作数(Tensor 列表)- 计算爱因斯坦求和的张量。
- 返回类型:
示例:
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])