快捷键

torch.linalg.tensorsolve

torch.linalg.tensorsolve(A, B, dims=None, *, out=None) Tensor

计算满足 torch.tensordot(A, X) = B 的解 X。

如果 m 是 A 的前 B 个维度的乘积,n 是其余维度的乘积,则此函数期望 m 和 n 相等。

返回的张量 x 满足 tensordot( A , x, dims=x.ndim) == B 。x 的形状为 A [B.ndim:]。

如果指定了 dims ,则 A 将被重塑为

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

支持浮点、双精度浮点、复浮点型和复双精度浮点型数据类型。

参见

torch.linalg.tensorinv() 计算出 torch.tensordot() 的乘法逆。

参数:
  • A (张量) – 要求解的张量。其形状必须满足 prod( A .shape[: B .ndim]) == prod( A .shape[ B .ndim:])。

  • B (张量) – 形状为 A .shape[: B .ndim] 的张量。

  • dims (可选的 int 元组) – 要移动的 A 的维度。如果为 None,则不移动任何维度。默认:None。

关键字参数:

out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。

引发:

RuntimeError – 如果重塑后的 A .view(m, m)(其中 m 如上所述)不可逆,或者前 ind 个维度的乘积不等于其余维度的乘积。

示例:

>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
>>> B = torch.randn(2 * 3, 4)
>>> X = torch.linalg.tensorsolve(A, B)
>>> X.shape
torch.Size([2, 3, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B)
True

>>> A = torch.randn(6, 4, 4, 3, 2)
>>> B = torch.randn(4, 3, 2)
>>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2))
>>> X.shape
torch.Size([6, 4])
>>> A = A.permute(1, 3, 4, 0, 2)
>>> A.shape[B.ndim:]
torch.Size([6, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6)
True

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源