torch.triangular_solve¶
- torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)¶
解具有平方上三角或下三角可逆矩阵 和多个右端项 的方程组。
在符号中,它解出 并假设 是上三角(或下三角,如果
upper
= False)且对角线上没有零。torch.triangular_solve(b, A) 可以接受 2D 输入 b, A 或 2D 矩阵批次的输入。如果输入是批次,则返回批处理输出 X
如果
A
的对角线包含零或非常接近零的元素,并且unitriangular
= False(默认)或输入矩阵条件差,则结果可能包含 NaN。支持浮点数、双精度浮点数、复浮点数和复双精度浮点数的数据类型输入。
警告
torch.triangular_solve()
已弃用,将替换为torch.linalg.solve_triangular()
并将在未来的 PyTorch 版本中删除。torch.linalg.solve_triangular()
的参数顺序已反转,并且不返回输入之一的一个副本。应将
X = torch.triangular_solve(B, A).solution
替换为X = torch.linalg.solve_triangular(A, B)
- 参数:
b(张量)- 大小为 的多个右手边,其中 是零个或多个批处理维度
A(张量)- 大小为 的输入三角形系数矩阵,其中 是零个或多个批处理维度
upper(布尔值,可选)- 是否为上三角或下三角。默认:
True
。transpose(布尔值,可选)- 如果此标志为
True
,则解 op(A)X = b,其中 op(A) = A^T;如果为False
,则 op(A) = A。默认:False
。单位三角(bool,可选)- 是否 为单位三角。如果为 True,则假定 的对角线元素为 1,并且不引用自 。默认:
False
。
- 关键字参数:
out ((Tensor, Tensor),可选)- 要写入输出的两个张量的元组。如果为 None 则忽略。默认:None。
- 返回值:
一个命名元组(solution,cloned_coefficient),其中 cloned_coefficient 是 的副本,solution 是 对 的解(或系统的任何变体,取决于关键字参数。)
示例:
>>> A = torch.randn(2, 2).triu() >>> A tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) >>> b = torch.randn(2, 3) >>> b tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) >>> torch.triangular_solve(b, A) torch.return_types.triangular_solve( solution=tensor([[ 1.7841, 2.9046, -2.5405], [ 1.9320, 0.9270, -1.2826]]), cloned_coefficient=tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]))