由 Jade Nie、CK Luk、Xiaodong Wang、Jackie(Jiaqi)Xu 撰写

1. 引言

PyTorch 支持两种执行模式[1]:即时模式和图模式。在即时模式下,模型中的操作在遇到时立即执行。相比之下,在图模式下,操作首先被合成成图,然后整个图被编译并执行。即时模式使用起来更简单,更适合机器学习研究人员,因此是默认的执行模式。另一方面,图模式通常提供更高的性能,因此在生产中得到了广泛的应用。

具体来说,图模式支持操作融合[2],即将一个操作与另一个操作合并,以减少/局部化内存读取以及总的内核启动开销。融合可以是横向的——将单个操作(例如 BatchNorm)独立应用于多个操作数,并将这些操作数合并成一个数组;也可以是纵向的——将一个内核与另一个内核合并,后者消耗第一个内核的输出(例如卷积后跟 ReLU)。

Torch.FX [3, 4](简称 FX)是 PyTorch 包中公开可用的工具包,支持图模式执行。特别是,它(1)从 PyTorch 程序中捕获图,并且(2)允许开发者对捕获的图进行转换。它被 Meta 内部用于优化生产模型的训练吞吐量。通过引入 Meta 开发的基于 FX 的优化,我们展示了使用图转换来优化 PyTorch 生产性能的方法。

2. 背景

嵌入表在推荐系统中无处不在。第 3 节将讨论三种优化嵌入表访问的 FX 转换。在本节中,我们提供一些关于 FX(第 2.1 节)和嵌入表(第 2.2 节)的背景信息。

2.1 FX

图 1 是一个简单的例子,摘自[3],它说明了如何使用 FX 来转换 PyTorch 程序。它包含三个步骤:(1)从程序中捕获图,(2)修改图(在这个例子中,所有使用 RELU 的地方都被 GELU 所替换),(3)从修改后的图中生成新的程序。

图 1:一个 FX 示例,在 PyTorch 模块中将所有 RELU 的使用替换为 GELU。

FX API [4]提供了许多用于检查和转换 PyTorch 程序图的更多功能。

2.2 嵌入表

图 2:稀疏特征嵌入表的示例,批大小为 1

在推荐系统中,稀疏特征(例如,用户 ID,故事 ID)通过嵌入表表示。嵌入表 E 是一个 HxD 矩阵,其中 H 是哈希大小,D 是嵌入维度。E 的每一行都是一个浮点向量。特征哈希[5]用于将稀疏特征映射到 E 的索引列表,例如[S 1 ,S 2 ,…,S k ],其中 0<=S

为了充分利用 GPU,稀疏特征通常以批处理方式处理。批处理中的每个实体都有自己的索引列表。如果一个批处理有 B 个实体,则原始表示有 B 个索引列表。更紧凑的表示是将 B 个索引列表合并为一个索引列表,并添加一个索引长度列表(每个批处理中的实体一个长度)。例如,如果一个批处理有 3 个实体,其索引列表如下:

  • 实体 1:索引 = [10, 20]
  • 实体 2:索引 = [5, 9, 77, 81]
  • 实体 3:索引 = [15, 20, 45]

然后整个批次的索引和长度将是:

  • 索引 = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • 长度 = [2, 4, 3]

整个批次的嵌入表查找输出是一个 BxD 矩阵。

3. 三个 FX 变换

我们开发了三种 FX 变换,以加速对嵌入表的访问。第 3.1 节讨论了一种将多个小型输入张量合并为单个大张量的变换;第 3.2 节讨论了一种将多个并行计算链融合为单个计算链的变换;第 3.3 节讨论了一种重叠通信与计算的变换。

3.1 结合输入稀疏特征

回忆一下,在批次中,一个输入稀疏特征由两个列表表示:一个索引列表和一个长度为 B 的列表,其中 B 是批次大小。在 PyTorch 中,这两个列表被实现为两个张量。当 PyTorch 模型在 GPU 上运行时,嵌入表通常存储在 GPU 内存中(与 GPU 更近,读写带宽比 CPU 内存高得多)。要使用输入稀疏特征,其两个张量需要首先从 CPU 复制到 GPU。然而,每次主机到设备内存复制都需要一个内核启动,这相对于实际的数据传输时间来说相对昂贵。如果一个模型使用许多输入稀疏特征,这种复制可能会成为性能瓶颈(例如,1000 个输入稀疏特征将需要从主机复制 2000 个张量到设备)。

减少主机到设备 memcpy 次数的一种优化是在将它们发送到设备之前将多个输入稀疏特征组合在一起。例如,给定以下三个输入特征:

  • 特征_A:indices = [106, 211, 7],lengths = [2, 1]
  • Feature_B: 索引 = [52, 498, 616, 870, 1013], 长度 = [3, 2]
  • Feature_C: 索引 = [2011, 19, 351, 790], 长度 = [1, 3]

合并形式为:

  • Features_A_B_C: 索引 = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], 长度 = [2, 1, 3, 2, 1, 3]

因此,我们不需要从主机复制 3x2=6 个张量到设备,只需复制 2 个张量即可。

图 3(b)描述了这种优化的一个实现,它包含两个部分:

  • 在 CPU 端:修改输入管道,将稀疏特征的索引全部合并成一个张量,同样将所有长度合并成另一个张量。然后,将这两个张量复制到 GPU 上。
  • 在 GPU 端:使用 FX,我们在模型图中插入一个 Permute_and_Split 操作,从合并的张量中恢复各个特征的索引和长度张量,并将它们路由到相应的下游节点。

(a). 未进行优化

(b). 进行了优化

图 3:结合输入稀疏特征

3.2 以嵌入表访问开始的计算链的横向融合

在生产模型中,每个 GPU 上通常会有数十个嵌入表。出于性能考虑,对这些表的查找操作会被分组,以便它们的输出在一个大张量中连接(见图 4(a)中的红色部分)。为了对单个特征输出应用计算,使用 Split 操作将大张量分割成 N 个小张量(其中 N 是特征的数量),然后对每个张量应用所需的计算。这如图 4(a)所示,对每个特征输出 O 应用的计算是 Tanh(LayerNorm(O))。所有计算结果都连接回一个大张量,然后传递给下游操作(图 4(a)中的 Op1)。

这里的主要运行时成本是 GPU 内核启动开销。例如,图 4(a)中的 GPU 内核启动次数是 2*N + 3(图中的每个椭圆形都是一个 GPU 内核)。这可能会成为性能问题,因为 LayerNorm 和 Tanh 在 GPU 上的执行时间与它们的内核启动时间相比非常短。此外,Split 操作可能会创建嵌入输出张量的额外副本,消耗额外的 GPU 内存。

我们使用 FX 来实现一种称为水平融合的优化,这大大减少了 GPU 内核调用的次数(在本例中,优化后的 GPU 内核调用次数为 5,见图 4(b))。我们不是进行显式的 Split 操作,而是使用 Add_middle_dim 操作将形状为(B, NxD)的 2D 嵌入张量重塑为形状为(B, N, D)的 3D 张量。然后对它的最后一个维度应用单个 LayerNorm。然后对 LayerNorm 的结果应用单个 Tanh。最后,我们使用 Remove_middle_dim 操作将 Tanh 的结果重塑回 2D 张量。此外,由于 Add_middle_dim 和 Remove_middle_dim 操作仅重塑张量而不创建额外的副本,因此 GPU 内存消耗量也可以减少。

(a). 未进行优化

(b). 进行了优化

图 4:水平融合

3.3 与通信重叠的计算

生产推荐模型的训练通常在分布式 GPU 系统上进行。由于每个 GPU 的设备内存容量不足以容纳模型中的所有嵌入表,它们需要分布在多个 GPU 上。

在训练步骤中,GPU 需要从其他 GPU 上的嵌入表中读取/写入特征值。这被称为全对全通信[6],可能成为性能瓶颈。

我们使用 FX 实现了一种可以与全对全通信重叠计算的重构。图 5(a)显示了具有嵌入表访问(EmbeddingAllToAll)和其他操作的模型图示例。在没有任何优化的情况下,它们在 GPU 流上按顺序执行,如图 5(b)所示。使用 FX,我们将 EmbeddingAllToAll 分解为 EmbeddingAllToAll_Request 和 EmbeddingAllToAll_Wait,并在它们之间安排独立的操作。

(a) 模型图

(b) 原始执行顺序

(c) 优化后的执行顺序

图 5:计算与通信重叠

3.4 摘要

表 1 总结了本节讨论的优化以及相应的性能瓶颈。

优化 解决的性能瓶颈
结合输入稀疏特征 主机到设备内存复制
水平融合 GPU 内核启动开销
重叠计算与通信 嵌入全连接访问时间

表 1:优化总结及解决的性能瓶颈

我们还在此部分未讨论的其他 FX 转换进行了开发,由于篇幅限制。

为了发现哪些模型将受益于这些转换,我们分析了 MAIProf [7]从运行在 Meta 数据中心上的模型收集的性能数据。总的来说,这些转换与生产模型上的急切模式相比,提供了高达 2-3 倍的加速。

4. 结论

由于性能原因,PyTorch 中的图模式比即时模式更适合生产使用。FX 是一个强大的工具,用于捕获和优化 PyTorch 程序的图。我们展示了三种 FX 转换,这些转换用于在 Meta 内部优化生产推荐模型。我们希望这篇博客能够激励其他 PyTorch 模型开发者使用图转换来提升他们模型的性能。

参考文献列表

[1] 端到端机器学习框架

[2] DNNFusion:通过高级算子融合加速深度神经网络执行

[3] 火炬.FX:Python 深度学习中的实用程序捕获和转换,MLSys 2022。

[4] 火炬.fx—PyTorch 1.12 文档

[5] 大规模多任务学习中的特征哈希

[6] NVIDIA 集体通信库文档

[7] Meta 生产 PyTorch 模型的性能调试