快捷键

torch.linalg.pinv

torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) Tensor

计算矩阵的伪逆(摩尔-彭罗斯逆)

假逆可以代数定义,但更便于通过奇异值分解(SVD)来理解

支持输入 float、double、cfloat 和 cdouble 数据类型。也支持矩阵批处理,如果 A 是矩阵批,则输出具有相同的批维度。

如果 hermitian 为真,则假设 A 是复数时的厄米矩阵或实数时的对称矩阵,但内部并不进行检查。相反,仅使用矩阵的下三角部分进行计算。

hermitian 为真时,低于 max(atol,σ1rtol)\max(\text{atol}, \sigma_1 \cdot \text{rtol}) 阈值的奇异值(或特征值的范数)被视为零并在计算中被丢弃,其中 σ1\sigma_1 是最大的奇异值(或特征值)。

如果 rtol 未指定且 A 是一个维度为(m,n)的矩阵,则相对容差设置为 rtol=max(m,n)ε\text{rtol} = \max(m, n) \varepsilonε\varepsilonA 数据类型的 epsilon 值(参见 finfo )。如果 rtol 未指定且 atol 指定为大于零,则 rtol 设置为零。

如果 atolrtoltorch.Tensor ,则其形状必须可以广播到由 A 返回的单个值形状,该值由 torch.linalg.svd() 返回。

注意

此函数在 hermitian 为 False 时使用 torch.linalg.svd() ,在 hermitian 为 True 时使用 torch.linalg.eigh() 。对于 CUDA 输入,此函数将设备与 CPU 同步。

注意

如果可能,请考虑使用 torch.linalg.lstsq() 来乘以矩阵的左伪逆,如下所示:

torch.linalg.lstsq(A, B).solution == A.pinv() @ B

当可能时,始终优先使用 lstsq() ,因为它比显式计算伪逆更快且更数值稳定。

注意

此函数具有与 NumPy 兼容的变体 linalg.pinv(A, rcond, hermitian=False)。但是,使用位置参数 rcond 已被弃用,改用 rtol

警告

此函数内部使用 torch.linalg.svd() (当 hermitian = True 时为 torch.linalg.eigh() ),因此其导数与这些函数存在相同的问题。有关更多详细信息,请参阅 torch.linalg.svd()torch.linalg.eigh() 中的警告。

参见

torch.linalg.inv() 计算一个方阵的逆。

torch.linalg.lstsq() 使用数值稳定的算法计算 A .pinv() @ B

参数:
  • A(张量)- 形状为(*, m, n)的张量,其中*表示零个或多个批处理维度。

  • rcond (float, Tensor, 可选) – [NumPy 兼容]. rtol 的别名。默认:None。

关键字参数:
  • atol (float, Tensor, 可选) – 绝对容差值。当为 None 时,视为零。默认:None。

  • rtol (float, Tensor, 可选) – 相对容差值。当为 None 时,其值见上文。默认:None。

  • hermitian (bool, 可选) – 表示 A 是否为复数时的 Hermitian 或实数时的对称。默认:False。

  • out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。

示例:

>>> A = torch.randn(3, 5)
>>> A
tensor([[ 0.5495,  0.0979, -1.4092, -0.1128,  0.4132],
        [-1.1143, -0.3662,  0.3042,  1.6374, -0.9294],
        [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
>>> torch.linalg.pinv(A)
tensor([[ 0.0600, -0.1933, -0.2090],
        [-0.0903, -0.0817, -0.4752],
        [-0.7124, -0.1631, -0.2272],
        [ 0.1356,  0.3933, -0.5023],
        [-0.0308, -0.1725, -0.5216]])

>>> A = torch.randn(2, 6, 3)
>>> Apinv = torch.linalg.pinv(A)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(8.5633e-07)

>>> A = torch.randn(3, 3, dtype=torch.complex64)
>>> A = A + A.T.conj()  # creates a Hermitian matrix
>>> Apinv = torch.linalg.pinv(A, hermitian=True)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(1.0830e-06)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源