由佩德罗·库恩卡、帕特里克·冯·普拉滕、苏拉杰·帕提尔、萨亚克·保罗所著

PyTorch 2.0 刚刚发布。其标志性的新特性是 torch.compile() ,一行代码的更改承诺将自动提升代码库的性能。我们之前已经在 Hugging Face Transformers 和 TIMM 模型上验证了这一承诺,并深入探讨了其动机、架构和未来的道路。

尽管 torch.compile() 很重要,但 PyTorch 2.0 还有更多内容。值得注意的是,PyTorch 2.0 整合了多种加速 transformer 块的策略,这些改进对于扩散模型也非常相关。例如,FlashAttention 等技术因其在显著加速 Stable Diffusion 和实现更大批处理大小方面的能力而在扩散社区中变得非常流行,现在这些技术已成为 PyTorch 2.0 的一部分。

在本文中,我们讨论了 PyTorch 2.0 中注意力层的优化方式以及这些优化如何应用于流行的🧨 Diffusers 库。最后,我们通过基准测试展示了使用 PyTorch 2.0 和 Diffusers 如何立即在不同硬件上实现显著的性能提升。

更新(2023 年 6 月):新增了一个部分,展示了使用 PyTorch(2.0.1)最新版本后,经过修复 diffusers 代码库中的图断开问题, torch.compile() 的显著性能提升。关于如何查找和修复图断开问题的更详细分析将单独发布在另一篇博文中。

加速 Transformer 块

PyTorch 2.0 将缩放点积注意力函数作为 torch.nn.functional 的一部分。这个函数包含几个根据输入和硬件使用情况可应用的实现。在 PyTorch 2.0 之前,您需要寻找第三方实现并安装单独的包,才能利用内存优化算法,如 FlashAttention。可用的实现包括:

  • 来自官方 FlashAttention 项目的 FlashAttention。
  • 内存高效的注意力机制,来自 xFormers 项目。
  • 适用于非 CUDA 设备或需要高精度的本地 C++实现。

所有这些方法默认可用,PyTorch 将通过使用新的缩放点积注意力(SDPA)API 自动尝试选择最佳方案。您还可以单独切换它们以获得更精细的控制,请参阅文档以获取详细信息。

在 diffusers 中使用缩放点积注意力机制。

通过使用 set_attn_processor 方法,将加速 PyTorch 2.0 Transformer 注意力机制集成到 Diffusers 库中,实现了可配置的插件式注意力模块。在这种情况下,创建了一个新的注意力处理器,当 PyTorch 2.0 可用时默认启用。为了清晰起见,这是您手动启用它的方法(但通常不需要,因为 diffusers 会自动处理):

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

稳定扩散基准

我们在 Diffusers 中使用了 PyTorch 2.0 的加速点积注意力机制进行了一系列测试。我们通过 pip 安装了 diffusers,并使用了 PyTorch 2.0 的夜间版本,因为我们的测试是在官方发布之前进行的。我们还使用了 torch.set_float32_matmul_precision('high') 来启用额外的快速矩阵乘法算法。

我们将结果与 diffusers 中的传统注意力实现(以下称为 vanilla )以及 2.0 版本之前性能最佳的解决方案 PyTorch 1.13.1(安装了 xFormers 包 v0.0.16)进行了比较。

测量结果是在没有编译的情况下进行的(即,没有任何代码更改),同时也通过调用 torch.compile() 来包装 UNet 模块。我们没有编译图像解码器,因为在 50 次去噪迭代中运行 UNet 评估所花费的时间最多。

结果以 float32 格式

Diffusers Speedup vs xFormers float32

以下图表探讨了不同代代表性 GPU 的性能提升与批处理大小的关系。我们收集了每种组合的数据,直到达到最大内存利用率。Vanilla attention 比 xFormers 或 PyTorch 2.0 更早耗尽内存,这解释了为什么在大批处理大小下缺少条形图。同样,A100(我们使用的是 40 GB 版本)能够运行 64 个批处理大小,但其他 GPU 在我们的测试中只能达到 32。

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

我们在所有方面都发现了相对于 vanilla attention 的非常显著的性能提升,甚至没有使用 torch.compile() 。PyTorch 2.0 和 diffusers 的即插即用安装,在 A100 上实现了约 50%的速度提升,在 4090 GPU 上则实现了 35%到 50%的速度提升,具体取决于批处理大小。对于现代 CUDA 架构,如 Ada(4090)或 Ampere(A100),性能提升更为明显,但对于仍在云服务中大量使用的旧架构,性能提升仍然非常显著。

除了更快的速度外,PyTorch 2.0 中加速的 transformers 实现允许使用更大的批处理大小。单个 40GB 的 A100 GPU 在批处理大小为 10 时就会耗尽内存,而 24GB 的高端消费级显卡,如 3090 和 4090,一次无法生成 8 张图像。使用 PyTorch 2.0 和 diffusers,我们能够实现 3090 和 4090 的批处理大小为 48,A100 的批处理大小为 64。这对于云服务和应用来说意义重大,因为它们可以更有效地一次处理更多图像。

与 PyTorch 1.13.1 + xFormers 相比,新的加速 Transformer 实现仍然更快,且无需额外的包或依赖。在这种情况下,我们在 A100 或 T4 等数据中心卡上发现了高达 2%的适度加速,但在消费卡的最后两代上性能出色:3090 上高达 20%的速度提升,4090 上根据批大小在 10%到 45%之间。

当使用 torch.compile() 时,我们还能获得额外的性能提升(通常是)2%和 3%,超过之前的改进。由于编译需要一些时间,这更适合面向用户的推理服务或训练。更新:当最小化图中断时, torch.compile() 实现的改进要大得多,请参阅新部分以获取详细信息。

浮点 16 的结果

Diffusers Speedup vs xFormers float16

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

当我们考虑 float16 推理时,在 PyTorch 2.0 中加速的 Transformer 实现相对于标准注意力机制,在所有测试的 GPU 上性能提升了 20%至 28%,除了属于更现代 Ada 架构的 4090。这款 GPU 在使用 PyTorch 2.0 的 nightly 版本时,性能得到了显著提升。至于优化后的 SDPA 与 xFormers 相比,大多数 GPU 上的结果通常相当,但同样,4090 是个例外。加入 torch.compile() 后,整体性能又提升了几个百分点。

优化后 torch.compile() 的性能

在前面的章节中,我们看到了使用 PyTorch 2.0 加速的 Transformer 实现相对于 PyTorch 的早期版本(无论是否有 xFormers)提供了重要的性能提升。然而, torch.compile() 只带来了适度的边际改进。在 PyTorch 团队的帮助下,我们发现这些适度改进的原因是 diffusers 源代码中的一些操作导致了图断开,这阻碍了 torch.compile() 充分利用图优化。

在修复了图中断后(请参阅这些 PR 以获取详细信息),我们测量了 torch.compile() 与 PyTorch 2 未编译版本的额外改进,并看到了非常重要的性能提升。以下图表是使用 2023 年 5 月 1 日下载的 PyTorch 2 夜间版本获得的,它显示大多数工作负载的性能改进在 13%到 22%之间。对于现代 GPU 系列,性能提升更明显,对于 A100 来说,性能提升超过 30%。图表中还有两个异常值。首先,我们发现在 T4 上,对于 16 个批次的处理,性能有所下降,这对该卡的内存压力很大。在另一端,当使用仅 1 个批次的处理时,我们在 A100 上看到了超过 100%的性能提升,这很有趣,但并不代表具有如此大量 RAM 的 GPU 的实际使用情况——能够服务多个客户的更大批次通常对 A100 的服务部署更有趣。

Diffusers Speedup using torch.compile() in float16

再次强调,这些性能提升是除了迁移到 PyTorch 2 和使用加速的 Transformer 缩放点积注意力实现所获得的之外。我们建议在生产环境中部署 diffusers 时使用 torch.compile()

结论

PyTorch 2.0 自带多个功能来优化基础 Transformer 块的关键组件,并且可以通过使用 torch.compile 进一步改进。这些优化为扩散模型带来了显著的内存和时间改进,并消除了安装第三方库的需求。

要利用这些速度和内存改进,您只需升级到 PyTorch 2.0 并使用 diffusers >= 0.13.0 即可。

更多示例和详细的基准测试数据,请参阅 PyTorch 2.0 的 Diffusers 文档。

致谢

作者们对 PyTorch 团队创建如此优秀的软件表示衷心的感谢。