torch.linalg.pinv¶
- torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) Tensor ¶
计算矩阵的伪逆(摩尔-彭罗斯逆)
假逆可以代数定义,但更便于通过奇异值分解(SVD)来理解
支持输入 float、double、cfloat 和 cdouble 数据类型。也支持矩阵批处理,如果
A
是矩阵批,则输出具有相同的批维度。如果
hermitian
为真,则假设A
是复数时的厄米矩阵或实数时的对称矩阵,但内部并不进行检查。相反,仅使用矩阵的下三角部分进行计算。当
hermitian
为真时,低于 阈值的奇异值(或特征值的范数)被视为零并在计算中被丢弃,其中 是最大的奇异值(或特征值)。如果
rtol
未指定且A
是一个维度为(m,n)的矩阵,则相对容差设置为 , 是A
数据类型的 epsilon 值(参见finfo
)。如果rtol
未指定且atol
指定为大于零,则rtol
设置为零。如果
atol
或rtol
是torch.Tensor
,则其形状必须可以广播到由A
返回的单个值形状,该值由torch.linalg.svd()
返回。注意
此函数在
hermitian
为 False 时使用torch.linalg.svd()
,在hermitian
为 True 时使用torch.linalg.eigh()
。对于 CUDA 输入,此函数将设备与 CPU 同步。注意
如果可能,请考虑使用
torch.linalg.lstsq()
来乘以矩阵的左伪逆,如下所示:torch.linalg.lstsq(A, B).solution == A.pinv() @ B
当可能时,始终优先使用
lstsq()
,因为它比显式计算伪逆更快且更数值稳定。注意
此函数具有与 NumPy 兼容的变体 linalg.pinv(A, rcond, hermitian=False)。但是,使用位置参数
rcond
已被弃用,改用rtol
。警告
此函数内部使用
torch.linalg.svd()
(当hermitian
= True 时为torch.linalg.eigh()
),因此其导数与这些函数存在相同的问题。有关更多详细信息,请参阅torch.linalg.svd()
和torch.linalg.eigh()
中的警告。参见
torch.linalg.inv()
计算一个方阵的逆。torch.linalg.lstsq()
使用数值稳定的算法计算A
.pinv() @B
。- 参数:
A(张量)- 形状为(*, m, n)的张量,其中*表示零个或多个批处理维度。
rcond (float, Tensor, 可选) – [NumPy 兼容].
rtol
的别名。默认:None。
- 关键字参数:
atol (float, Tensor, 可选) – 绝对容差值。当为 None 时,视为零。默认:None。
rtol (float, Tensor, 可选) – 相对容差值。当为 None 时,其值见上文。默认:None。
hermitian (bool, 可选) – 表示
A
是否为复数时的 Hermitian 或实数时的对称。默认:False。out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。
示例:
>>> A = torch.randn(3, 5) >>> A tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) >>> torch.linalg.pinv(A) tensor([[ 0.0600, -0.1933, -0.2090], [-0.0903, -0.0817, -0.4752], [-0.7124, -0.1631, -0.2272], [ 0.1356, 0.3933, -0.5023], [-0.0308, -0.1725, -0.5216]]) >>> A = torch.randn(2, 6, 3) >>> Apinv = torch.linalg.pinv(A) >>> torch.dist(Apinv @ A, torch.eye(3)) tensor(8.5633e-07) >>> A = torch.randn(3, 3, dtype=torch.complex64) >>> A = A + A.T.conj() # creates a Hermitian matrix >>> Apinv = torch.linalg.pinv(A, hermitian=True) >>> torch.dist(Apinv @ A, torch.eye(3)) tensor(1.0830e-06)