摘要
被誉为“首个真正异步的 GPU”的 Hopper(H100)GPU 架构,包括一个全新的、完全异步的硬件复制引擎,用于在全局和共享内存之间进行大量数据移动,称为张量内存加速器(TMA)。虽然 CUTLASS 通过其异步管道范式内置了对 TMA 的支持,但 Triton 通过实验性 API 公开了 TMA 支持。
在本文中,我们深入探讨了 TMA 的工作原理细节,以便开发者了解新的异步复制引擎。我们还展示了利用 TMA 对 H100 内核的重要性,通过在 Triton 中构建 TMA 启用 FP8 GEMM 内核,对于小到中等规模的问题,性能提升可达 1.4-2.2 倍,超过 cuBLAS FP16。最后,我们展示了 Triton 和 CUTLASS 之间的关键实现差异,这些差异可能是 Triton 中 TMA 性能退化的报告原因。我们开源我们的实现,以便于复现和审查,请访问 https://github.com/pytorch-labs/applied-ai/tree/main/kernels。
图 1. 在 M=M,N=4096,K=4096 的情况下,各种 Triton 和 cuBLAS FP8 和 FP16 内核的吞吐量(以 TFLOPs 为单位)。红色线表示 Triton TMA,展示了利用 TMA 的优势。
TMA 背景
TMA 是 H100 硬件扩展,允许应用程序异步、双向地在 GPU 全局和共享内存之间传输 1D-5D 张量。此外,TMA 还可以将相同的数据传输到调用 SM 的共享内存,以及其他属于同一线程块集群的 SM 的共享内存。这被称为“多播”。
TMA 非常轻量,只需要一个线程来启动 TMA 传输。通过直接将数据从 GMEM(全局)移动到 SMEM(共享),这避免了早期 GPU 在移动不同内存空间之间的数据时使用寄存器的需求。
图 2. A100 风格的数据移动与 H100 相比,TMA 硬件消除了大量线程和寄存器参与批量数据传输的需求。(图片来源:Nvidia)
单个线程可以发出大型数据移动指令,允许大多数线程块在数据传输过程中继续执行其他指令。结合异步流水线,这使得内存传输可以轻松隐藏,并确保任何给定线程块集群的大部分可以专注于计算任务。
这种轻量级的数据移动调用使得创建 warp 组专用内核成为可能,其中 warp 组扮演不同的角色,即生产者和消费者。生产者选举一个领导者线程来触发 TMA 请求,这些请求随后通过到达屏障异步协调与消费者(MMA)warp 组。消费者随后使用 warp 组 MMA 处理数据,并在完成从 SMEM 缓冲区读取后向生产者发出信号,然后周期重复。
此外,在 threadblock 集群中,生产者可以降低其最大寄存器需求,因为它们只发出 TMA 调用,并将额外的寄存器有效地转移到 MMA 消费者,这有助于减轻消费者的寄存器压力。
此外,TMA 处理共享内存目标地址的计算,即请求的数据应该放置的位置。这就是为什么调用线程(生产者)可以如此轻量级。
为了确保最大读取访问速度,TMA 可以根据交换指令布局到达的数据,以确保消费者可以尽可能快地读取到达的数据,因为交换模式有助于避免共享内存银行冲突。
最后,对于即将发出的 TMA 指令,或者从 SMEM 移动到 GMEM 的数据,TMA 还可以包括归约操作(加/减/最大值)和位操作(与/或)。
Triton 中 TMA 的使用
预 Hopper 加载:
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k[None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k[:, None]*stride_bk + offs_bn[None, :]*stride_bn)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
图 3. Triton 中从全局内存到共享内存的传统风格批量加载
在上述 Triton 示例中,展示了一个预 Hopper 加载,我们看到每个线程块如何通过从其相关的程序 ID(pid_m、pid_n、k)计算全局偏移量(a_ptrs、b_ptrs)来加载张量 a 和 b 的数据,然后请求将 a 和 b 的内存块移动到共享内存中。
现在我们来探讨如何在 Triton 中使用 TMA 进行加载。
与上面直接传递全局内存指针不同,TMA 指令需要一个特殊的数据结构,即张量映射。为了构建张量映射,我们首先在 CPU 上创建一个 TMA 描述符。描述符通过使用 cuTensorMapEncode API 来处理张量映射的创建。张量映射包含诸如张量的全局和共享内存布局等元数据,并作为存储在全局内存中的多维张量结构的压缩表示。
图 4. 通过复制描述符生成 TMA 地址(图片来源:Nvidia)
TMA 描述符包含张量的关键属性:
- 基指针
- 形状和块大小
- 数据类型
TMA 描述符在内核之前在主机上创建,然后通过将描述符传递给 torch 张量将其移动到设备。因此,在 Triton 中,GEMM 内核接收一个指向张量映射的全局指针。
三叉戟主机代码
desc_a = np.empty(TMA_SIZE, dtype=np.int8)
desc_b = np.empty(TMA_SIZE, dtype=np.int8)
desc_c = np.empty(TMA_SIZE, dtype=np.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c)
desc_a = torch.tensor(desc_a, device='cuda')
desc_b = torch.tensor(desc_b, device='cuda')
desc_c = torch.tensor(desc_c, device='cuda')
这是用于在内核调用函数中设置描述符的代码。
三叉戟设备代码
偏移/指针算术:
offs_am = pid_m * block_m
offs_bn = pid_n * block_n
offs_k = 0
加载:
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)
存储:
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
我们不再需要在内核中为加载和存储函数计算指针数组。相反,我们传递一个单独的描述符指针、偏移量、块大小和输入数据类型。这简化了地址计算并减少了寄存器压力,因为我们不再需要在软件中进行复杂的指针运算并分配 CUDA 核心进行地址计算。
TMA 性能分析
下面,我们讨论 Hopper 上不同加载机制的 PTX 指令。
PTX 用于加载瓦片(cp.async)- H100 无 TMA
add.s32 %r27, %r100, %r8;
add.s32 %r29, %r100, %r9;
selp.b32 %r30, %r102, 0, %p18;
@%p1 cp.async.cg.shared.global [ %r27 + 0 ], [ %rd20 + 0 ], 0x10, %r30;
@%p1 cp.async.cg.shared.global [ %r29 + 0 ], [ %rd21 + 0 ], 0x10, %r30;
cp.async.commit_group ;
在这里,我们观察到负责全局内存复制的较老 cp.async 指令。从下面的跟踪中我们可以看到,两次加载都绕过了 L1 缓存。与较新的 TMA 加载的一个主要区别是,在 A 和 B 瓦片准备好被 Tensor Core 消费之前,我们需要执行一个操作于寄存器文件中的数据的 ldmatrix 指令。在 Hopper 上,现在可以直接从共享内存中重用数据。
图 5. H100 内存图表显示 GMEM 吞吐量=910.22 GB/s(Triton GEMM 无 TMA)对于 M=128,N=4096,K=4096
通过利用上述 Triton API 更改中的 TMA,我们可以调查 Triton 为单个 2D 瓦片加载生成的 PTX。
PTX 用于加载瓦片(cp.async.bulk.tensor)- 使用 TMA 的 H100
bar.sync 0;
shr.u32 %r5, %r4, 5;
shfl.sync.idx.b32 %r66, %r5, 0, 31, -1;
elect.sync _|%p7, 0xffffffff;
add.s32 %r24, %r65, %r67;
shl.b32 %r25, %r66, 7;
@%p8
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r24], [%rd26, {%r25,%r152}], [%r19];
cp.async.bulk.tensor.2d.shared TMA 指令分别传入共享内存中的目标地址、张量映射指针、张量映射坐标以及 mbarrier 对象指针。
图 6. H100 内存图 GMEM 吞吐量=1.45 TB/s(Triton GEMM 使用 TMA)对于 M=128,N=4096,K=4096
为了获得最佳性能,我们对 TMA GEMM 内核进行了大量调整。在瓦片大小、线程束数量和流水线阶段数量等参数中,当我们将 TMA_SIZE(描述符大小)从 128 增加到 512 时,观察到内存吞吐量最大提升。从上述 NCU 配置文件中,我们可以看到最终调整后的内核将全局内存传输吞吐量从 910 GB/s 提升到 1.45 TB/s,GMEM 吞吐量提高了 59%,超过了非 TMA 的 Triton GEMM 内核。
CUTLASS 与 Triton FP8 GEMM 和 TMA 实现的比较 - 内核架构
图 7. Triton 与 CUTLASS Ping-Pong FP8 GEMM TFLOPs,M=M,N=4096,K=4096
上图显示了 CUTLASS Ping-Pong GEMM 内核与 Triton 的性能对比。Ping-Pong 内核在 TMA 的使用上与 Triton 不同。它充分利用了其硬件和软件软件能力,而 Triton 目前尚未做到。具体来说,CUTLASS 支持以下 TMA 特性,有助于解释纯 GEMM 性能差距:
-
TMA 多播
- 使 GMEM 数据能够复制到多个 SM
-
虚拟化特殊化
- 允许线程块内的虚拟组承担不同的角色
-
张量映射(TMA 描述符)预取
- 从 GMEM 预取 Tensor Map 对象,允许 TMA 加载流水线化
为了使性能数据更直观,以下我们展示了一个“加速”图表,突出显示了基于百分比的延迟差异:
图 8:CUTLASS Ping-Pong 与 Triton FP8 使用 TMA 的加速比
这种加速完全是内核吞吐量,不包括下面将要讨论的端到端启动开销。
TMA 描述符移动 - 三叉戟与 CUTLASS 在端到端性能方面的关键区别
如前所述,2D+维度的 TMA 描述符的创建发生在主机上,然后转移到设备。然而,这个传输过程根据实现方式有很大不同。
在这里,我们展示了三叉戟与 CUTLASS 在传输 TMA 描述符方面的差异。
回顾一下,TMA 传输需要创建一个特殊的数据结构,即通过 cuTensorMap API 在 CPU 上创建的张量映射,对于一个 FP8 GEMM 内核,这意味着需要创建三个描述符,分别对应 A、B 和 C。下面我们可以看到,对于三叉戟和 CUTLASS 内核,调用的 CPU 过程是相同的。
图 7. 对 cuTensorMapEncodeTiled 的调用(Triton 和 CUTLASS 都使用此路径)
然而,对于 Triton 来说,每个描述符都通过自己的独立复制内核进行传输,这增加了大量的开销,并成为在端到端推理场景中使用此内核的障碍。
图 8. 在内核执行之前,启动了三个 H2D 复制内核,用于 A、B 和 C
由于 TMA 描述符传递给内核的方式,这些复制在 CUTLASS 实现中并未观察到。从下面的 PTX 中我们可以看到,使用 Cutlass 时,张量映射是通过值传递给内核的。
.entry _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE(
.param .align 64 .b8 _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0[1024]
mov.b64 %rd110, _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_10bfloat16_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEES8_NS7_ILi256EEEEEENS6_IJNS7_ILi1EEESB_SB_EEENS_4gemm24KernelTmaWarpSpecializedENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0;
add.s64 %rd70, %rd110, 704;
cvta.param.u64 %rd69, %rd70;
cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd69, {%r284, %r283}], [%r1880];
图 9. CUTLASS 内核 PTX 展示按值传递
与直接传递 TMA 描述符而不是传递全局内存指针相比,CUTLASS 内核避免了三个额外的 H2D 复制内核,而是将这些复制包含在为 GEMM 的单次设备内核启动中。
由于描述符移动到设备的方式不同,包括准备要由 TMA 消耗的张量的时间在内的内核延迟差异很大。对于 M=1-128,N=4096,K=4096,CUTLASS pingpong 内核的平均延迟为 10us,Triton TMA 内核平均完成时间为 4ms。这比慢了约 3330 倍,似乎直接与 Triton 为 TMA 描述符传输进行的三个独立内核启动有关。
Cuda 图可能是一种减少这种延迟的方法,但考虑到 H2D 复制产生的开销,当前 Triton 的实现当从端到端测量时并不具有竞争力。重新设计 Triton 编译器如何管理 TMA 描述符可能会解决这个问题。因此,我们在上面的数据中专注于比较实际的计算内核吞吐量,而不是端到端性能。
结果摘要
图 10. Triton FP8 TMA GEMM TFLOPs 比较
M | Triton TMA | Triton 教程 | 特里顿 SplitK | cuBLAS FP8 | cuBLAS FP16 | CUTLASS Ping-Pong FP8 |
1 | 2.5 | 1 | 2.4 | 1.5 | 1.8 | 3.57 |
2 | 5.1 | 2.5 | 4.8 | 3.1 | 3.6 | 5.9 |
4 | 10.3 | 7.21 | 9.6 | 6.1 | 7.2 | 14.3 |
8 | 21.0 | 16.5 | 19.2 | 12.3 | 14.4 | 28.6 |
16 | 44.5 | 41.0 | 37.2 | 24.5 | 27.7 | 55.1 |
32 | 89.7 | 81.2 | 72.2 | 71.6 | 56.8 | 114.4 |
64 | 178.5 | 163.7 | 130.8 | 144.6 | 105.3 | 228.7 |
128 | 359.7 | 225.9 | 160.1 | 244.0 | 189.2 | 377.7 |
图 11. Triton FP8 TMA GEMM TFLOPs 比较表
上面的图表和表格总结了通过利用 TMA 硬件单元,在单个 NVIDIA H100 上实现 FP8 GEMM 所获得的增益,与不带 TMA 的 Triton 内核和高性能 CUDA(cuBLAS)内核相比。需要注意的是,该内核在批大小上的扩展性(与竞争产品相比)更优越。我们在这些基准测试中使用的规模问题代表了在 1001#推理中找到的小到中等批大小矩阵形状。因此,对于有兴趣利用此内核进行 FP8 LLM部署用例的人来说,TMA GEMM 内核在中 M 区域(M=32 到 M=128)的性能将至关重要,因为 FP8 压缩数据类型可以允许更大的矩阵适合 GPU 内存。
总结我们的分析,Triton 中的 TMA 实现与 CUTLASS 在完整功能集支持(多播、预取等)以及如何将 TMA 描述符传递给 GPU 内核方面有所不同。如果以更接近 CUTLASS 内核(按值传递)的方式传递此描述符,则可以避免额外的 H2D 复制,从而大大提高端到端性能。
未来工作
为了未来的研究,我们计划在此基础上进行改进,通过与社区合作将 TMA 负载的 CUTLASS 架构集成到 Triton 中,并研究 FP8 GEMM 的协同内核,这是一种对 Ping-Pong 内核的修改策略。
此外,一旦在 Triton 中启用线程块集群和 TMA 原子操作,我们可能能够通过利用 TMA GEMM 内核中的 SplitK 策略来获得进一步的加速,因为 Hopper 上的原子操作可以在分布式共享内存(DSMEM)中执行,而不是在 L2 缓存中。我们还注意到 NVIDIA Hopper GPU 与谷歌的 TPU 和 IBM 的 AIU 等其他 AI 硬件加速器在数据流架构上的相似性。由于 TMA 和 DSMEM 的添加,Hopper 上的数据现在可以从 GMEM“流动”到连接的 SM 网络,我们已经在本文中详细讨论了 TMA,计划在未来的文章中介绍 DSMEM。