torch.nn.utils.spectral_norm¶
- torch.nn.utils.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source][source]¶
将频谱归一化应用于给定模块中的参数。
频谱归一化通过使用幂迭代法计算权重矩阵的频谱范数来调整权重张量,从而稳定生成对抗网络(GANs)中判别器(critics)的训练。如果权重张量的维度大于 2,则在幂迭代法中将其重塑为 2D 以获取频谱范数。这通过一个钩子实现,该钩子在每次
forward()
调用之前计算频谱范数并调整权重。请参阅针对生成对抗网络的频谱归一化。
- 参数:
module (nn.Module) – 包含模块
name(str,可选)- 权重参数的名称
n_power_iterations (int, 可选) – 计算谱范数的幂迭代次数
eps (float, 可选) – 计算范数时的数值稳定性 epsilon
dim (int, 可选) – 与输出数量对应的维度,默认为
0
,除非是 ConvTranspose{1,2,3}d 模块的实例,此时为1
- 返回值:
带有谱范数钩子的原始模块
- 返回类型:
模块 T
注意
此功能已重新实现为
torch.nn.utils.parametrizations.spectral_norm()
,使用了torch.nn.utils.parametrize.register_parametrization()
中的新参数化功能。请使用较新版本。此功能将在 PyTorch 的将来版本中弃用。示例:
>>> m = spectral_norm(nn.Linear(20, 40)) >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_u.size() torch.Size([40])