由 Jay Shah 和 Ganesh Bikshandi,Colfax Research,Ying Zhang,Meta,Vijay Thakkar 和 Pradeep Ramani,NVIDIA,Tri Dao,TogetherAI 和普林斯顿大学共同撰写

注意,作为无处不在的 Transformer 架构的核心层,注意力机制是大型语言模型和长上下文应用的一个瓶颈。FlashAttention(以及 FlashAttention-2)开创了一种通过最小化内存读写来加速 GPU 上注意力的方法,现在被大多数库用于加速 Transformer 的训练和推理。这导致了在过去两年中上下文长度在LLM上大幅增加,从 2-4K(GPT-3,OPT)到 128K(GPT-4),甚至 1M(Llama 3)。然而,尽管 FlashAttention 取得了成功,但它尚未充分利用现代硬件的新功能,FlashAttention-2 在 H100 GPU 上仅实现了理论最大 FLOPs 的 35%利用率。在这篇博客文章中,我们描述了三种在 Hopper GPU 上加速注意力的主要技术:利用 Tensor Core 和 TMA 的异步性,通过 warp 专业化(1)重叠整体计算和数据移动,以及(2)交错块矩阵乘法和 softmax 操作,以及(3)利用硬件对 FP8 低精度支持的不可协调处理。

我们非常高兴发布 FlashAttention-3,该版本集成了这些技术。它在 FP16 模式下比 FlashAttention-2 快 1.5-2.0 倍,最高可达 740 TFLOPS,即 H100 理论最大 FLOPS 的 75%利用率。在 FP8 模式下,FlashAttention-3 接近 1.2 PFLOPS,比基线 FP8 注意力机制小 2.6 倍的误差。

FlashAttention-3 可在以下地址获取:https://github.com/Dao-AILab/flash-attention
论文

FlashAttention 概述

FlashAttention 是一种重新排序注意力计算算法,通过使用瓦片化和重新计算来显著提高其速度并减少内存使用,从二次方降低到线性序列长度。我们使用瓦片化从 HBM(GPU 内存)加载输入块到 SRAM(快速缓存),针对该块执行注意力计算,并在 HBM 中更新输出。通过不将大型中间注意力矩阵写入 HBM,我们减少了内存读写次数,从而将墙钟时间速度提高了 2-4 倍。

下面展示了 FlashAttention 正向传播的示意图:通过瓦片化和 softmax 缩放,我们按块操作,避免从 HBM 读取/写入,同时获得正确的输出而不进行近似。

math equations

Hopper GPU 上的新硬件特性 - WGMMA、TMA、FP8

虽然 FlashAttention-2 在 Ampere(A100)GPU 上可以达到 70%的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新特性来最大化性能。我们在这里描述了一些 Hopper 特定的新特性及其重要性。

1. WGMMA(Warpgroup 矩阵乘加)。这一新特性利用了 Hopper 上的新 Tensor 核心,在吞吐量上比 Ampere 中的 older mma.sync 指令有了显著提升(图片来自 H100 白皮书)。

image from the H100 white paper

2. TMA(张量内存加速器)。这是一个特殊的硬件单元,用于加速全局内存和共享内存之间的数据传输,负责所有索引计算和越界预测。这释放了寄存器资源,这对于增加瓦片大小和提高效率是非常宝贵的。

block diagram

3. 低精度 FP8。这可以将 Tensor Core 吞吐量翻倍(例如,FP16 时为 989 TFLOPS,FP8 时为 1978 TFLOPS),但通过使用更少的位来表示浮点数来牺牲精度。

6x throughput

4. FlashAttention-3 利用 Hopper 的所有这些新特性,使用 NVIDIA 的 CUTLASS 库中的强大抽象。

通过将 FlashAttention 重写以使用这些新特性,我们已能显著提高其速度(例如,从 FlashAttention-2 FP16 正向传递的 350 TFLOPS 提升到约 540-570 TFLOPS)。然而,Hopper(WGMMA 和 TMA)上新指令的异步特性为重叠操作提供了额外的算法机会,从而提取出更高的性能。对于这篇博客,我们将解释两种针对注意力的特定技术。在 GEMM(乘法运算)的上下文中,warp 专业化技术(具有分别执行 TMA 和 WGMMA 的生产者和消费者 warp)已有详细阐述,在这里作用相同。

异步:重叠 GEMM 和 Softmax

为什么需要重叠?

注意力有两个主要操作:GEMM(Q 和 K 之间的乘法以及注意力概率 P 和 V 之间的乘法)和 softmax。为什么需要重叠它们?GEMM 不是已经占用了大部分的 FLOPS 吗?只要 GEMM 运算速度快(例如,使用 WGMMA 指令计算),GPU 不就应该嗖嗖地运行吗?

问题在于,在现代加速器上,非矩阵乘法操作的速度远慢于矩阵乘法操作。像指数(用于 softmax)这样的特殊函数的吞吐量甚至低于浮点乘加;它们是通过多功能单元计算的,这是一个与浮点乘加或矩阵乘加分开的单元。以 H100 GPU SXM5 为例,它有 989 TFLOPS 的 FP16 矩阵乘法,但特殊函数只有 3.9 TFLOPS(吞吐量降低 256 倍);对于头维度为 128 的情况,矩阵乘法的 FLOPS 比指数多 512 倍,这意味着指数可能需要比矩阵乘法多 50%的时间。对于 FP8 来说,情况更糟,矩阵乘法的 FLOPS 速度快一倍,而指数的 FLOPS 速度保持不变。理想情况下,我们希望矩阵乘法和 softmax 能够并行操作。当 Tensor 核心忙于矩阵乘法时,多功能单元应该计算指数!

互战群重叠与 pingpong 调度

第一种也是最简单的方法是什么都不做!warp 调度器已经尝试调度 warp,以便如果某些 warp 被阻塞(例如,等待 GEMM 结果),其他 warp 可以运行。也就是说,warp 调度器为我们免费做了一些这种重叠。

然而,我们可以通过手动进行一些调度来改进这一点。例如,如果我们有 2 个 warp 组(标记为 1 和 2 - 每个 warp 组是一组 4 个 warp),我们可以使用同步屏障(bar.sync)来确保 warp 组 1 首先执行其 GEMM(例如,一个迭代的 GEMM1 和下一个迭代的 GEMM0),然后 warp 组 2 执行其 GEMM,同时 warp 组 1 执行其 softmax,依此类推。下面的图中展示了这种“乒乓”调度,其中相同的颜色表示相同的迭代。

block chart

这将使我们能够在其他 warp 组的 GEMM 的阴影下执行 softmax。当然,这个图只是一个夸张的描绘;在实践中,调度并不真的这么干净。尽管如此,pingpong 调度可以将 FP16 注意力正向传递的性能从大约 570 TFLOPS 提高到 620 TFLOPS(头维度 128,序列长度 8K)。

GEMM 和 Softmax 的 warp 组内重叠

即使在同一个 warpgroup 中,我们也可以在运行该 warpgroup 的 GEMMs 的同时运行一部分 softmax。这在图中得到了说明,其中相同的颜色表示相同的迭代。

block chart

这种流水线技术将 FP16 注意力前向的吞吐量从约 620 TFLOPS 提高到约 640-660 TFLOPS,但代价是更高的寄存器压力。我们需要更多的寄存器来存储 GEMMs 的累加器和 softmax 的输入/输出。总的来说,我们发现这项技术提供了有利的权衡。

低精度:通过非相干处理减少量化误差

激活函数中可能存在比其他特征大得多的异常值。这些异常值使得量化变得困难,产生了更大的量化误差。我们利用了非相干处理技术,这是一种在量化文献中使用的技巧(例如来自 QuIP),通过将查询和键与一个随机的正交矩阵相乘来“分散”异常值并减少量化误差。特别是,我们使用了哈达玛变换(带有随机符号),这可以在每个注意力头中以 O(d log d)的时间复杂度完成,而不是 O(d^2)。

在我们的实验中,Q、K、V 是从标准正态分布生成的,但其中 0.1%的条目具有较大的幅度(以模拟异常值),我们发现非相干处理可以将量化误差降低 2.6 倍。以下表格显示了数值误差的比较,请参阅论文以获取详细信息。

text diagram

注意力基准测试

我们展示了使用 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 进行了比较,以及 Triton 和 cuDNN 中的实现(这两者都已经使用了 Hopper GPU 的新硬件特性)。

对于 FP16,我们看到了比 FlashAttention-2 快 1.6 倍到 1.8 倍的速度提升。

speed charts

speed charts

对于 FP8,我们可以达到接近 1.2PFLOPS!

speed charts

讨论

本博客文章重点介绍了在 Hopper GPU 上可用的 FlashAttention 的一些优化。其他优化(例如可变长度序列、持久内核以及内核内的 FP8 转置)在论文中有所涉及。

我们已经看到,设计利用其运行硬件的算法可以带来显著的效率提升并解锁新的模型能力,如长上下文。我们期待未来对LLM推理的优化工作,以及将我们的技术推广到其他硬件架构。

我们也期待 FlashAttention-3 在未来 PyTorch 的版本中得以集成。

备注

  1. 没有 wgmma 指令,较老的 mma.sync 指令只能达到 Hopper Tensor Cores 峰值吞吐量的约 2/3:https://arxiv.org/abs/2402.13499v1 ↩

  2. CUDA 编程指南规定,特殊函数的吞吐量为每个流多处理器(SM)每个时钟周期 16 个操作。我们将 16 乘以 132 个 SM 和 1830 MHz(用于计算 989 TFLOPS 的 FP16 矩阵乘法的时钟速度),得到 3.9 TFLOPS ↩