快捷键

torch.linalg.lstsq ¶ torch.linalg.最小二乘求解器

torch.linalg.lstsq(A, B, rcond=None, *, driver=None) ¶ torch.linalg.lstsq(A, B, rcond=None, *, driver=None) ¶

计算线性方程组的最小二乘问题的解。

K\mathbb{K}R\mathbb{R}C\mathbb{C} ,线性系统 AX=BAX = BAKm×n,BKm×kA \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k} 的最小二乘问题定义为

minXKn×kAXBF\min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F

其中 F\|-\|_F 表示 Frobenius 范数。

支持浮点数、双精度浮点数、复浮点数和复双精度浮点数的数据类型输入。也支持矩阵批处理,如果输入是矩阵批处理,则输出具有相同的批处理维度。

driver 选择将要使用的后端函数。对于 CPU 输入,有效的值有‘gels’,‘gelsy’,‘gelsd’,‘gelss’。为了在 CPU 上选择最佳驱动器,请考虑:

  • 如果 A 是良态的(其条件数不是太大),或者你不在乎一些精度损失。

    • 对于一般矩阵:‘gelsy’(带置换的 QR 分解)(默认)

    • 如果 A 是满秩的:‘gels’(QR 分解)

  • 如果 A 条件不良。

    • ‘gelsd’(三对角化及奇异值分解)

    • 但如果您遇到内存问题:“gelss”(满秩 SVD)。

对于 CUDA 输入,唯一有效的驱动程序是“gels”,它假定 A 是满秩的。

参考这些驱动程序的完整描述。

rcond 用于确定 A 中矩阵的有效秩,当 driver 为(‘gelsy’,‘gelsd’,‘gelss’)之一时。在这种情况下,如果 σi\sigma_i 是 A 的奇异值,按降序排列, σi\sigma_i 将如果 σircondσ1\sigma_i \leq \text{rcond} \cdot \sigma_1 则向下取整为零。如果 rcond = None(默认),则 rcond 设置为 A 数据类型的机器精度乘以 max(m, n)。

此函数返回问题的解和一些额外信息,以四个张量(solution,residuals,rank,singular_values)组成的命名元组形式。对于输入 AB 的形状分别为(*, m, n),(*, m, k),它包含

  • solution:最小二乘解。其形状为(*, n, k)。

  • residuals:解的平方残差,即 AXBF2\|AX - B\|_F^2 。其形状为(*, k)。当 m > n 且 A 中的每个矩阵都是满秩时,它被计算出来;否则,它是一个空张量。如果 A 是一批矩阵,并且批中的任何矩阵不是满秩,则返回一个空张量。这种行为可能在未来的 PyTorch 版本中发生变化。

  • rank: A 中矩阵的秩张量。其形状等于 A 的批维度。当 driver 是('gelsy','gelsd','gelss')之一时,它被计算出来,否则它是一个空张量。

  • 矩阵 A 的奇异值张量。它具有形状 (*, min(m, n))。当 driver 为 ('gelsd', 'gelss') 之一时,它会被计算,否则为空张量。

注意

此函数以比单独执行计算更快、更数值稳定的方式计算 X = A .pinverse() @ B

警告

在未来的 PyTorch 版本中, rcond 的默认值可能会更改。因此,建议使用固定值以避免潜在的破坏性更改。

参数:
  • (张量) - 形状为 (*, m, n) 的 lhs 张量,其中 * 是零个或多个批处理维度。

  • B(张量)- 形状为(*, m, k)的 rhs 张量,其中*表示零个或多个批处理维度。

  • rcond(可选浮点数)- 用于确定 A 的有效秩。如果 rcond 为 None,则 rcond 设置为 A 数据类型的机器精度乘以 max(m, n)。默认:None。

关键字参数:

driver(可选字符串)- 要使用的 LAPACK/MAGMA 方法的名称。如果为 None,则对于 CPU 输入使用‘gelsy’,对于 CUDA 输入使用‘gels’。默认:None。

返回值:

命名元组(解,残差,秩,奇异值)。

示例:

>>> A = torch.randn(1,3,3)
>>> A
tensor([[[-1.0838,  0.0225,  0.2275],
     [ 0.2438,  0.3844,  0.5499],
     [ 0.1175, -0.9102,  2.0870]]])
>>> B = torch.randn(2,3,3)
>>> B
tensor([[[-0.6772,  0.7758,  0.5109],
     [-1.4382,  1.3769,  1.1818],
     [-0.3450,  0.0806,  0.3967]],
    [[-1.3994, -0.1521, -0.1473],
     [ 1.9194,  1.0458,  0.6705],
     [-1.1802, -0.9796,  1.4086]]])
>>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3)
>>> torch.dist(X, torch.linalg.pinv(A) @ B)
tensor(1.5152e-06)

>>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values
>>> torch.dist(S, torch.linalg.svdvals(A))
tensor(2.3842e-07)

>>> A[:, 0].zero_()  # Decrease the rank of A
>>> rank = torch.linalg.lstsq(A, B).rank
>>> rank
tensor([2])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源