在这篇博客中,我们讨论了如何使用 PyTorch 原生优化(如原生快速内核、从 torch compile 编译的转换和用于分布式推理的张量并行)来提高 Llama 2 系列模型的推理延迟。我们的方法使得在 70B LLaMa 模型上(在 8 个 A100 GPU 上测量)的单用户请求延迟达到 29ms/token。我们很高兴与社区分享我们的发现,并将我们的代码在此处提供。
背景
我们正处于一场生成式 AI 革命之中,数十亿参数的大型语言模型正变得商品化并可供使用。然而,社区普遍认为,以成本效益的方式部署这些大型模型仍然是一个关键挑战。已经尝试了多种不同的方法,程度各异,提供了不同的权衡。针对特定硬件的优化(例如,NVIDIA 的 Faster Transformer)仅限于特定目标硬件,而依赖于抽象层的方法(例如,ONNX)可以实现任意模型,但会损失效率。去年,随着 PyTorch compile 的推出,IBM 和 PyTorch 团队开始探索使用模型编译进行推理优化,目标是降低生成模型的每 token 延迟。
模型选择
由于 Llama 2 系列模型广受欢迎,我们选择对其进行基准测试。我们感兴趣的模型及其与本文相关的超参数如下表所示:
模型大小 | 隐藏维度 | 头数 | 层数 | 注意力类型 |
7B | 4096 | 32 | 32 | MHA |
13B | 5120 | 40 | 40 | MHA |
70B | 8192 | 64 | 80 | GQA |
这些模型仅包含解码器,这意味着标记以序列化的方式生成,通常使用 KV 缓存来加速。我们在延迟和吞吐量测量中也采取了类似的方法。
推理方法
我们对推理的目标是提供一个快速实现最佳可能延迟的路径,以跟上社区中新模型架构出现的速度。使用 PyTorch 原生方法很有吸引力,因为它在“覆盖”模型方面提供了最大的灵活性。我们注意到有四种正交技术可以加速推理:(a) 使用 compile 进行内核融合,(b) 更快的内核,(c) 张量并行用于更大模型,以及(d) 量化。在我们的方法中,我们使用了这四种杠杆中的前三种 - 与 SDPA 中的更快内核和自定义张量并行实现一起工作的本地编译,所有这些协同工作,在 8 个 NVIDIA A100 GPU 上以单用户模式测量,实现了 29ms/标记的推理延迟。
全程编译!
PyTorch 编译利用追踪和图捕获来降低 CPU 开销,在理想情况下,从 CPU 到 GPU 的结果是单个图执行/指令。然而,由于模型架构和编译不支持的操作,编译通常会引入图断开。例如,像 einops 这样的复杂操作今天并不支持编译。同样,张量并行推理可以在每一层引入图断开,因为编译需要张量并行实现使用可追踪的通信集合。如果这些图断开没有被移除,编译后实体的性能将会受到影响,甚至可能低于急切模式的执行。为了充分利用编译后的实体,需要移除图断开。
下面,我们将描述我们如何为 70b Llama 2 模型进行这项工作,以及我们克服的挑战,以使编译能够完全工作。
我们第一次尝试使用 torch.compile 编译预置的 Llama 2 模型,但失败了,因为不支持复杂操作。通过设置 TORCH_COMPILE_DEBUG = 1,我们发现了 RoPE 位置编码使用了复数函数,导致图断开和显著降速。我们重写了 RoPE 函数,绕过 torch.einsum(原始实现使用 torch.polar,也会与编译冲突),并使用 torch.cos 和 torch.sin 代替。
self.cached_freqs[dev_idx][alpha] = torch.stack(
[
torch.cos(freqs),
-torch.sin(freqs),
torch.sin(freqs),
torch.cos(freqs),
],
dim=2,
).view(*freqs.shape, 2, 2)
我们实现的频率计算
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
Hugging Face 的频率计算实现
修复 RoPE 后,我们能够在单个 A100 GPU 上成功编译 7B 和 13B 模型,没有任何图断开。
我们使用了 SDPA,这是 PyTorch 的本地实现,具有启用跟踪的高效注意力计算(用于编译)。为了避免使用 Python 上下文强制选择单个算法导致的图断裂问题,我们不得不使用 torch.backends.cuda.enable_*_sdp
函数。
attn = torch.nn.functional.scaled_dot_product_attention(
queries,
keys_e,
values_e,
attn_mask=attn_mask,
dropout_p=self.p_dropout if self.training else 0.0,
is_causal=is_causal_mask,
)
使用 SDPA 进行注意力计算
接下来我们对更大的 70B 模型执行了相同的步骤,发现即使使用半精度,该模型也无法适应单个 GPU,需要张量并行推理。使用 torch.compile 对 70B 模型进行编译导致了 162 个图断裂,这是由于每层有两个 all-reduces 操作,一个用于前向嵌入的全局收集,另一个用于反向嵌入的全局收集。因此,我们没有看到推理延迟的显著提升。在撰写这篇博客时,我们无法使用 PyTorch 的分布式张量实现,因为它不支持 compile。我们从头开始重写了张量并行代码,使其仅依赖于可追踪的集体操作,以便与 compile 一起工作。在这最后一次更改之后,PyTorch 编译器没有引入任何图断裂,我们在推理延迟方面看到了显著的加速。具体来说,我们测量了使用 8 个 A100 GPU 时 Llama 70B 模型的延迟为 29ms/标记,比未优化的推理提高了 2.4 倍。
服务方面
最后,这里有一个需要注意的点,那就是仅仅对模型进行编译在生产环境中并不足够。为了实现上述性能并具有高吞吐量,我们需要支持动态批处理、嵌套张量,以及预热阶段,在这个阶段我们为分桶序列长度预编译。我们正在努力实现这些方面,以便在生产环境中实现这样的性能。
实验和测量
我们使用带有 8 个 A100 NVIDIA GPU 和 80G 显卡的节点,在两个不同的环境中(IBM Cloud 和 AWS,均运行 OpenShift)进行所有测量。首先,我们比较各种技术——急切模式、带有 SDPA Flash 内核、编译以及编译和 SDPA。对于 70B 模型,我们以编译和 SDPA 的方式在 Tensor Parallel 模式下运行它。对于这个实验,我们使用 512 个 token 作为输入长度,生成 50 个 token。对于 7B 和 13B 模型,我们使用单个 A100 来测量延迟,而对于 70B 模型,我们使用 8 个 A100。此外,对于 70B 模型,我们使用 PyTorch 编译中的 reduce-overhead 选项,该选项使用 CudaGraphs 来减少 CPU 到 GPU 内核启动的开销;在 7B 和 13B 模型中使用 CudaGraphs 没有显示出任何好处(因此在此未报告)。从图 1 中我们可以观察到,编译和 SDPA 提供了非常低的延迟,70B Llama 2 模型在 29ms/token。
图 1:不同技术下序列长度为 512 的中间延迟(在 IBM Cloud A100 服务器上测量)
接下来,我们研究序列长度的影响,将序列长度从 1024 增加到 4096,观察到每个标记的中位延迟呈亚线性增长,这表明当我们增加大文档的上下文时,我们不会牺牲响应时间。
图 2:不同序列长度下编译+SDPA 的中位延迟(在 AWS 的 A100s 上测量)
最后,随着批量大小的增加,我们观察到响应延迟呈亚线性增长。对于 13B 模型,在批大小为 8 时,我们遇到了内存不足错误。对于 70B 模型,由于它在 8 个 GPU 上运行,并且使用了张量并行,我们没有看到任何这样的内存不足错误。
图 3:在批大小和序列长度固定为 4096 的情况下编译+SDPA 的中位延迟(在 AWS 的 A100s 上测量)
总结
我们已经展示了如何使用 PyTorch 编译路径进行推理,实现了 70B 模型推理的超低延迟。下一步是启用动态批处理和嵌套张量。
特别感谢来自 PyTorch 团队的 Edward Yang、Elias Ellison、Driss Guessous、Will Feng、Will Constable、Horace He、Less Wright 和 Andrew Gu,他们的 PR 审查和代码贡献使得我们能够使用 PyTorch 原生方法实现延迟。我们感谢不断努力使 PyTorch 变得更好的更广泛的 PyTorch 团队,特别感谢 SDPA 团队使跟踪和编译在快速内核上成为可能,编译团队在如何处理和修复问题(包括在 CUDA 图中识别和报告 NVIDIA 驱动程序错误)方面给予了我们密切的指导。
推理延迟一直是LLM在关键企业工作流程中应用的一个障碍,但另一个主要障碍是需要安全性、可靠性和治理。IBM 的 AI 安全指南和LLM风险指南可在此找到,Meta 的 LLaMa 负责任用户指南可在此找到。
参考文献列表
- GitHub 资源:https://ibm.biz/fm-stack
- 使用 PyTorch/XLA 实现 LLaMa 65B 超低推理延迟的路径
- 速度、Python:只能选两个。CUDA Graph 如何实现深度学习快速 Python 代码
- IBM 关于 AI 伦理和信任的资源:https://www.ibm.com/downloads/cas/E5KE5KRZ
- Meta LLaMa 负责任用户指南:https://ai.meta.com/llama/responsible-use-guide/