由 Michael Gschwind、Driss Guessous、Christian Puhrsch 著

PyTorch 2.0 版本发布包括 PyTorch Transformer API 的新高性能实现,旨在使最先进的 Transformer 模型的训练和部署变得经济实惠。继“fastpath”推理执行(“更好的 Transformer”)成功发布之后,本版本引入了对使用自定义内核架构进行扩展点积注意力(SPDA)的训练和推理的高性能支持。

您可以通过直接调用新的 SDPA 运算符(如 SDPA 教程中所述)或通过集成到现有的 PyTorch Transformer API 中透明地利用新的融合 SDPA 内核。PyTorch Transformer API 的所有功能将继续兼容工作,许多功能映射到高性能 SDPA 内核,而其他功能可能无法支持更高性能(例如,需要_weights,如下所示),而其他功能的高性能支持可能仍在积极开发中。

与“fastpath”架构类似,自定义内核已完全集成到 PyTorch Transformer API 中——因此,使用原生 Transformer 和 MultiHeadAttention API 将使用户能够透明地看到显著的性能提升。与“fastpath”架构不同,新引入的“自定义内核”支持更多用例,包括使用交叉注意力、Transformer 解码器以及用于训练模型,除了现有的针对固定和可变序列长度 Transformer 编码器和自注意力用例的快速路径推理之外。

为了充分利用不同的硬件模型和 Transformer 使用场景,支持多个 SDPA 自定义内核,并具有自定义内核选择逻辑,该逻辑将根据特定模型和硬件类型选择最高性能的内核。特别是,PyTorch 2.0 版本中包含的第一个自定义内核是 Flash Attention 内核(sdpa_flash,用于 Nvidia GPU 的 SM80+架构级别的 16 位浮点训练和推理)和 xFormers 内存高效注意力内核(sdpa_mem_eff,用于 Nvidia GPU 的 16 位和 32 位浮点训练和推理)。当自定义内核不适用时,通用内核 sdpa_math 提供了一种实现方式。

如前所述,自定义内核提供了更广泛的执行场景支持。为确保高效执行(例如,使用 GPU 张量核心),模型配置需要满足少量要求。随着时间的推移,此要求列表将不断演变,未来可能会放宽限制当前支持的自定义内核使用的约束,或提供额外的内核。

关于自定义内核和调度约束的最新列表,您可以参考 sdp_utils.h。截至 PyTorch 2.0 版本,现有的融合 SDPA 内核有以下约束:

  • Flash Attention 仅支持 16 位浮点数据类型(float16 和 bfloat16)。
  • 头维度必须是 8 的倍数,用于 16 位浮点数,以及 4 的倍数,用于 32 位浮点数。目前,Flash Attention 自定义内核支持的最大 head_dim 为 128。
  • 对于 mem_efficient 内核,CUDA 架构级别必须是 sm5x 或更高,对于 Flash Attention,必须是 sm80。
  • Flash Attention 支持任意 dropout,在 PyTorch 2.0 中,mem_efficient 内核不支持 dropout(即,为了选择此内核,dropout 必须设置为 0)。
  • 为了支持可变序列长度的批次,所有 SDPA 内核都支持使用可变序列长度张量组合输入数据和填充信息的前向 Nested Tensor 输入。(您可以在 Nested Tensor 教程中找到更多关于 Nested Tensor 的信息。)
  • 您可以在将它们传递给 SDPA 操作符之前将 key_padding_mask 和 attn_mask 合并来指定两者。特别是,您可以使用 nn.Transformer API 的每个批次元素的 key padding mask 来实现支持可变序列长度输入的批次的训练。
  • 目前,融合内核实现仅支持用于训练的常见因果掩码。要在自定义内核中指定因果掩码,必须使用 is_causal 布尔值指定,并且 attn_mask 必须为 None。
  • 支持嵌套张量仍在开发中。具体来说,在 PyTorch 2.0 中,只有 sdpa_math 内核支持使用嵌套张量进行训练。此外,PyTorch 2.0 不支持将嵌套张量作为 torch.compile()编译的代码的一部分。
  • SDPA 算子不支持返回平均注意力权重,因为计算这些权重会破坏使融合内核能够更高效执行的优化。torch.nn.MultiheadAttention 的前向函数中的 need_weights 参数默认为 True。为了使用融合内核,需要将 need_weights 设置为 need_weights=False。

我们发现注意力掩码在现实世界中的应用很少,除了训练期间的因果掩码。因此,我们通过内置使用因果掩码作为注意力掩码的选项来降低内核复杂度和计算成本,并通过与新的 SDPA 算子一起引入的 is_causal 参数选择这一新功能。

提供常用的因果掩码的 is_causal 布尔标志,也消除了昂贵的内存密集型因果掩码分配,通过允许更多内存用于大批次大小,提高了训练内存效率,并通过不需要加载注意力掩码张量,减少了内存带宽和缓存竞争——这在 GPU 加速器中都是宝贵的资源。

如果不满足任何可用的自定义内核的约束,则训练将回退到使用默认的 sdpa_math 内核,通过一系列 PyTorch 操作实现 SDPA,实现缩放点积注意力的数学方程。这是最通用的“万用”回退内核,以确保所有模型的训练成功。

除了现有的 Transformer API 之外,模型开发者还可以通过调用新的 scaled_dot_product_attention() 操作符直接使用缩放点积注意力内核。该操作符可以与内投影和投影相结合,以高效实现多头注意力,如 SDPA 教程中所述。

除了添加自定义内核外,加速 PyTorch 2 Transformer 已与 PyTorch 2.0 编译集成。为了在使用模型的同时,从 PT2 编译的额外加速中受益(用于推理或训练),请对模型进行预处理。

model = torch.compile(model)

我们使用自定义内核和 torch.compile() 的组合,在加速 PyTorch 2 Transformer 中实现了训练 Transformer 模型,特别是大型语言模型的重大加速。

Better Transformer chart 图:使用缩放点积注意力机制和自定义内核以及 torch.compile(),为训练大型语言模型(如此处所示的 nanoGPT)提供了显著的加速。

最后,由于自定义内核的内存效率更高,尝试增加训练批的大小,以实现更快训练并提高批大小。

除了自动内核选择之外,上下文管理器还允许开发者覆盖内核选择算法——这并非日常操作所必需,但可以让开发者调试代码,同时也让性能工程师能够覆盖内核选择。SDPA 教程提供了有关使用 SDPA 上下文管理器的更多信息。

除了作为 nn.Transformer API 的一部分提供可用性之外,加速 PyTorch 2 Transformer 自定义内核还与 torchtext、torchvision 和 fairseq 领域库一起在 PyTorch 2.0 的发布时提供。