• 文档 >
  • torch >
  • torch.lu
快捷键

torch.lu

torch.lu(*args, **kwargs)[source]

计算矩阵或矩阵批次的 LU 分解。返回包含 LU 分解和主元的元组。如果设置 pivotTrue ,则进行主元操作。

警告

torch.lu() 已被 torch.linalg.lu_factor()torch.linalg.lu_factor_ex() 取代。 torch.lu() 将在未来的 PyTorch 版本中删除。 LU, pivots, info = torch.lu(A, compute_pivots) 应该替换为

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

应将 LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 替换为

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 每个矩阵的返回置换矩阵由一个大小为 min(A.shape[-2], A.shape[-1]) 的 1 索引向量表示。 pivots[i] == j 表示在算法的第 i 步中,第 i 行与第 j-1 行进行了置换。

  • pivot = False 时,CPU 上不可用 LU 分解,尝试这样做将引发错误。然而,当 pivot = False 时,CUDA 上可用的 LU 分解。

  • 如果 get_infosTrue ,则此函数不会检查分解是否成功,因为分解的状态包含在返回元组的第三个元素中。

  • 在 CUDA 设备上,对于大小小于或等于 32 的平方矩阵批次,由于 MAGMA 库中的错误(参见 magma 问题 13),LU 分解会重复进行奇异矩阵的分解。

  • LUP 可以通过 torch.lu_unpack() 推导出来。

警告

A 满秩时,此函数的梯度才会是有限的。这是因为 LU 分解仅在满秩矩阵时才可微。此外,如果 A 接近非满秩,由于它依赖于 L1L^{-1}U1U^{-1} 的计算,梯度将数值上不稳定。

参数:
  • A(张量)- 要分解的张量,大小为 (,m,n)(*, m, n)

  • pivot(布尔值,可选)- 控制是否进行转置。默认: True

  • get_infos(布尔值,可选)- 如果设置为 True ,则返回一个信息 IntTensor。默认: False

  • out(元组,可选)- 可选输出元组。如果 get_infosTrue ,则元组中的元素是 Tensor、IntTensor 和 IntTensor。如果 get_infosFalse ,则元组中的元素是 Tensor、IntTensor。默认: None

返回值:

包含张量的元组

  • (Tensor):大小为 (,m,n)(*, m, n) 的分解

  • (IntTensor):大小为 (,min(m,n))(*, \text{min}(m, n)) 的置换。 pivots 存储所有中间行置换。通过将 swap(perm[i], perm[pivots[i] - 1]) 应用于 i = 0, ..., pivots.size(-1) - 1 ,可以重建最终的置换 perm ,其中 perm 是初始的 mm 元素的单位置换(这实际上就是 torch.lu_unpack() 所做的事情)。

  • (IntTensor, 可选):如果 get_infosTrue ,则这是一个大小为 ()(*) 的张量,其中非零值表示矩阵或每个 minibatch 的分解是否成功或失败

返回类型:

(Tensor, IntTensor, IntTensor (可选))

示例:

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源