1.0 摘要
我们表明,通过实施列主调度以改善数据局部性,我们可以在 A100 上加速 Triton GEMM(通用矩阵-矩阵乘法)核心,对于 MoEs(专家混合)最多提高 4 倍,在 H100 Nvidia GPU 上最多提高 4.4 倍。本文展示了 MoE GEMMs 的几种不同的工作分解和调度算法,并在硬件级别上说明了为什么列主调度产生最高的加速。
代码库和代码可在以下链接获取:https://github.com/pytorch-labs/applied-ai/tree/main/kernels/triton/inference/col_major_moe_gemm。
图 1A. A100 上不同批量大小的优化融合 MoE GEMM 内核 TFLOPs
图 1B. H100 上不同批量大小的优化融合 MoE GEMM 内核 TFLOPs
2.0 背景
OpenAI 的 Triton 是一个硬件无关的语言和编译器,正如我们之前的博客文章所展示的,它可以用于加速量化工作流程。我们还展示了在内核开发方面,许多来自 CUDA 的学习和性能分析工具可以用于提供类似的见解,了解 Triton 内核在底层的工作方式以及如何通过这些方法在延迟敏感的环境中加速这些内核。随着 Triton 在生产环境中的日益普及,了解开发高性能内核的常见技巧以及这些方法对不同架构和工作流程的通用性变得非常重要。因此,本文将探讨我们如何使用经典技术优化了 vLLM 开发的流行混合专家(MoE)Mixtral 模型的 Triton 内核,以及这些技术如何在 Triton 中实现性能提升。
Mixtral 8x7B 是一种稀疏混合专家语言模型。与经典的密集型 Transformer 架构不同,每个 Transformer 块包含 8 个 MLP 层,其中每个 MLP 都是一个‘专家’。当标记流经时,路由网络会选择 8 个专家中的哪两个来处理该标记,然后将结果合并。对于同一标记,所选的专家在每个层中都会变化。因此,虽然 Mixtral 8x7B 总共有 47B 个参数,但在推理过程中只有 13B 个参数是活跃的。
MoE GEMM(通用矩阵-矩阵乘法)内核接收一个包含所有专家的堆叠权重矩阵,并且必须随后通过路由网络产生的映射数组将每个标记路由到 TopK(对于 Mixtral 来说是 2)个专家。在这篇文章中,我们提供了在推理时间,特别是在自回归(或解码阶段)期间高效并行化此计算的方法。
3.0 工作分解 - SplitK
我们之前已经表明,对于在LLM推理中发现的矩阵问题规模,特别是在 W4A16 量化推理的背景下,通过应用 SplitK 工作分解,GEMM 内核可以加速。因此,我们开始我们的 MoE 加速研究,通过在 vLLM MoE 内核中实现 SplitK,这相对于数据并行方法产生了大约 18-20%的速度提升。
这个结果表明,SplitK 优化可以作为改进/开发推理设置中 Triton 内核的更规范方法的一部分。为了对这些不同的工作分解有更直观的了解,让我们考虑一个简单的例子,即两个 4x4 矩阵的乘法,SplitK=2。
在下面的数据并行 GEMM 内核中,输出矩阵的单个块的计算将由 1 个线程块 TB0 处理。
图 2. 数据并行 GEMM
与之相反,在 SplitK 内核中,计算输出矩阵中 1 个块所需的工作被“分割”或共享在两个线程块 TB0 和 TB1 之间。这提供了更好的负载均衡和更高的并行性。
图 3. SplitK GEMM
关键思想是我们将并行性从 MN 增加到 MN*SplitK。这种方法确实会带来一些成本,例如通过原子操作添加线程块间的通信。然而,与节省其他受限 GPU 资源(如共享内存和寄存器)相比,这些成本微乎其微。最重要的是,SplitK 策略为瘦矩阵提供了优越的负载均衡特性(如 MoE 推理中的情况),并且在解码和推理过程中是常见的矩阵配置文件。
4.0 GEMM 硬件调度 - 列主序
为了在 SplitK 的基础上进一步提高约 20%的速度,我们专注于研究控制 Triton 内核中 GEMM 硬件调度的逻辑。我们对 vLLM MoE 内核的剖析显示 L2 缓存命中率较低,因此我们调查了三种调度选项——列主序、行主序和分组发射。由于 MoE 模型的一些固有属性,例如大型专家矩阵,以及内核运行期间需要动态加载 TopK(Mixtral 为 2)矩阵,缓存重用/命中率成为瓶颈,这次优化将针对这一点。
作为背景,在我们之前的博客中,我们提到了“瓦片交换”的概念,这是一种提高 L2 缓存命中率的技巧。这个概念与软件如何将 GEMM 调度到 GPU 的 SM 上有关。在 Triton 中,这个调度由 pid_m 和 pid_n 计算决定。我们的关键洞察是,对于瘦矩阵乘法,列主序排序可以确保权重矩阵 B 的列的最佳重用。为了说明这一点,让我们看看 pid_m 和 pid_n 的列主序计算的代码片段:
图 4. PyTorch 中的列主序排序
从上方可以看出,通过这种映射,我们安排 GEMM 的计算,使得我们按照以下顺序计算 C 的输出块:C(0, 0),C(1, 0),C(2, 0),等等。为了理解其影响,我们提供了以下插图:
图 5.列主序 GEMM 调度缓存重用模式
在上述列主序调度的简化视图中,让我们假设对于具有瘦激活矩阵 A 的 GEMM,整个矩阵可以完全适应 GPU 缓存,这对于我们在 MoE 推理中遇到的问题规模类型来说是一个合理的假设。这允许最大程度地重用权重矩阵 B 的列,因为 B 列可以被重用于相应的输出块计算,如 C(0,0),C(1,0)和 C(2,0)。相反,考虑行主序调度,C(0,0),C(0,1),C(0,2)等等。我们不得不驱逐 B 列,并发出多个加载指令到 DRAM 来计算相同数量的输出块。
在优化内核时,一个重要的设计考虑因素是导致全局加载指令最少的内存访问模式。这种最优的内存访问模式通过列主序调度实现。以下结果展示了我们调查的三种调度方案的性能:
图 6. A100 上不同批大小 M 的 GEMM 调度方案的比较
列主序调度比其他模式快 4 倍以上,正如我们将在下一节中展示的,它提供了最优的内存访问模式,因为数据局部性得到了显著提高。
5.0 Nsight Compute 分析 - 吞吐量和内存访问模式
为了性能分析,我们专注于 H100 的 M=2 情况。对于 A100 也可以进行类似的研究,因为许多相同的观察结果可以延续。我们注意到以下显著结果,这些结果展示了我们优化的影响。
图 7. M=2 时 H100 内存吞吐量图表。注意缓存命中率的大幅提升,L1 缓存命中率(+2696%)和 L2 缓存命中率(+254%)。
图 8. H100 内存指令统计 M=2。注意全局内存加载减少了 49%。
这些统计数据表明我们的优化产生了预期效果,这可以从减少的缓存未命中、减少的内存访问以及结果 2.7 倍的速度提升中看出。更具体地说,跟踪显示 L2 命中率增加了 2.54 倍(图 7),DRAM 访问减少了约 50%(图 8)。
这些改进最终降低了延迟,优化后的内核在 bs=2 时速度提高了 2.7 倍,在 bs=512 时提高了 4.4 倍。
6.0 未来工作
我们的内核在 FP16 下进行了测试,展示了 MoE 的列主调度在数值和性能方面的优势,但大多数生产模型都在使用 BFloat16。我们在 Triton 中遇到了一个限制,即 tl.atomic_add 不支持 BFloat16,并且遇到了启动延迟问题,这需要 cuda graph 支持才能用于列主生产使用。在初步测试中,这转化为 70%的端到端速度提升,但我们遇到了一些在端到端环境中出现而在测试环境中没有反映的专家映射不一致性,因此需要进一步工作才能完全实现这些速度提升。
对于未来的工作,我们打算将其移入 CUDA 内核,这将确保完全支持 BFloat16 并相对于 Triton 降低启动延迟,并可能解决专家路由不一致性问题。我们之前还发表了关于启用 Triton GEMM 内核的 GPTQ W4A16 的工作,因此自然的后续工作将包括将去量化融合到这个内核中,以允许 GPTQ 量化推理路径。
7.0 可重复性
我们已开源 Triton 内核代码,并为对比较或验证其性能感兴趣的读者提供了一个易于运行的性能基准。
致谢
感谢 Daniel Han、Raghu Ganti、Mudhakar Srivatsa、Bert Maher、Gregory Chanan、Eli Uriegas 和 Geeta Chauhan 对所展示材料的审阅,以及 vLLM 团队的 Woosuk,因为我们基于他对 Fused MoE 内核的实现进行了构建。