快捷键

RMSNorm

class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source][source]

在输入的小批量上应用均方根层归一化。

此层实现了论文《均方根层归一化》中描述的操作。

yi=xiRMS(x)γi,whereRMS(x)=ϵ+1ni=1nxi2y_i = \frac{x_i}{\mathrm{RMS}(x)} * \gamma_i, \quad \text{where} \quad \text{RMS}(x) = \sqrt{\epsilon + \frac{1}{n} \sum_{i=1}^{n} x_i^2}

RMS 是在最后 D 个维度上进行的,其中 Dnormalized_shape 的维度。例如,如果 normalized_shape(3, 5) (一个二维形状),则 RMS 是在输入的最后 2 个维度上计算的。

参数:
  • normalized_shape (int 或 list 或 torch.Size) –

    输入形状来自期望的输入大小

    [×normalized_shape[0]×normalized_shape[1]××normalized_shape[1]][* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] \times \ldots \times \text{normalized\_shape}[-1]]

    如果使用单个整数,则将其视为单例列表,此模块将在此特定大小上归一化最后一个维度,该维度预期为该特定大小。

  • eps(可选[浮点])- 为数值稳定性添加到分母的值。默认: torch.finfo(x.dtype).eps()

  • elementwise_affine(布尔值)- 当设置为 True 时,此模块具有可学习的每个元素仿射参数,初始化为 1(对于权重)。默认: True

形状:
  • 输入: (N,)(N, *)

  • 输出: (N,)(N, *) (与输入形状相同)

示例:

>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)
extra_repr()[source][source]

关于该模块的额外信息。

返回类型:

str

forward(x)[source][source]

前向传播。

返回类型:

张量

reset_parameters()[source][source]

根据在 __init__ 中使用的初始化重置参数。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源