TL;DR. 我们展示了如何使用加速 PyTorch 2.0 Transformer 和新引入的 torch.compile()
方法来加速大型语言模型,以 nanoGPT 为例,nanoGPT 是 Andrej Karpathy 的 GPT 模型的一个紧凑的开源实现。使用与加速 PT2 Transformer 一起引入的新缩放点积注意力算子,我们选择了 flash_attention 自定义内核,实现了每批次的更快训练时间(使用 Nvidia A100 GPU 测量),从 ~143ms/批次的基础值降低到 ~113 ms/批次。此外,使用 SDPA 算子增强的实现提供了更好的数值稳定性。最后,通过使用填充输入进一步优化,与 flash attention 结合后,每批次时间降低到 ~87ms。
近年来,大型语言模型(LLMs)和生成式 AI 在日常生活中得到了指数级的采用。与这些不断增长模型紧密相连的是不断增长的训练成本——无论是时间还是硬件利用率。PyTorch 团队直面这些挑战,推出了加速 PyTorch 2 Transformers(之前称为“更好的 Transformer”)和 PyTorch 2.0 中的即时编译(JIT Compilation)。
在本文中,我们探讨了通过使用 SDPA(也称为缩放点积注意力)的定制内核实现来获得的训练优化,SDPA 是 Transformer 模型中的一个关键层。定制的 SDPA 内核将几个离散的顺序操作替换为一个全局优化的内核,从而避免了分配大量中间 CUDA 内存。这种方法具有许多优点,包括但不限于:通过减少内存带宽瓶颈来提高 SDPA 的计算性能,减少内存占用以支持更大的批量大小,以及通过预缩放输入张量来提高数值稳定性。这些优化在 nanoGPT 上得到了演示,它是 Andrej Karpathy 的开源 GPT 实现。
背景
扩展点积注意力是多头注意力的基本构建块,如“Attention is All You Need”中所述,在LLM和生成式 AI 模型中具有广泛的应用。
图 1:基于“Attention is All You Need”的 Transformer 模型架构。利用新的 PyTorch SDPA 算子,通过线性层进行内投影、SDPA 算子和线性层进行外投影,高效地实现了多头注意力。
利用新的 scaled_dot_product_attention 算子,多头注意力可以通过以下 3 个步骤实现:使用线性层进行内投影、SDPA 和使用线性层进行外投影。
# In Projection
# variable descriptions:
# q,k,v = Query, Key, Value tensors
# bsz = batch size
# num_heads = Numner of heads for Multihead Attention
# tgt_len = Target length
# src_len = Source Length
# head_dim: Head Dimension
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
# Scaled Dot Product Attention
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
# Out Projection
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
PyTorch 2.支持针对特定用例优化的多种不同内核,具有特定要求。内核选择器会为特定输入参数组合选择最佳内核。如果无法识别针对特定输入参数组合的优化“自定义内核”,内核选择器将选择可以处理所有输入组合的通用内核。
虽然未来的版本可能会扩展此操作符集,但 PyTorch 2.0 的发布包含了 SDPA 操作符的 3 种实现:
- 一个通用的内核,实现了 SDPA 的数学方程式,函数为
sdpa_math()
- 基于论文“Flash Attention”的优化内核,支持在计算架构 SM80(A100)上使用 16 位浮点数据类型评估 SDPA。
- 基于论文“Self-Attention Does Not Need O(n^2) Memory”并实现于 xFormer 的优化内核,支持在更广泛的架构(SM40 及以后)上使用 32 位和 16 位浮点数据类型。本博客文章将此内核称为
mem_efficient
内核。
注意,上述列出的优化内核(两个和三个),都支持关键填充掩码并限制支持的注意力掩码为因果注意力。今天加速 PyTorch 2.0 Transformer 仅支持当使用 is_causal
布尔值指定时才支持因果掩码。当指定掩码时,将选择通用内核,因为分析提供的掩码内容以确定它是否是因果掩码的成本太高。有关每个内核的约束条件的更多解释,请参阅加速 PT2 Transformer 博客。
启用 nanoGPT 加速 Transformer
SDPA 运算符是 GPT 模型的关键组件,我们确定了开源 nanoGPT 模型作为展示 PyTorch 2.0 加速 Transformer 实现简便性和益处的优秀候选者。以下展示了在 nanoGPT 上启用加速 Transformer 的确切过程。
这个过程主要围绕用功能模块中的新添加的 F.scaled_dot_product_attention 操作符替换现有的 SDPA 实现。这个过程可以轻松地适应,以便在许多其他 LLMs 中启用该操作符。或者,用户可以选择在适用的情况下调用 F.multi_head_attention_forward() 或直接使用 nn.MultiHeadAttention 模块。以下代码片段改编自 Karpathy 的 nanoGPT 仓库。
第一步:识别现有的 SDPA 实现
在 nanoGPT 的情况下,SDPA 实现在模型的 CausalSelfAttention 类中。以下是本文中使用的原始实现。
第二步:用 PyTorch 的 scaled_dot_product_attention 替换
在这一点上,我们可以注意以下:
- 第 36 至 42 行定义了 SDPA 的数学实现,我们将对其进行替换
- 第 39 行应用的面罩现在不再相关,因为我们使用了 scaled_dot_product_attention 的
is_causal
标志。 - 第 41 行使用的 dropout 层现在也变得不再必要。
将 SDPA 实现替换为 torch 的 scaled_dot_product_attention 并移除现在冗余的代码,得到以下实现。
或者,可以将原始掩码传递到 attn_mask
字段,但由于提到的内核约束,这将限制实现仅支持通用的 sdpa_math
内核。
第 3 步(加分项):使用填充加快 matmuls
在 SDPA 的性能改进的基础上,我们的分析还带来了一些额外的收获。用 Andrej 的话说:“迄今为止对 nanoGPT 最显著的优化(约 25%的速度提升)就是简单地增加词汇量从 50257 到 50304(64 的最近倍数)。”
词汇量大小决定了 GPT 输出层中 matmuls 的维度,这些维度非常大,以至于它们占据了整个训练循环的大部分时间!我们发现它们的性能远低于 A100 GPU 可实现的峰值吞吐量,并从 NVIDIA 的 matmul 文档中猜测 64 元素对齐将产生更好的结果。确实,对这些 matmuls 进行填充几乎实现了 3 倍的速度提升!其根本原因是未对齐的内存访问会显著降低效率。更深入的分析可以在这个 Twitter 帖子中找到。
通过这次优化,我们进一步将训练时间从大约 113 毫秒(使用 flash attention)降低到每批约 87 毫秒。
结果
下图展示了使用 Pytorch 自定义内核获得的效果。以下是具体的数字:
- 基准(nanoGPT 实现):约 143 毫秒
- sdpa_math (通用): ~134ms (快 6.71%)
-
mem_efficient
内核: ~119ms (快 20.16%) -
flash_attention
内核: ~113ms (快 26.54%) - flash_attention + 填充词汇表: ~87ms (快 64.37%)
所有代码均在 8 x NVIDIA 公司 A100 服务器(A100 SXM4 80GB,80 GB HBM)上运行,并且为了本次实验,dropout 被设置为 0。
图 2:使用缩放点积注意力机制和自定义内核以及 torch.compile 可以为训练大型语言模型(如本例中的 nanoGPT)带来显著的加速。
提高数值模型稳定性
除了速度更快之外,PyTorch 的实现通过避免许多执行场景中的精度损失,提供了更高的数值稳定性。这里有一个很好的解释,但基本上 PyTorch 实现是在乘法之前对 Query 和 Key 矩阵进行缩放,这被认为更加稳定并避免了精度损失。由于 SDPA 的合并自定义内核架构,这种缩放不会在计算注意力结果时引入额外的开销。相比之下,来自单个计算组件的实现将需要额外的预缩放,这将产生额外的成本。有关更多解释,请参阅附录 A。
改进的内存消耗
使用 torch SDPA 内核的另一个显著优势是减少内存占用,这允许使用更大的批量大小。以下图表比较了两种实现(闪存注意力与因果注意力基线)在训练一小时后的最佳验证损失。如图所示,基线因果注意力实现(在 8 x NVIDIA Corporation A100 服务器上,80 GB HBM)达到的最大批量大小为 24,远低于闪存注意力实现达到的 39。
图 3:使用闪存注意力可以启用更大的批量大小,使用户在训练一小时后达到更低的验证损失(越小越好)。
结论
加速 PyTorch 2 Transformer 旨在使最先进的 Transformer 模型的训练和生产部署变得经济实惠,并与 PyTorch 2.0 模型 JIT 编译集成。新引入的 PyTorch SDPA 算子为训练 Transformer 模型提供了改进的性能,尤其是在昂贵的 Large Language Model 训练中特别有价值。在这篇文章中,我们展示了在 nanoGPT 模型上的一些优化,包括:
- 与基准模型相比,在保持批次大小不变的情况下,训练速度提高了 26%以上
- 通过填充词汇表实现的额外加速,使总优化效果达到基准的约 64%
- 额外的数值稳定性
附录 A:分析注意力数值稳定性
在本节中,我们提供了对之前提到的通过预缩放 SDPA 输入向量获得的增强数值稳定性的更深入解释。以下是对 nanoGPT 中 SDPA 数学实现的简化版本。这里需要注意的是,查询在进行矩阵乘法之前没有被缩放。
# nanoGPT implementation of SDPA
# notice q (our query vector) is not scaled !
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
# Dropout is set to 0, so we can safely ignore this line in the implementation# att = self.attn_dropout(att)
y_nanogpt = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
以下是在 torch 的 scaled_dot_product_attention
中的等效数学实现。
# PyTorch implementation of SDPA
embed_size = q.size(-1)
scaling_factor = math.sqrt(math.sqrt(embed_size))
q = q / scaling_factor # notice q _is_ scaled here !
# same as above, but with scaling factor
att = q @ (k.transpose(-2, -1) / scaling_factor)
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att0, dim=-1)
# Dropout is set to 0, so we can safely ignore this line in the implementation# att = self.attn_dropout(att)
y_scale_before = att @ v
从数学上讲,两种方法应该是等效的,然而我们的实验表明,在实践中,我们得到了两种方法的不同结果。
使用上述方法,我们验证了 y_scale_before
与使用 scaled_dot_product_attention
方法得到的预期输出匹配,而 y_nanogpt
则不匹配。
使用了 torch.allclose
方法来测试等价性。具体来说,我们展示了:
y_sdpa = torch.nn.functional._scaled_dot_product_attention(
q,
k,
v,
attn_mask=self.bias[:,:,:T,:T] != 0,
dropout_p=0.0,
need_attn_weights=False,
is_causal=False,
)
torch.allclose(y_sdpa, y_nanogpt) # False, indicating fp issues
torch.allclose(y_sdpa, y_scale_before) # True, as expected
附录 B:重现实验结果
寻求重现这些结果的学者应从 Andrej 的 nanoGPT 仓库的以下提交开始 - b3c17c6a363357623f223aaa4a8b1e89d0a465。这个提交被用作测量每批速度提升的基准。对于包含填充词汇优化(这带来了对批速度的最显著提升)的结果,请使用以下提交 - 77e7e04c2657846ddf30c1ca2dd9f7cbb93ddeab。从任一检查点,使用 torch.backends API 选择用于实验的内核变得非常简单。
期望的内核可以通过上下文管理器进行选择:
with torch.backends.cuda.sdp_kernel (
enable_math = False,
enable_flash = False,
enable_mem_efficient = True
):
train(model)