由 CK Luk、史道航、黄宇珍、徐佳琪(Jackie)、Nie Jade、王周、方璐、Flavio Sales Truzzi、Devashish Shankar、Dima Ivashchenko、杨春志、Nicolas Macchioni、David Berard、郭宇、王晓东、Bert Maher、梁彦博、杨大伟、Brian Hirsh、Michael Voznesensky、Animesh Jain、Michael Anderson 撰写

1. 引言

PyTorch 2.0(简称 PT2)可以通过使用 torch.compile 编译器显著提高 AI 模型的训练和推理性能,同时与 PyTorch 1.x 保持 100%向后兼容。已有关于 PT2 如何提高常见基准测试(例如 huggingface 的 diffusers)性能的报道。在这篇博客中,我们讨论了在 Meta 将 PT2 应用于生产 AI 模型的经验。

2. 背景

2.1 为什么自动性能优化对生产如此重要?

性能对于生产尤其重要——例如,即使是高度使用的模型训练时间的 5%的减少,也能转化为 GPU 成本和数据中心电力的显著节省。另一个重要的指标是开发效率,它衡量将模型投入生产所需的工程师月数。通常,这部分投入生产的工作中,很大一部分是花费在手动性能调优上,例如重写 GPU 内核以提高训练速度。通过提供自动性能优化,PT2 可以提高成本和开发效率。

2.2 PT2 如何提高性能

作为编译器,PT2 可以查看从模型捕获的训练图中的多个操作(与 PT1.x 不同,在 PT1.x 中一次只执行一个操作)。因此,PT2 可以利用许多性能优化机会,包括:

  • 将多个操作融合到一个 GPU 内核中:
    • 运行 GPU 程序时典型的性能开销之一是 CPU 在启动小 GPU 内核时的开销。通过将多个操作融合成一个单独的 GPU 内核,PT2 可以显著减少 CPU 上的内核启动开销。例如,考虑图 1(a)中的 PyTorch 程序。当它在 GPU 上使用 PT1 执行时,有三个 GPU 内核(两个用于两个 sin()操作,一个用于加法操作)。使用 PT2 时,只生成一个内核,它融合了所有三个操作。
    • 在融合一些操作后,图中的某些操作可能成为死代码,因此可以被优化掉。这可以在 GPU 上节省计算和内存带宽。例如,在图 1(b)中,可以优化掉一个重复的 sin()操作。
    • 此外,融合还可以通过组合逐点内核来减少 GPU 设备内存的读写操作,并有助于提高硬件利用率。

Fig.1  How PT2 improves performance with fusion and dead-code elimination.

图 1:PT2 通过融合和死代码消除提高性能。

  • 降低使用低精度数据类型的类型转换开销:
    • PyTorch 1.x 支持自动混合精度(AMP)。虽然 AMP 可以减少操作的计算时间,但它在操作前后引入了类型转换开销。PT2 通过优化移除不必要的类型转换代码,可以显著降低其开销。例如,图 2(a)在执行矩阵乘法之前将三个 32 位输入张量(a32、b32、c32)转换为 bf16。然而,在这个例子中,a32 和 c32 实际上是同一个张量(a_float32)。因此,没有必要对 a_float32 进行两次转换,如图 2(b)中 torch.compile 生成的代码所示。请注意,虽然这个例子和前面的例子都优化了冗余计算,但它们在类型转换代码的隐式性方面有所不同,这个例子中的类型转换代码是通过 torch.autocast 隐式的,而前面的例子中 torch.sin(x).cuda() 在用户代码中是显式的。

Fig.2  How PT2 reduces type conversion overhead when using AMP.

图 2:PT2 在使用 AMP 时如何降低类型转换开销。

  • 在 GPU 上重用缓冲区:
    • 以全局视角,torch.compile 中的调度器可以在 GPU 上重用缓冲区,从而减少内存分配时间和内存消耗。图 3 显示了调用图 2(a)中程序生成的 Triton 内核的驱动程序。我们可以看到 buf1 被重用为 buf4

Fig.3  Reuse of buffers.

图 3:缓冲区重用。

  • 自动调优:
    • PT2 具有启用矩阵乘法操作、逐点操作和归约操作的自动调优(通过 Triton)的选项。可调参数包括块大小、阶段数和线程束数。通过自动调优,可以经验性地找到操作的最优实现。

3. 生产环境注意事项

在本节中,我们将描述将 PT2 应用于生产时的一些重要考虑因素。

3.1 使用 torch.compile 确保模型质量不下降

将 torch.compile 应用于模型会导致数值变化,因为(1)在融合等优化过程中对浮点运算的重新排序以及(2)启用 AMP 时使用低精度数据类型(如 bf16)。因此,与 PT 1.x 的 100%位兼容性是不预期的。尽管如此,我们仍然需要确保在应用 torch.compile 后模型质量(以某种形式的数值分数衡量)得到保留。通常,每个生产模型都将有其自己的可接受分数范围(例如,百分比变化必须在 0.01%以内)。

如果由于 torch.compile 导致模型质量下降,我们需要进行深度调试。

对于调试与 torch.compile 相关的数值问题,一个有用的技术是使用不同的后端进行 torch.compile,特别是“eager”和“aot_eager”,以及“inductor”:

  • 如果数值问题出现在“eager”后端,那么 torch.compile 构建的前向图可能是不正确的;
  • 如果数值问题在“eager”后端没有出现,但在“aot_eager”后端出现,那么 torch.compile 构建的反向图可能是不正确的;
  • 如果在“eager”或“aot_eager”中都不会出现数字问题,但在“inductor”中出现,那么 inductor 内部的代码生成很可能是不正确的。

3.2 生产环境中的自动调优

默认情况下,torch.inductor 中的自动调优是在模型执行时在线进行的。对于某些生产模型,我们发现自动调优可能需要几个小时,这在生产环境中是不可接受的。因此,我们增加了离线自动调优,其工作原理如图 4 所示。第一次运行模型时,所有需要调优的 ops 的详细信息(例如,输入张量形状、数据类型等)将被记录到数据库中。然后,这些 ops 的调优过程将在夜间运行,以寻找每个 ops 的最优实现;搜索结果将更新到持久缓存中(作为 torch.inductor 的源文件实现)。下次再次运行模型时,每个 ops 的调优实现将可以在缓存中找到,并用于执行。

Fig.4  The offline autotuning used in production.

图 4:生产环境中使用的离线自动调优。

3.3 支持 torch.compile 的配置文件分析

正如我们之前在这篇博客中讨论的,性能分析器对于调试生产模型的性能至关重要。我们增强了性能分析器,使其能够在时间轴上显示 torch.compile 相关的事件。其中最有用的是标记模型中哪些部分正在运行编译后的代码,这样我们可以快速验证模型中应该被编译的部分是否真的被 torch.compile 编译了。例如,图 5 中的跟踪有两个编译区域(带有标签“CompiledFunction”)。其他有用的信息包括编译所花费的时间和访问编译器代码缓存所花费的时间。

Fig.5  A trace with two compiled regions.

图 5:包含两个编译区域的跟踪。

3.4 控制即时编译时间

torch.compile 使用即时编译。编译发生在第一个批次数据训练时。在我们的生产环境中,训练作业达到第一个批次的时间有一个上限,即首次批次时间(TTFB)。我们需要确保启用 torch.compile 不会使 TTFB 超过限制。这可能具有挑战性,因为生产模型很大,而 torch.compile 可能需要大量的编译时间。我们启用并行编译以控制编译时间(这由全局变量 torch/_inductor/config.py 中的 compile_threads 控制,在 OSS Linux 上已设置为 CPU 数量)。一个模型被分解为一个或多个计算图;每个图被分解为多个 Triton 内核。如果启用并行编译,同一图中的所有 Triton 内核可以同时编译(尽管来自不同图的内核仍然按顺序编译)。图 6 说明了并行编译如何帮助。

Fig.6  Using parallel compilation in production.

图 6:在生产中使用并行编译。

4. 结果

在本节中,我们使用三个生产模型来评估 PT2。首先,我们展示了使用 PT2 在不同优化配置下的训练时间加速。其次,我们展示了并行编译对编译时间的重要性。

4.1 使用 torch.compile 进行训练时间加速

图 7 报告了使用 PT2 的训练时间加速。对于每个模型,我们展示了四种情况:(i)不编译使用 bf16,(ii)编译使用 fp32,(iii)编译使用 bf16,(iv)编译使用 bf16 并自动调整。y 轴是相对于基线(不编译使用 fp32)的加速比。请注意,不编译使用 bf16 实际上比不编译使用 fp32 慢,这是因为类型转换的开销。相比之下,使用 bf16 编译通过减少这部分开销实现了更大的加速。总的来说,鉴于这些模型已经通过手工进行了大量优化,我们很高兴看到 torch.compile 仍然可以提供 1.14-1.24 倍的加速。

Fig.7 Training-time speedup with torch.compile (note: the baseline, no-compile/fp32, is  omitted in this figure).

图 7:使用 torch.compile 的训练时间加速(注意:本图省略了基线,即不编译/fp32)。

4.2 并行编译实现编译时间减少

图 8 展示了有并行编译和无并行编译的编译时间。尽管串行编译时间仍有改进空间,但并行编译已将 TTFB 的编译开销降低到可接受的水平。模型 B 和 C 比模型 A 更能从并行编译中受益,因为它们每个图有更多独特的 Triton 内核。

Fig.8 PT2 compilation time.

图 8:PT2 编译时间。

5. 结论

在本文中,我们展示了 PT2 可以显著加快大型和复杂生产 AI 模型的训练速度,同时保持合理的编译时间。在下一篇文章中,我们将讨论 PT2 如何进行通用图转换。

6. 致谢

感谢 Mark Saroufim、Adnan Aziz 和 Gregory Chanan 对他们的详细和有见地的审阅。