由米尔阿德·莫哈马迪、谭洁文、鲁立扬、刘思源、郑耀诺、李元珠、白曼菲、史蒂文·克劳奇克、沙乌欣·扎希拉扎米、亚历克斯·韦尔蒂海姆、梅根·考恩、曹杰、乔·斯皮萨克

背景 & 研究现状

在自然语言处理(NLP)领域,语言模型被设计用来使用过去输入序列的标记(例如单词)生成标记。大型语言模型(LLMs)是这一领域最新的深度学习创新,旨在以人类似的方式生成文本。这些模型通常使用转换器来提高它们对大量输入标记的关注。

Meta AI 开源的 LLaMA 是一个强大的基础模型(LLM),在超过 1T 的标记上进行了训练。LLaMA(13B)与许多顶级模型(如 GPT-3、Chinchilla、PaLM)具有竞争力,并且优于 GPT-3(175B),这突显了它从每个模型参数中提取更多计算能力的能力。

在这篇博客文章中,我们以 LLaMA 为例,展示了 PyTorch/XLA 在LLM推理方面的能力。我们讨论了这里讨论的计算技术和优化如何将 65B 参数的 LLaMA 模型在 Google Cloud TPU v4(v4-16)上的推理延迟提高了 6.4 倍。

模型概述

我们展示了 PyTorch/XLA 在 LLaMA 上的性能能力,LLaMA 是 Meta 的最新LLM。我们展示了一系列常见 LLaMA 配置的性能优化。请注意,175B 参数的模型配置在公共领域不存在。对于下面提到的 175B 参数模型,我们将 OPT 175B 模型配置应用于 LLaMA 代码库。除非另有说明,否则在所有配置中,我们使用 max_seq_len=256dtype=bfloat16 作为权重和激活。

表 1:本文探讨的模型配置

LLaMA 模型超参数
# 参数 维度 N 头 N 层 最大序列长度
7B 4,096 32 32 256
33B 6,656 52 60 256
65B 8,192 64 80 256
175B 12,288 96 96 256

LLMs 的性能挑战

LLMs 具有一些特性,使得它们对编译器优化具有挑战性。(a)LLMs 使用自回归解码来生成基于前一个的下一个标记;这意味着提示张量和教练具有动态形状。(b)LLMs 必须与可变输入提示长度一起工作,而不会因为输入张量形状变化而触发重新编译;输入张量必须进行适当的分桶和填充,以避免重新编译。(c)LLMs 通常需要的内存比单个 TPU(或 GPU)设备可以支持的还要多。需要一个模型分片方案,以便将模型适应分布式计算架构。例如,一个 65B 参数的 LLaMA 模型可以适应 v4-16 Cloud TPU,这相当于 8 个 A100 GPU。(d)在生产中运行 LLMs 可能很昂贵;提高性能/总拥有成本(Perf/TCO)的一种方法是通过量化;量化可能可以减少硬件需求。

PyTorch/XLA 中的推理技术栈

我们的目标是向人工智能社区提供一个高性能的推理栈。PyTorch/XLA 与 TorchDynamo、PjRt、OpenXLA 以及各种模型并行方案集成。TorchDynamo 消除了运行时的跟踪开销,PjRt 实现了高效的宿主-设备通信;PyTorch/XLA 可追踪的集体操作通过 TorchDynamo 在 LLaMA 上实现了模型和数据并行。要尝试我们的结果,请使用我们定制的 torch、torch-xla 轮子来重现我们的 LLaMA 推理解决方案。PyTorch/XLA 2.1 将默认支持本文中讨论的功能。

并行计算

公平分片

LLaMA 使用 FairScale 模型分片 API(fairscale.nn.model_parallel.layers)。我们使用 PyTorch/XLA 通信集体(CC)操作(如 all-reduce )构建了此 API 的等效表示,以在加速器之间通信程序状态(例如激活)。TorchDynamo 目前不完全支持捕获 CC 操作(即可追踪的集体)。没有这种支持,TorchDynamo FX 图会在每次设备通信时被切断,这意味着在每个模型层。图切分会导致性能损失,因为底层的 XLA 编译器失去了完整的图优化机会。为了解决这个问题,我们通过将调度器集体集成到现有的 CC API 中,提供了 PyTorch/XLA 可追踪的集体。不同之处在于,由于 PyTorch/XLA 的懒执行特性,我们不需要在集体之后插入 c10d.wait() 操作。有了对可追踪集体的支持,PyTorch/XLA 允许在 TorchDynamo 中生成单个 FX 图。

PyTorch/XLA 上的自回归解码

需要自回归解码以将前一个单词作为提示来预测下一个标记。自回归解码会导致动态形状问题无界,进而导致每次提示都需要重新编译。我们对 LLaMA 自回归解码器进行了优化,使其能够以固定形状运行,并在每次生成标记时就地更新 KV 缓存、输出序列和注意力掩码。通过结合填充、掩码和索引操作,我们避免了过度图形重新编译,从而实现了高效的自回归解码。

KV-Cache 优化

LLaMA 实现了带有 KV 缓存的自动回归解码。对于每个生成的标记,KV 缓存存储了每个 Transformer 层的注意力键/值激活。因此,在解码新标记时,先前标记的键/值不再需要重新计算。

在 LLaMA 中,KV 缓存张量的切片就地更新;这导致每次生成标记时都会发生重新编译事件。为了解决这个问题,我们使用索引张量和 tensor.index_copy() 操作来替换就地切片更新。注意力掩码和输出序列也受益于相同的优化。

输入提示优化

变长输入提示在LLM应用中很常见。这种属性会导致输入张量形状动态变化,进而引发重新编译事件。在处理提示以填充 KV 缓存时,我们或者(a)逐个处理输入提示的标记,或者(b)在一次迭代中处理整个提示。每种方法的优缺点如下:

  1. 预编译 1 个图,逐个处理提示标记
    • 实用:在预热期间编译 1 个图
    • 慢:O(L) 处理输入提示长度 L - 对于长提示是一个缺点
  2. 预编译所有输入长度从 1 到 max_seq_len(例如 2,048)的图
    • 不切实际:在预热时间内预编译并缓存 max_seq_len 个图
    • 快:执行 1 个图来处理整个提示

我们引入了提示长度分桶技术,这是一种在两种选择之间取得平衡的优化方法。我们定义了一组递增的分桶大小,(b 0 ,b 1 ,b 2 ,…,b B-1 ),然后根据这些桶值预编译程序图,(G 0 ,G 1 ,G 2 ,…,G B-1 );B 是桶的数量。对于给定的输入提示,我们将提示长度向上取整到最接近的桶值 b n ,填充序列,并使用 G n 在一次迭代中处理提示。对填充标记的计算将被丢弃。对于大于最大桶大小的提示,我们将它们分部分处理。

最佳桶大小应由目标应用中的提示长度分布来确定。在这里,我们采用桶长度:128、256、384、512。任何最多包含 2,047 个标记的输入提示都需要最多 4 次图执行。例如,一个 1,500 个标记的输入提示,生成长度为 256,需要 260 次图执行 - 4 次处理输入,256 次生成输出。

量化

量化减少了表示值所需的位数;它减少了通过集体(collectives)在多个加速节点之间通信数据所需的带宽,并降低了服务于特定模型大小的硬件要求。

通常情况下,使用 BF16 权重,一个 175B 参数模型会消耗大约 351GB 的内存,因此需要 v4-32 实例来容纳该模型。通过将权重量化到 INT8 ,我们大致将模型大小减少了 50%,使其能够在更小的 v4-16 实例上运行。由于 LLaMA 分片模型激活,量化提供了微不足道的通信增益。

在我们的实验中,我们对线性层进行了量化。由于 LLaMA 模型检查点未公开,我们的目标是评估性能,因此量化模型使用随机权重初始化。最近的文献,如 AWQ 和整数或浮点?,为 LLaMA 在各种低比特量化方案下的性能特性提供了见解。

批大小对量化性能的影响

TPU v4 在模型批大小(BS)大于 1 时,会在矩阵乘法单元(MXU)上运行 matmul 。对于 BS=1, matmul 将在向量处理器单元(VPU)上运行。由于 MXU 比 VPU 更高效,因此在 BS>1 时, INT8 量化可以提升性能。具体细节请参见性能分析部分。

操作支持

有时,新模型会引入新的数学运算,这需要 PyTorch/XLA 扩展其支持的运算集以进行编译。对于 LLaMA,我们支持了:multinomial。

方法论

LLaMA 在 PyTorch/XLA 上无需额外配置即可直接运行在 LazyTensorCore 上。我们将此配置作为后续分析的基准。所有实验假设输入提示长度为 256。在没有公开可用的模型检查点的情况下,我们使用了随机张量初始化来进行此推理堆栈优化工作。模型检查点不会改变此处讨论的延迟结果。

模型规模

假设 N 是参数数量, dimensions 是隐藏层大小, n_layers 是层数, n_heads 是注意力头数量,以下方程可以用来近似模型大小。详见模型概述部分。

N = (dimensions)^2 * n_layers * 12

n_heads 不影响 N ,但对于开源模型配置,以下方程成立。

dim = 128 * n_heads

缓存规模

模型参数和注意力块中的缓存层都对内存消耗有贡献。由于默认的 LLaMA 模型使用 BF16 权重,本节中的内存消耗计算基于 BF16 权重。

缓存层的大小通过 cache_size = max_batch_size * max_seq_len * dimensions 计算。 max_batch_size = 1max_seq_len = 256 total_cache_size = n_layers * 2 * cache_size * (2 bytes) 在以下计算中用作示例配置。每个注意力块中有 2 个缓存层。因此,LLaMA 的总缓存大小(以字节为单位)为 total_cache_size = n_layers * 2 * cache_size * (2 bytes)

TPU v4 硬件规格

每个 TPU v4 芯片有 32GB 的高带宽内存(HBM)。表 2 中详细说明了内存消耗和容纳 LLaMA 模型所需的 TPU 芯片数量。

表格 2:LLaMA TPU v4 HBM 需求(即 TPU v4 芯片需求)

# 参数 参数(MB) 缓存(MB) 总计(GB) 最小 TPU v4 芯片数量
7B 14,000 134 14.128 1
33B 66,000 408 66.41 3
65B 130,000 671 130.67 5
175B 350,000 1,208 351.21 11

指标

以下是衡量推理速度的有用指标。假设 T 为总时间, B 为批次大小, L 为解码序列长度。

延迟定义

延迟是指获取目标长度解码结果所需的时间,与批次大小无关。延迟表示用户等待从生成模型获取响应的时间。

Latency = T (s)

每个 token 的延迟

自回归解码的一步为批次中的每个样本生成一个 token。每个 token 的延迟是指这一步的平均时间。

Per-token latency = T / L (s/token)

吞吐量

吞吐量衡量每单位时间内生成的标记数量。虽然它不是评估在线服务的有用指标,但它可以用来衡量批量处理的速度。

Throughput = B * L / T (tokens/s)

为了减少混淆和误解,最好避免像 T / (B * L) 这样的指标,它将延迟和吞吐量混合在一起。

结果

图 1 显示了 LLaMA 7B 到 175B 模型的延迟/标记结果。在每种情况下,模型都在一系列 TPU v4 配置上运行。例如,LLaMA 7B 在 v4-8 和 v4-16 上分别显示 4.7ms/标记和 3.8ms/标记。更多比较,请访问 HuggingFace LLM性能排行榜。

在没有本博客文章中讨论的功能的情况下,运行在 v4-32 上的 LLaMA 65B 提供 120ms/标记,而不是这里获得的 14.5ms/标记,从而实现了 8.3 倍的速度提升。如前所述,鼓励开发者尝试我们的自定义 torch、torch-xla wheels,这些工具可以解锁这里共享的 LLaMA 推理结果的重现。

Figure 1: LLaMA Inference Performance on TPU v4 hardware

图 1:LLaMA 在 TPU v4 硬件上的推理性能

PyTorch/XLA:GPU 的性能优于 PyTorch:GPU eager,与 PyTorch Inductor 相似。PyTorch/XLA:TPU 的性能优于 PyTorch/XLA:GPU。在不久的将来,XLA:GPU 将提供优化,以实现与 XLA:TPU 的等效性。单个 A100 配置仅适用于 LLaMA 7B,而 8-A100 不适用于 LLaMA 175B。

Figure 2: LLaMA Inference Performance on GPU A100 hardware

图 2:LLaMA 在 GPU A100 硬件上的推理性能

随着批处理大小的增加,我们观察到每令牌延迟呈次线性增长,这突显了硬件利用率和延迟之间的权衡。

Figure 3: LLaMA Inference Performance across different batch sizes

图 3:LLaMA 在不同批处理大小下的推理性能

我们的研究表明,最大序列输入长度对推理延迟的影响相对较小。我们将其归因于标记生成的顺序和迭代性质。性能的小幅差异可能是因为随着存储大小的增加,KV 缓存访问延迟的变化。

Figure 4: LLaMA Inference Performance across different prompt lengths

图 4:LLaMA 在不同提示长度下的推理性能

LLMs通常是内存受限的应用程序;因此,通过量化模型参数,我们能够在单位时间内加载和执行更大的张量(即 HBM⇒CMEM 和 CMEM⇒MXU 数据移动)。图 5 显示,仅对 INT8 权重进行量化可以提供 1.6x-1.9x 的速度提升,从而在给定硬件上运行更大的模型。

当 BS=1 时,INT8 张量被发送到比 MXU 小的 VPU(参见 TPU v4 论文);否则,使用 MXU。因此,当 BS=1 时,量化内存带宽增益被 MXU 利用率不足所抵消。然而,当 BS>1 时,内存增益在量化模型上提供了更优的延迟。例如,在 175B 参数的 LLaMA 的情况下,v4-16 带量化与 v4-32 不带量化提供了相似的性能。请注意,我们没有提供 FP8 比较,因为 PyTorch 尚未提供此数据类型。

Figure 5: LLaMA Inference Performance vs. weight-only quantization. The missing blue bars suggest the model size doesn’t fit in the specified TPU hardware.

图 5:LLaMA 推理性能与仅权重量化的对比。缺失的蓝色条表示模型大小不适合指定的 TPU 硬件。

图 6 展示了当输入提示长度从 10 个 token 增长到 1,500 个 token 时,PyTorch/XLA 的稳定性能优势。这种强大的扩展能力表明 PyTorch/XLA 的重编译事件最少,从而支持广泛的实际应用。在这个实验中,最大长度为 2,048,最大生成长度为 256。

Figure 6: LLaMA Inference Performance vs. Input Prompt Length

图 6:LLaMA 推理性能与输入提示长度的对比

总结

我们对 PyTorch/XLA 的未来感到兴奋,并邀请社区加入我们。PyTorch/XLA 完全开源开发。因此,请将问题、提交拉取请求和发送 RFC 发送到 GitHub,以便我们可以公开协作。您还可以在包括 TPUs 和 GPU 在内的各种 XLA 设备上尝试 PyTorch/XLA。

干杯,
谷歌 PyTorch/XLA 团队
#PoweredByPyTorch