torch.linalg.tensorsolve¶
- torch.linalg.tensorsolve(A, B, dims=None, *, out=None) Tensor ¶
计算满足 torch.tensordot(A, X) = B 的解 X。
如果 m 是
A
的前B
个维度的乘积,n 是其余维度的乘积,则此函数期望 m 和 n 相等。返回的张量 x 满足 tensordot(
A
, x, dims=x.ndim) ==B
。x 的形状为A
[B.ndim:]。如果指定了
dims
,则A
将被重塑为A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))
支持浮点、双精度浮点、复浮点型和复双精度浮点型数据类型。
参见
torch.linalg.tensorinv()
计算出torch.tensordot()
的乘法逆。- 参数:
A (张量) – 要求解的张量。其形状必须满足 prod(
A
.shape[:B
.ndim]) == prod(A
.shape[B
.ndim:])。B (张量) – 形状为
A
.shape[:B
.ndim] 的张量。dims (可选的 int 元组) – 要移动的
A
的维度。如果为 None,则不移动任何维度。默认:None。
- 关键字参数:
out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。
- 引发:
RuntimeError – 如果重塑后的
A
.view(m, m)(其中 m 如上所述)不可逆,或者前ind
个维度的乘积不等于其余维度的乘积。
示例:
>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) >>> B = torch.randn(2 * 3, 4) >>> X = torch.linalg.tensorsolve(A, B) >>> X.shape torch.Size([2, 3, 4]) >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) True >>> A = torch.randn(6, 4, 4, 3, 2) >>> B = torch.randn(4, 3, 2) >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) >>> X.shape torch.Size([6, 4]) >>> A = A.permute(1, 3, 4, 0, 2) >>> A.shape[B.ndim:] torch.Size([6, 4]) >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) True