由阿丹·侯克、莱斯·赖特、拉古·甘蒂和穆达卡·斯里瓦萨著

在本文中,我们讨论了使用 OpenAI 的 Triton 语言实现 FP16 推理的方法,这些方法适用于 Meta 的 Llama3-8B 和 IBM 的 Granite-8B 代码等流行的LLM模型,其中 100%的计算都使用 OpenAI 的 Triton 语言完成。
对于使用基于 Triton 内核的模型进行单令牌生成的时间,我们能够将 Llama 和 Granite 在 Nvidia H100 GPU 上的 CUDA 内核主导工作流程的性能提升到 0.76-0.78 倍,在 Nvidia A100 GPU 上提升到 0.62-0.82 倍。

为什么探索使用 100%的 Triton?Triton 为在 NVIDIA、AMD 等不同类型的 GPU 上运行LLMs提供了一条路径,未来还将支持 Intel 和其他基于 GPU 的加速器。它还为 Python 编程 GPU 提供了更高层次的抽象,使我们能够比使用供应商特定 API 编写性能更好的内核更快地编写内核。在本文的其余部分,我们将分享如何实现 CUDA-free 计算,对单个内核进行微基准测试以进行比较,并讨论如何进一步改进未来的 Triton 内核以缩小差距。

图 1. 使用 Triton 和 CUDA 变体的 Llama3-8B 和 Granite-8B 推理吞吐量基准,在 NVIDIA H100 和 A100 上
设置:批处理大小=2,输入序列长度=512,输出序列长度=256

2.0 Transformer 块的组成

我们从 Transformer 模型中发生的计算分解开始。下方的图展示了典型 Transformer 块的“内核”。

图 2. Transformer 块的核心内核

Llama3 架构的核心操作总结如下列表:

  1. RMSNorm
  2. 矩阵乘法:Fused QKV
  3. RoPE
  4. 注意
  5. 矩阵乘法:输出投影
  6. RMSNorm
  7. 矩阵乘法:融合门 + 上投影
  8. 激活函数:SiLU
  9. 元素级乘法
  10. 矩阵乘法:下投影

这些操作都是通过执行一个(或多个)内核在 GPU 上进行的。虽然不同变压器模型中每个内核的具体细节可能有所不同,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层使用偏差,与 Llama3 不同。这些更改确实需要修改内核。典型的模型是由这些变压器块堆叠而成,并通过嵌入层连接在一起。

3.0 模型推理

典型的模型架构代码与 Python 模型.py 文件共享,该文件由 PyTorch 启动。在默认的 PyTorch eager 执行模式下,这些内核都使用 CUDA 执行。为了实现 Llama3-8B 和 Granite-8B 端到端推理的 100% Triton,我们需要编写和集成手写的 Triton 内核,并利用 torch.compile(以生成 Triton 操作)。首先,我们将较小的操作替换为编译器生成的 Triton 内核,其次,我们将更昂贵和复杂的计算(例如矩阵乘法和闪存注意力)替换为手写的 Triton 内核。

Torch.compile 自动为 RMSNorm、RoPE、SiLU 和逐元素乘法生成 Triton 内核。使用 Nsight Systems 等工具可以观察这些生成的内核;它们在矩阵乘法和注意力之间以微小的深绿色内核形式出现。

图 3. 使用 torch.compile 的 Llama3-8B 追踪,显示用于矩阵乘法和闪存注意力的 CUDA 内核。

对于上述追踪,我们注意到构成 Llama3-8B 风格模型端到端延迟 80% 的两个主要操作是矩阵乘法和注意力内核,并且两者仍然是 CUDA 内核。因此,为了缩小剩余的差距,我们将矩阵乘法和注意力内核都替换为手写的 Triton 内核。

4.0 Triton SplitK GEMM 内核

对于线性层中的矩阵乘法,我们编写了一个定制的 FP16 Triton GEMM(通用矩阵-矩阵乘法)内核,该内核利用了 SplitK 工作分解。我们之前在其他博客中讨论了这种并行化方法,作为加速LLM推理解码部分的一种方式。

5.0 GEMM 内核调优

为了实现最佳性能,我们使用了穷举搜索方法来调整我们的 SplitK GEMM 内核。Granite-8B 和 Llama3-8B 的线性层具有以下形状:

线性层 形状(in_features, out_features)
混合 QKV 投影 (4096, 6144)
输出投影 (4096, 4096)
混合门 + 上投影 (4096, 28672)
下降投影 (14336, 4096)

图 4. Granite-8B 和 Llama3-8B 线性层权重矩阵形状

这些线性层具有不同的权重矩阵形状。因此,为了获得最佳性能,Triton 内核必须针对每个形状配置文件进行调整。调整每个线性层后,我们能够在 Llama3-8B 和 Granite-8B 上实现 1.20 倍的 E2E 速度提升,相对于未调整的 Triton 内核。

6.0 闪存注意力内核

我们评估了现有 Triton 闪存注意力核的套件,并使用不同的配置进行了测试,具体包括:

  1. AMD 闪存
  2. OpenAI 闪存
  3. Dao AI Lab 闪存
  4. XFormers Flash
  5. PyTorch FlexAttention

我们首先在急切模式下评估了这些核的文本生成质量,然后(如果能够使用标准方法 torch.compile 编译核)在编译模式下进行评估。对于 2-5 号核,我们注意到以下情况:

内核 文本生成质量 Torch.compile 支持任意序列长度
AMD 闪存 一致性 是的 是的
OpenAI 闪 不连贯 未评估。首先调试急切模式下的精度
Dao AI Lab 闪存 不连贯 未进行评估。首先在急切模式下调试精度 是的
Xformers FlashDecoding 在我们能够评估文本质量之前遇到了编译错误 WIP 否(此内核针对解码进行了优化)
PyTorch FlexAttention 一致 WIP WIP

图 5. 我们尝试的不同 Flash Attention 内核的组合表

上述表格总结了我们的初始观察结果。经过一些努力,我们预计内核 2-5 可以修改以满足上述标准。然而,这也表明,拥有一个适用于基准测试的内核通常只是将其作为端到端生产内核使用的起点。
在随后的测试中,我们选择使用 AMD 闪存注意力内核,因为它可以通过 torch.compile 进行编译,并在急切模式和编译模式下产生可读的输出。

为了满足 torch.compile 与 AMD 闪存注意力内核的兼容性,我们必须将其定义为 torch 自定义操作符。该过程在此处有详细说明。教程链接讨论了如何包装简单的图像裁剪操作。然而,我们注意到包装更复杂的闪存注意力内核遵循类似的过程。两步方法如下:

  1. 将函数包装成 PyTorch 自定义操作符

  1. 将一个假 Tensor 内核添加到算子中,该内核根据闪存(q、k 和 v)输入张量的形状提供计算闪存内核输出形状的方法

在将 Triton 闪存内核定义为自定义操作后,我们成功将其编译到我们的端到端运行中

图 6.使用 torch.compile 替换 Triton matmul 和 Triton 闪存注意力内核后 Llama3-8B 的跟踪

从图 5 中,我们注意到现在,在集成 SplitK 矩阵乘法内核、torch op 包装的闪存注意力内核,然后运行 torch.compile 之后,我们能够实现使用 100% Triton 计算内核的前向传递

7.0 端到端基准测试

我们在 NVIDIA H100s 和 A100s(单 GPU)上使用 Granite-8B 和 Llama3-8B 模型进行了端到端测量。我们的基准测试采用了两种不同的配置。

Triton 内核配置使用:

  1. Triton SplitK GEMM
  2. AMD Triton 闪存注意力

CUDA 内核配置使用:

  1. cuBLAS GEMM
  2. cuDNN 闪存注意力 - 比例点积注意力(SDPA)

我们找到了以下吞吐量和两种模式(急切模式和 torch 编译模式)的介词延迟,在典型的推理设置下:

GPU Model 内核配置 中值延迟(急切)[ms/词] 中值延迟(编译)[ms/词]
H100 Granite-8B 特里顿 27.42 11.59
    CUDA 18.84 9.50
  羚羊 3-8B 特里顿 20.36 10.61
    CUDA 16.59 8.59
A100 Granite-8B 特里顿 53.44 16.88
    CUDA 37.13 14.25
  Llama3-8B 特里顿 44.44 17.94
    CUDA 32.45 12.96

图 7. Granite-8B 和 Llama3-8B 在 H100 和 A100 上的单 token 生成延迟,
(批大小=2,输入序列长度=512,输出序列长度=256)

总结来说,Triton 模型在 H100 上可以达到 CUDA 模型的 78%的性能,在 A100 上可以达到 82%。

性能差距可以由我们观察到的矩阵乘法和 Flash 注意力内核延迟来解释,这些内容将在下一节中讨论。

8.0 微基准测试

内核 Triton [微秒] CUDA [美国]
QKV 投影矩阵乘法 25 21
闪速注意力 13 8
输出投影矩阵乘法 21 17
网格 + 上投影矩阵乘法 84 83
下投影矩阵乘法 58 42

图 8. Triton 与 CUDA 内核延迟比较(Llama3-8B 在 NVIDIA H100 上)
输入为任意提示(bs=1,提示=44 个序列长度),解码延迟时间

从上述内容中,我们注意到以下几点:

  1. Triton 矩阵乘法内核比 CUDA 慢 1.2-1.4 倍

  2. AMD 的 Triton Flash Attention 内核比 CUDA SDPA 慢 1.6 倍

这些结果突出了进一步改进核心原语如 GEMM 和 Flash Attention 内核性能的必要性。我们将此作为未来研究课题,因为最近的工作(例如 FlashAttention-3、FlexAttention)提供了更好地利用底层硬件以及 Triton 路径的方法,我们希望在此基础上实现更大的速度提升。为了说明这一点,我们比较了 FlexAttention 与 SDPA 和 AMD 的 Triton Flash 内核。

我们正在努力验证 FlexAttention 的端到端性能。目前,Flex 的初始微基准测试显示出对于更长的上下文长度和解码问题形状的潜力,其中查询向量较小:

图 9. NVIDIA H100 SXM5 80GB 上的 FlexAttention 内核基准测试
(batch=1, num_heads=32, seq_len=seq_len, head_dim=128)

9.0 未来工作

对于未来的工作,我们计划探索进一步优化我们的 matmuls 的方法,使其更好地利用硬件,例如我们发布的关于利用 TMA 进行 H100 的博客,以及不同的工作分解(如 StreamK 等持久内核技术)以获得对基于 Triton 的方法的更大加速。对于闪存注意力,我们计划探索 FlexAttention 和 FlashAttention-3 作为这些内核中使用的技巧可以进一步缩小 Triton 和 CUDA 之间的差距。
我们还注意到,我们之前的工作已经显示出 FP8 Triton GEMM 内核性能相对于 cuBLAS FP8 GEMM 的令人鼓舞的结果,因此在未来的一篇文章中,我们将探索 E2E FP8 LLM推理。