由 IBM 的 PyTorch 团队

推测性解码是一种推理优化技术,它在生成当前标记的同时,对未来的标记进行有根据的猜测,所有这些都在单个前向传递中完成。它包含一个验证机制来确保这些推测标记的正确性,从而保证推测性解码的整体输出与普通解码相同。优化大型语言模型(LLMs)的推理成本可能是降低生成式 AI 成本并提高其采用率的最关键因素之一。为此目标,有各种推理优化技术可用,包括自定义内核、动态批处理输入请求以及大型模型的量化。

在本文中,我们提供了推测性解码的指南,并展示了它如何与其他优化技术共存。我们自豪地开源以下内容,其中包括 Llama3 模型的第一个推测器:

  1. Meta Llama3 8B、IBM Granite 7B 实验室、Meta Llama2 13B 和 Meta Code Llama2 13B 的推测器模型。
  2. 通过 IBM 对 HF TGI 的分支进行推理的代码。
  3. 训练您自己的投机者和相应食谱的代码。

我们将这些投机者部署在拥有数千名日常用户的内部生产级环境中,观察到语言模型(Llama3 8B、Llama2 13B、IBM Granite 7B)的速度提升了 2 倍,IBM 的 Granite 20B 代码模型的速度提升了 3 倍。我们在本技术报告中详细解释了我们的方法,并计划在即将发表的 ArXiv 论文中进行深入分析。

投机性解码:推理

我们在我们的内部生产环境中运行 IBM TGIS,该环境具有连续批处理、融合内核和量化内核等优化。为了在 TGIS 中启用推测性解码,我们修改了来自 vLLM 的分页注意力内核。以下,我们将描述为了启用推测性解码而对推理引擎所做的关键更改。

推测性解码基于模型足够强大,可以在单次前向传递中预测多个标记的假设。然而,当前的推理服务器优化为一次只预测一个标记。在我们的方法中,我们在LLM上附加多个推测性头(除了通常的一个之外),以预测 N+1、N+2、N+3 等标记。例如,3 个头将预测 3 个额外的标记。推测性架构的详细信息将在本博客的后续部分中解释。在推理过程中实现效率和正确性有两个挑战 - 一个是在不复制 KV 缓存的情况下进行预测,另一个是验证预测是否与原始模型的输出匹配。

在典型的生成循环中,在单次正向步骤处理完提示后,将序列长度为 1(预测的下一个标记)输入模型的正向传递中,同时输入 kv-cache。在简单的投机解码实现中,每个投机头都有自己的 kv-cache,但我们将 vLLM 项目开发的分页注意力内核进行了修改,以实现高效的 kv-cache 维护。这确保了在更大的批量大小下吞吐量不会降低。此外,我们修改了注意力掩码,以验证 N+1 个标记,从而在不偏离原始模型输出的情况下实现投机解码。此实现的详细信息在此处记录。

结果

我们使用一个简单的提示展示了使用 Meta 的 Llama2 13B 聊天版本的加速效果。

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

图 2:非投机生成(左侧)与投机生成(右侧)的视觉说明

我们已在内部生产环境中部署了上述解决方案。下图中报告了两个指标——首次标记时间(TTFT)和标记间延迟(ITL),以及不同并发用户数量(这些数据在图表线条上的数字中体现)。我们观察到,对于 Llama2 13B 聊天模型,推测性解码版本的速度几乎是非推测性版本的两倍,对于 Granite 20B 代码模型,速度几乎是三倍,对于所有批处理大小都是如此。对于较小的模型——IBM 的 Granite 7B 和 Meta Llama3 8B 模型,我们也观察到类似的行为。

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

图 3:Llama 13B 的首次标记时间(TTFT - 左)和标记间延迟(ITL - 右),图中标明了并发用户数量

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

图 4:Granite 20B 代码的首次标记时间(TTFT - 左)和标记间延迟(ITL - 右),图中标明了并发用户数量

效率说明

我们进行了多次实验,以确定投机训练的正确配置。这些配置包括:

  1. 投机架构:当前方法允许修改头数,这对应于我们可以预览的标记数量。增加头数也会增加额外的计算需求以及训练的复杂性。在实践中,对于语言模型,我们发现 3-4 个头数效果较好,而对于代码模型,我们发现 6-8 个头数可以带来好处。
  2. 计算:增加头数会导致计算在两个维度上增加,一方面是单次前向传播的延迟增加,另一方面是多个标记所需的计算量。如果投机模型在更多头数下不准确,将导致计算浪费,增加延迟并降低吞吐量。
  3. 内存:增加的计算需求可以通过每个前向传播所需的 HBM 往返次数来抵消。请注意,如果我们正确预测了 3 个标记的预览,我们就节省了三次 HBM 往返时间。

我们为语言模型选择了 3-4 个头,为代码模型选择了 6-8 个头,在 7B 到 20B 的不同模型规模范围内,我们观察到与非投机解码相比,在吞吐量没有损失的情况下,显著提高了延迟。我们开始观察到在超过 64 个批次的吞吐量下降,这在实际中很少发生。

投机解码:训练

投机解码有两种主要方法,一种是通过使用较小的模型(例如,将 Llama 7B 作为 Llama 70B 的投机者),另一种是附加投机头(并对其进行训练)。在我们的实验中,我们发现将投机头附加到模型上在模型质量和延迟提升方面都更为有效。

投机者架构

美杜莎使推测性解码变得流行;他们的方法是在现有模型上添加一个头部,然后对其进行训练以进行推测。我们通过使“头部”分层来修改美杜莎架构,其中每个头部阶段预测一个标记,然后将其馈送到下一个头部阶段。这些多阶段头部在下面的图中表示。我们正在探索通过在多个阶段和基础模型之间共享这些内容来最小化嵌入表的方法。

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

图 4:一个 3 头多阶段推测器的简单架构图。Z 是基础模型的状态。

推测器训练

我们为了效率原因采用了两阶段训练策略。在第一阶段,我们使用长序列长度(4k 个 token)的小批量进行训练,并采用标准的因果 LM 方法进行训练。在第二阶段,我们使用从基础模型生成的短序列长度(256 个 token)的大批量。在这个训练阶段,我们调整头部以匹配基础模型的输出。通过多次实验,我们发现第一阶段与第二阶段 5:2 的步骤比例效果较好。我们在下图中描述了这些阶段的进展。我们使用 PyTorch FSDP 和 IBM FMS 进行投机者的训练。

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

图 5:Llama2-13B 投机者训练第一阶段和第二阶段每头训练损失曲线

结论与未来工作

通过这篇博客,我们发布了一种新的投机解码方法以及以下资产:

  1. 提高一系列模型(Llama3 8B、Llama2 13B、Granite 7B 和 CodeLlama 13B)的跨 token 延迟的模型
  2. 推理的生产质量代码
  3. 训练投机者的食谱

我们正在为 Llama3 70B 和 Mistral 模型训练投机者,并邀请社区贡献力量以及帮助我们改进框架。我们也愿意与主要开源服务框架如 vLLM 和 TGI 合作,将我们的投机解码方法回馈社区,以造福大众。

致谢

有几个团队帮助我们实现了这些推理的延迟改进。我们想感谢 vLLM 团队以干净和可重用的方式创建了分页注意力内核。我们向 Meta 的 PyTorch 团队表示感谢,他们帮助提供反馈以及持续优化 PyTorch 的使用。特别感谢 IBM Research 内部生产团队,他们将这个原型推向了生产并使其更加坚固。向 Stas Bekman 致敬,他为博客提供了有见地的评论,从而改进了对计算、内存和投机者效率之间权衡的解释。

分页注意力内核由 Josh Rosenkranz 和 Antoni Viros i Martin 整合到 IBM FMS 中。投机架构和训练由 Davis Wertheimer、Pavithra Ranganathan 和 Sahil Suneja 完成。建模代码与推理服务器的集成由 Thomas Parnell、Nick Hill 和 Prashant Gupta 完成。