torch.nn.utils.parametrizations.spectral_norm¶
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source][source]¶
将频谱归一化应用于给定模块中的参数。
当应用于向量时,它简化为
频谱归一化通过减少模型的 Lipschitz 常数来稳定生成对抗网络(GANs)中判别器(批评家)的训练。在每次访问权重时,通过执行一次幂方法迭代来近似 。如果权重张量的维度大于 2,则在幂迭代方法中将它重塑为 2D 以获得频谱范数。
请参阅针对生成对抗网络的频谱归一化。
注意
此函数使用
register_parametrization()
中的参数化功能实现。它是torch.nn.utils.spectral_norm()
的重新实现。注意
当此约束被注册时,与最大的奇异值关联的单个向量是估计的,而不是随机采样。然后,在每次使用模块在训练模式下访问张量时,都会执行
n_power_iterations
的幂方法来更新。注意
如果_SpectralNorm 模块(即 module.parametrization.weight[idx])在移除时处于训练模式,它将执行另一个幂迭代。如果您想避免此迭代,请在移除之前将模块设置为评估模式。
- 参数:
module (nn.Module) – 包含模块
name (str, optional) – 权重参数的名称。默认:
"weight"
。n_power_iterations (int, 可选) – 计算谱范数的幂迭代次数。默认:
1
。eps (float, 可选) – 计算范数时的数值稳定性 epsilon。默认:
1e-12
。dim (int, 可选) – 与输出数量对应的维度。默认:
0
,但对于 ConvTranspose{1,2,3}d 的实例模块,则为1
。
- 返回值:
已注册到指定权重的新的参数化方式的原模块
- 返回类型:
示例:
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0081, grad_fn=<AmaxBackward0>)