torch.linalg.tensorinv¶
- torch.linalg.tensorinv(A, ind=2, *, out=None) Tensor¶
计算表达式
torch.tensordot()的乘法逆。如果 m 是
A的前ind个维度的乘积,而 n 是其余维度的乘积,则此函数期望 m 和 n 相等。如果这种情况成立,则计算一个张量 X,使得 tensordot(A, X,ind) 在维度 m 上是单位矩阵。X 将具有A的形状,但将前ind个维度推到末尾X.shape == A.shape[ind:] + A.shape[:ind]
支持输入浮点型、双精度浮点型、复浮点型和复双精度浮点型数据类型。
注意
当
A是一个二维张量且ind= 1 时,此函数计算A的(乘法)逆(见torch.linalg.inv())。注意
如果可能,请考虑使用
torch.linalg.tensorsolve()在左侧乘以张量逆,如下所示:linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim]
当可能时,始终优先使用
tensorsolve(),因为它比显式计算伪逆更快且更数值稳定。参见
torch.linalg.tensorsolve()计算 torch.tensordot(tensorinv(A),B)。- 参数:
一个(张量)- 要反转的张量。其形状必须满足 prod(
A.shape[:ind]) == prod(A.shape[ind:])。ind(int)- 计算逆的索引位置
torch.tensordot()。默认:2。
- 关键字参数:
out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。
- 引发:
RuntimeError - 如果重塑后的
A不可逆或前ind个维度的乘积不等于其余维度的乘积。
示例:
>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) >>> Ainv = torch.linalg.tensorinv(A, ind=2) >>> Ainv.shape torch.Size([8, 3, 4, 6]) >>> B = torch.randn(4, 6) >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) True >>> A = torch.randn(4, 4) >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) >>> Ainv = torch.linalg.inv(A) >>> torch.allclose(Atensorinv, Ainv) True