由 Sayak Paul 和 Patrick von Platen(Hugging Face 🤗)撰写

本文是关于如何使用纯原生 PyTorch 加速生成式 AI 模型的系列博客的第三部分。我们很高兴与大家分享一系列新发布的 PyTorch 性能特性,并通过实际示例展示我们可以将 PyTorch 原生性能提升到何种程度。在第一部分中,我们展示了如何仅使用纯原生 PyTorch 将 Segment Anything 加速超过 8 倍。在第二部分中,我们展示了如何仅使用原生 PyTorch 优化将 Llama-7B 加速近 10 倍。在本篇博客中,我们将重点关注如何将文本到图像扩散模型的速度提升至 3 倍。

我们将利用一系列优化,包括:

  • 使用 bfloat16 精度运行
  • 缩放点积注意力(SPDA)
  • torch.compile
  • 结合 q,k,v 投影进行注意力计算
  • 动态 int8 量化

我们将主要关注 Stable Diffusion XL(SDXL),展示 3 倍的延迟提升。这些技术是 PyTorch 原生,这意味着您无需依赖任何第三方库或任何 C++代码即可利用它们。

使用🤗Diffusers 库启用这些优化只需几行代码。如果您已经兴奋不已,迫不及待想要查看代码,请查看随附的仓库:https://github.com/huggingface/diffusion-fast。

SDXL Chart

(所讨论的技术并非 SDXL 专用,可以用于加速其他文本到图像的扩散系统,如后文所示。)

以下您可以找到一些关于类似主题的博客文章:

设置

我们将使用🤗 Diffusers 库演示优化及其相应的加速增益。除此之外,我们还将使用以下 PyTorch 原生库和环境:

  • Torch 夜间版本(以利用高效的注意力核;2.3.0.dev20231218+cu121)
  • 🤗 PEFT(版本:0.7.1)
  • 火烧 (提交 SHA: 54bcd5a10d0abbe7b0c045052029257099f83fd9)
  • CUDA 12.1

为了更简单的复现环境,您也可以参考这个 Dockerfile。本文中展示的基准测试数据来自 400W 80GB A100 GPU(其时钟频率设置为最大容量)。

由于我们在这里使用 A100 GPU(Ampere 架构),我们可以指定 torch.set_float32_matmul_precision("high") 以利用 TF32 精度格式。

使用降低精度运行推理

在 Diffusers 中运行 SDXL 只需几行代码:

from diffusers import StableDiffusionXLPipeline

## Load the pipeline in full-precision and place its model components on CUDA.
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

但这并不实用,因为使用 30 步生成一张图片需要 7.36 秒。这是我们基准,我们将一步步尝试优化。

SDXL Chart

在这里,我们使用全精度运行管道。我们可以通过使用降低精度,如 bfloat16,立即减少推理时间。此外,现代 GPU 配备了用于运行加速计算的专用核心,从而受益于降低精度。要使用 bfloat16 精度运行管道的计算,我们只需在初始化管道时指定数据类型即可:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDXL Chart

通过使用降低精度,我们能够将推理延迟从 7.36 秒减少到 4.63 秒。

关于使用 bfloat16 的一些笔记

  • 使用降低的数值精度(如 float16、bfloat16)进行推理不会影响生成质量,但会显著提高延迟。
  • 与 float16 相比,使用 bfloat16 数值精度的优势取决于硬件。现代 GPU 的生成代数通常更倾向于 bfloat16。
  • 此外,在我们的实验中,我们发现与 float16 相比,bfloat16 在量化使用时具有更高的鲁棒性。

我们后来在 float16 上运行了实验,发现 torchao 的最新版本不会因为 float16 而产生数值问题。

使用 SDPA 进行注意力计算

默认情况下,Diffusers 在 PyTorch 2 中使用 scaled_dot_product_attention (SDPA)进行与注意力相关的计算。SDPA 提供了更快、更高效的内核来运行密集的注意力相关操作。要运行 SDPA 管道,我们只需不设置任何注意力处理器,如下所示:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDPA 将性能提升了 4.63 秒到 3.31 秒。

SDXL Chart

编译 UNet 和 VAE

我们可以通过使用 torch.compile 来请求 PyTorch 执行一些低级优化(例如操作融合和通过 CUDA 图启动更快的内核)。对于 StableDiffusionXLPipeline ,我们编译去噪器(UNet)和 VAE:

from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Compile the UNet and VAE.
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

## First call to `pipe` will be slow, subsequent ones will be faster.
image = pipe(prompt, num_inference_steps=30).images[0]

使用 SDPA 注意力并编译 UNet 和 VAE 将延迟从 3.31 秒降低到 2.54 秒。

SDXL Chart

关于 torch.compile 的注意事项

提供不同的后端和模式。为了实现最大的推理速度,我们选择使用“max-autotune”的感应器后端。“max-autotune”使用 CUDA 图,并针对延迟优化编译图。使用 CUDA 图大大减少了启动 GPU 操作的开销。通过使用一种机制,通过单个 CPU 操作启动多个 GPU 操作,从而节省时间。

fullgraph 指定为 True 确保底层模型中没有图断开,从而确保 torch.compile 的最大潜力。在我们的案例中,以下编译器标志也被明确设置是很重要的:

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

有关编译器标志的完整列表,请参阅此文件。

我们还在编译 UNet 和 VAE 时将内存布局更改为“channels_last”,以确保最大速度:

pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

在下一节中,我们将展示如何进一步降低延迟。

其他优化

torch.compile 期间没有图断开

确保底层模型/方法可以完全编译对于性能至关重要( torch.compilefullgraph=True )。这意味着没有图断开。我们通过改变访问返回变量的方式对 UNet 和 VAE 进行了修改。以下是一个示例:

code example

编译后去除 GPU 同步

在迭代反向扩散过程中,每次在降噪器预测出更少的噪声潜在嵌入后,我们都会在调度器上调用 step() 。在 step() 中,对 sigmas 变量进行索引。如果将 sigmas 数组放置在 GPU 上,索引会导致 CPU 和 GPU 之间的通信同步。这会导致延迟,当降噪器已经被编译时,这种现象更为明显。

但如果 sigmas 数组始终保持在 CPU 上(参考此行),则不会发生此同步,从而提高了延迟。一般来说,任何 CPU <-> GPU 通信同步都应该是无的或保持最小,因为它会影响推理延迟。

使用组合投影进行注意力操作

SDXL 中使用的 UNet 和 VAE 都采用了类似 Transformer 的模块。一个 Transformer 模块由注意力模块和前馈模块组成。

在注意力模块中,输入通过三个不同的投影矩阵 Q、K 和 V 被投影到三个子空间。在原始实现中,这些投影是分别对输入进行的。但我们可以将投影矩阵水平组合成一个单一的矩阵,并一次性进行投影。这增加了输入投影的 matmuls 大小,并提高了量化的影响(将在下文中讨论)。

在 Diffusers 中启用这种计算只需一行代码:

pipe.fuse_qkv_projections()

这将使 UNet 和 VAE 的注意力操作都能利用组合投影。对于交叉注意力层,我们只组合键和值矩阵。想了解更多,可以参考官方文档。值得注意的是,我们在这里内部使用了 PyTorch 的 scaled_dot_product_attention

这些额外技术将推理延迟从 2.54 秒降低到 2.52 秒。

SDXL Chart

动态 int8 量化

我们有选择地将动态 int8 量化应用于 UNet 和 VAE。这是因为量化给模型增加了额外的转换开销,但希望这些开销可以通过更快的 matmuls(动态量化)来弥补。如果 matmuls 太小,这些技术可能会降低性能。

通过实验,我们发现 UNet 和 VAE 中的某些线性层并不受益于动态 int8 量化。您可以在此处查看过滤这些层的完整代码(以下称为 dynamic_quant_filter_fn )。

我们利用超轻量级的纯 PyTorch 库 torchao,使用其用户友好的 API 进行量化:

from torchao.quantization import apply_dynamic_quant

apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

由于这种量化支持仅限于线性层,我们还把合适的点卷积层转换为线性层以最大化效益。在使用此选项时,我们还指定了以下编译器标志:

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

为了防止量化引起的任何数值问题,我们以 bfloat16 格式运行所有操作。

以这种方式应用量化将延迟从 2.52 秒降低到 2.43 秒。

SDXL Chart

资源

欢迎您查看以下代码库以复现这些数字,并将这些技术扩展到其他文本到图像扩散系统中:

其他链接

其他管道的改进

我们将这些技术应用于其他管道以测试我们方法的通用性。以下是我们的发现:

SSD-1B

SSD-1B Chart

稳定扩散 v1-5

Stable Diffusion v1-5 chart

PixArt-alpha/PixArt-XL-2-1024-MS

值得注意的是,PixArt-Alpha 使用基于 Transformer 的架构作为其反向扩散过程的去噪器,而不是 UNet。

PixArt-alpha/PixArt-XL-2-1024-MS chart

注意,对于 Stable Diffusion v1-5 和 PixArt-Alpha,我们没有探索应用动态 int8 量化的最佳形状组合标准。可能通过更好的组合可以得到更好的数值。

我们提出的这些方法总体上在生成质量没有下降的情况下,比基线方法提供了显著的加速。此外,我们相信这些方法应该补充社区中流行的其他优化方法(如 DeepCache、Stable Fast 等)。

结论和下一步

在这篇文章中,我们介绍了一系列简单而有效的技术,可以帮助提高纯 PyTorch 中文本到图像扩散模型的推理延迟。总之:

  • 使用降低精度进行我们的计算
  • 缩放点积注意力,以高效运行注意力块
  • 使用“max-autotune”编译 torch 以提升延迟
  • 将不同的投影组合起来进行注意力计算
  • 动态 int8 量化

我们相信在如何将量化应用于文本到图像的扩散系统方面还有很多可以探索的空间。我们没有彻底探索 UNet 和 VAE 中哪些层倾向于从动态量化中受益。可能有机会通过更好地组合目标量化层来进一步加快速度。

我们保留了 SDXL 的文本编码器,除了在 bfloat16 模式下运行之外,没有对其进行修改。优化它们也可能导致延迟的改进。

致谢

感谢 Ollin Boer Bohan,其 VAE 在整个基准测试过程中被使用,因为它在降低数值精度下数值更稳定。

感谢来自 Hugging Face 的 Hugo Larcher 在基础设施方面的帮助。