torch.linalg.householder_product¶
- torch.linalg.householder_product(A, tau, *, out=None) Tensor ¶
计算 Householder 矩阵乘积的前 n 列。
令 为 或 ,令 为一个矩阵,其列向量由 组成,对应于 ,其中 。记 为零化 的前 个分量并将第 个分量设为 1 得到的向量。对于具有 的向量 ,此函数计算矩阵的前 列
其中 是 m 维单位矩阵, 是复数时的共轭转置,当 是实数时为转置。输出矩阵与输入矩阵
A
大小相同。请参阅正交或酉矩阵的表示以获取更多详细信息。
支持浮点数、双精度浮点数、复浮点数和复双精度浮点数的数据类型输入。也支持矩阵批处理,如果输入是矩阵批处理,则输出具有相同的批处理维度。
参见
torch.geqrf()
可以与该函数一起使用,以从qr()
分解中形成 Q。torch.ormqr()
是一个计算 Householder 矩阵乘积与另一个矩阵矩阵乘积的相关函数。然而,该函数不支持 autograd。警告
只有当 条件满足时,梯度计算才有意义。如果这个条件不满足,不会抛出错误,但产生的梯度可能包含 NaN。
- 参数:
A(张量)- 形状为(*, m, n)的张量,其中*表示零个或多个批处理维度。
tau(张量)- 形状为(*, k)的张量,其中*表示零个或多个批处理维度。
- 关键字参数:
out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。
- 引发:
运行时错误 - 如果
A
不满足要求 m >= n,或者tau
不满足要求 n >= k。
示例:
>>> A = torch.randn(2, 2) >>> h, tau = torch.geqrf(A) >>> Q = torch.linalg.householder_product(h, tau) >>> torch.dist(Q, torch.linalg.qr(A).Q) tensor(0.) >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) >>> tau = torch.randn(3, 1, dtype=torch.complex128) >>> Q = torch.linalg.householder_product(h, tau) >>> Q tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], [-0.6853+0.7953j, 2.0790+0.5620j]], [[ 1.4581+1.6989j, -1.5360+0.1193j], [ 1.3877-0.6691j, 1.3512+1.3024j]], [[ 1.4766+0.5783j, 0.0361+0.6587j], [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128)