• 文档 >
  • torch.nn >
  • torch.nn.utils.parametrizations.orthogonal
快捷键

torch.nn.utils.parametrizations.orthogonal

torch.nn.utils.parametrizations.orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[source][source]

将正交或酉参数化应用于矩阵或矩阵批。

K\mathbb{K}R\mathbb{R}C\mathbb{C} ,则参数矩阵 QKm×nQ \in \mathbb{K}^{m \times n} 是正交的,

QHQ=Inif mnQQH=Imif m<n\begin{align*} Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} \end{align*}

其中 QHQ^{\text{H}} 是当 QQ 为复数时的共轭转置,当 QQ 为实值时为转置, In\mathrm{I}_n 是 n 维单位矩阵。简单来说, QQ 将具有正交归一列当 mnm \geq n ,否则具有正交归一行。

如果张量具有超过两个维度,我们将其视为形状为(…,m,n)的矩阵批。

矩阵 QQ 可以通过三种不同的 orthogonal_map 在原始张量的基础上进行参数化:

  • "matrix_exp" / "cayley" : 对一个反对称 matrix_exp() 应用 Q=exp(A)Q = \exp(A) 和凯莱映射 Q=(In+A/2)(InA/2)1Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1} ,得到一个正交矩阵。

  • "householder" : 计算豪斯霍尔德反射器的乘积( householder_product() )。

"matrix_exp" / "cayley" 通常比 "householder" 更快地使参数化权重收敛,但它们在非常薄或非常宽的矩阵上计算较慢。

如果 use_trivialization=True (默认),参数化实现了“动态平凡化框架”,其中在 module.parametrizations.weight[0].base 下存储一个额外的矩阵 BKn×nB \in \mathbb{K}^{n \times n} 。这有助于参数化层的收敛,但会牺牲一些额外的内存使用。参见基于梯度的流形优化中的平凡化。

初始值 QQ :如果原始张量未参数化且 use_trivialization=True (默认),则 QQ 的初始值与原始张量相同,如果它是正交的(或复数情况下的酉矩阵),否则通过 QR 分解进行正交化(见 torch.linalg.qr() )。同样,当它未参数化且 orthogonal_map="householder" 时,即使 use_trivialization=False 也是如此。否则,初始值是所有已注册参数化应用于原始张量的组合结果。

注意

此函数使用 register_parametrization() 中的参数化功能实现。

参数:
  • 模块(nn.Module)- 在其上注册参数化的模块。

  • 名称(str,可选)- 要使其正交的张量的名称。默认: "weight"

  • 正交映射(字符串,可选)- 以下选项之一: "matrix_exp""cayley""householder" 。默认:如果矩阵是方阵或复数,则为 "matrix_exp" ,否则为 "householder"

  • use_trivialization(布尔值,可选)- 是否使用动态平凡化框架。默认: True

返回值:

原始模块,具有注册到指定权重的正交参数化

返回类型:

模块

示例:

>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
    (weight): ParametrizationList(
    (0): _Orthogonal()
    )
)
)
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源