在本文中,我们介绍了 PyTorch 中针对大型语言模型的端到端量化感知训练(QAT)流程。我们展示了 PyTorch 中的 QAT 如何能够在 hellaswag 上恢复高达 96%的精度下降,在 wikitext 上恢复高达 68%的困惑度下降,与训练后量化(PTQ)相比。我们介绍了 torchao 中的 QAT API,并展示了用户如何利用它们在 torchtune 中进行微调。
图 1:在 C4 数据集(en 子集)上微调的 Llama3-8B,使用 int8 per token 动态激活+int4 按通道分组权重,带有和没有 QAT,在 A100 GPU 上评估 hellaswag 和 wikitext。注意 wikitext 的日志刻度(越低越好)。
为了证明 QAT 在端到端流程中的有效性,我们进一步将量化模型降低到 XNNPACK,这是一个针对 iOS 和 Android 等后端高度优化的神经网络库,通过 executorch 实现。降低到 XNNPACK 后,QAT 模型比 PTQ 模型困惑度降低了 16.8%,同时保持了相同的模型大小和设备上的推理和生成速度。
模型指标降低 | PTQ | QAT |
维基文本词困惑度(↓) | 23.316 | 19.403 |
维基文本字节困惑度(↓) | 1.850 | 1.785 |
维基文本每字节比特数(↓) | 0.887 | 0.836 |
模型大小 | 3.881 GB | 3.881 GB |
设备端推理速度 | 5.065 tok/s | 5.265 每秒词 |
离线生成速度 | 8.369 每秒词 | 8.701 每秒词 |
表格 1:QAT 在 Llama3-8B 模型上实现了 16.8%的更低困惑度,模型大小、设备上的推理和生成速度保持不变,并在 XNNPACK 上降低。线性层使用 int8 按 token 动态激活+int4 按通道分组权重进行量化,嵌入层额外使用 32 组大小进行 int4 量化(QAT 仅应用于线性层)。由于设备上无法进行评估,因此在服务器 CPU 上使用 5 个样本和最大序列长度为 127 进行 Wikitext 评估(所有 Wikitext 结果中,数值越低越好)。在三星 Galaxy S22 智能手机上对设备上的推理和生成进行了基准测试。
QAT API
我们期待用户尝试 torchao 中的 QAT API,该 API 可用于训练和微调。此 API 涉及两个步骤:准备和转换。准备步骤对模型中的线性层应用转换,以模拟训练期间的量化数值;转换步骤则在实际训练后将这些层量化为更低的位宽。转换后的模型可以像 PTQ 模型一样使用:
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
# Smaller version of llama3 to fit in a single GPU
model = llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)
# inference or generate
使用 torchtune 进行微调
我们还将其 QAT 流程集成到 torchtune 中,并提供了在分布式环境中运行此流程的配方,类似于现有的全微调分布式配方。用户还可以在LLM微调期间应用 QAT,通过运行以下命令。请参阅此 README 获取更多详细信息。
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
什么是量化感知训练?
量化感知训练(QAT)是一种常见的量化技术,用于减轻量化过程中产生的模型精度/困惑度下降。这是通过在训练过程中模拟量化数值,同时保持权重和/或激活在原始数据类型(通常是浮点型)中实现的,有效地“模拟量化”值,而不是实际将它们转换为更低的位宽:
# PTQ: x_q is quantized and cast to int8
# scale and zero point (zp) refer to parameters used to quantize x_float
# qmin and qmax refer to the range of quantized values
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
# QAT: x_fq is still in float
# Fake quantize simulates the numerics of quantize + dequantize
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
由于量化涉及非可微操作,如四舍五入,因此 QAT 反向传播通常使用直通估计器(STE),这是一种估计通过非光滑函数流动的梯度的机制,以确保传递给原始权重的梯度仍然有意义。以这种方式,梯度是在知道权重最终将在训练后量化的情况下计算的,从而有效地允许模型在训练过程中调整量化噪声。请注意,QAT 的替代方法是量化训练,它实际上在训练期间将值转换为更低的位宽数据类型,但之前的努力仅在 8 位宽上取得了成功,而 QAT 即使在更低的位宽下也有效。
PyTorch 中的 QAT
我们在 torchao 原型中添加了初始的 QAT 流程。目前我们支持线性层使用 int8 动态每 token 激活+int4 分组每通道权重(简称 8da4w)。这些设置是由边缘后端上内核的可用性和关于LLM量化的先前研究相结合而激发的,该研究发现在LLMs与其他量化方案相比,每 token 激活和每组权重量化可以实现最佳模型质量。
图 2:torchao QAT 流程。此流程包括两个步骤:(1)准备,将假量化操作插入模型的线性层中;(2)转换,在训练后用实际的量化和反量化操作替换这些假量化操作。
此流程生成的量化模型与 PTQ 流程使用相同的量化设置(通过 Int8DynActInt4WeightQuantizer)完全相同,但量化权重实现了更优的准确率和困惑度。因此,我们可以使用从 QAT 流程转换的模型作为 PTQ 模型的直接替换,并重用所有后端委托逻辑和底层内核。
实验结果
本博客文章中所有实验均使用上述描述的 torchtune QAT 集成进行。我们使用 6-8 个每个 80GB 的 A100 GPU,对 C4 数据集(en 子集)上的 Llama2-7B 和 Llama3-8B 进行微调,共 5000 步。对于所有实验,我们使用批大小=2,学习率=2e-5,最大序列长度=4096(Llama2)和 8192(Llama3),使用完全分片数据并行(FSDP)作为我们的分布式策略,并使用激活检查点来减少内存占用。对于 8da4w 实验,我们使用权重组的规模为 256。
由于预训练数据集不易获取,我们在微调过程中执行 QAT。经验表明,在最初的 N 步中禁用伪量化可以带来更好的结果,这可能是由于这样做允许权重在我们开始引入量化噪声到微调过程中之前稳定下来。对于所有我们的实验,我们禁用伪量化前 1000 步。
我们使用 torchtune 中的 lm-evaluation-harness 集成来评估我们的量化模型。我们报告了来自各种任务的评估结果,这些任务通常用于评估LLMs,包括 hellaswag,一个常识句子补全任务,wikitext,一个下一个标记/字节预测任务,以及一些问答任务,如 arc、openbookqa 和 piqa。对于 wikitext,困惑度指的是模型预测下一个单词或字节的能力的倒数(越低越好),而 bits_per_byte
指的是预测下一个字节所需的位数(这里也越低越好)。对于所有其他任务, acc_norm
指的是按目标字符串的字节长度归一化的准确率。
Int8 动态激活 + Int4 权重量化(8da4w)
从 Llama2 8da4w 量化开始,我们发现 QAT 能够恢复 hellaswag 上与 PTQ 相比的 62%的归一化准确率下降,以及 wikitext 上 58%和 57%的单词和字节困惑度下降(分别)。我们在大多数其他任务上也看到了类似的改进。
图 3a:Llama2-7B 8da4w 量化,带 QAT 和不带 QAT
图 3b:Llama2-7B 8da4w 量化(带和不带 QAT),在维基文本上评估(越低越好)
Llama3 8da4w 量化在 QAT 的帮助下取得了更加显著的改进。在 hellaswag 评估任务中,我们能够恢复与 PTQ 相比在 hellaswag 上的 96%的归一化准确度下降,与未量化的准确度相比整体下降最小(<1%)。在维基文本评估任务中,QAT 恢复了 68%和 65%的词和字节困惑度下降(分别)。即使在 Llama2 QAT 难以处理的 arc_challenge 上,我们也能够恢复 51%的归一化准确度下降。
图 4a:Llama3-8B 8da4w 量化(带和不带 QAT)
图 4b:Llama3-8B 8da4w 量化(带和不带 QAT),在维基文本上评估(越低越好)
仅低比特位量化
我们进一步扩展了 torchao QAT 流程,支持 2 比特和 3 比特权重仅量化,并对 Llama3-8B 进行了相同的实验。在更低的比特宽度下,量化退化更为严重,因此我们为所有实验使用 32 的组大小进行更细粒度的量化。
然而,这仍然不足以满足 2 比特 PTQ 的需求,导致 wikitext 困惑度爆炸。为了缓解这个问题,我们利用先前敏感性分析的知识,即 Llama3 模型的前 3 层和最后 2 层对量化最敏感,因此跳过对这些层的量化,以换取量化模型大小的适度增加(2 比特为 1.78 GB,3 比特为 1.65 GB)。这使 wikitext 单词困惑度从 603336 降至 6766,这是一个显著的改进,但仍然远非可接受的。为了进一步提高量化模型,我们转向 QAT。
图 5a:Llama3-8B 2 比特权重仅量化,带 QAT 和不带 QAT,在 wikitext 上评估(越低越好)。带有“skip”的条形表示跳过对模型前 3 层和最后 2 层的量化,这些层对量化更敏感。注意对数刻度。
我们观察到在跳过前 3 层和最后 2 层的量化操作的同时应用 QAT,进一步将词困惑度降低到更合理的 30(从 6766)。更普遍地说,与 PTQ 相比,QAT 在 hellaswag 上能够恢复 53%的标准化精度下降,以及在 wikitext 上的词和字节困惑度下降分别达到 99%和 89%。然而,如果不跳过敏感层,QAT 在减轻量化模型质量下降方面效果远不如跳过这些层。
图 5b:Llama3-8B 仅 2 位权重量化,带 QAT 和不带 QAT。带有“skip”的条形表示跳过模型前 3 层和最后 2 层的量化,这些层对量化更敏感。
对于仅 3 位权重量化,即使不跳过前 3 层和最后 2 层,QAT 仍然有效,尽管跳过这些层仍然对 PTQ 和 QAT 都带来了更好的结果。在跳过的情况下,QAT 能够恢复与 PTQ 相比在 hellaswag 上的 63%的标准化精度下降,以及在 wikitext 上的词和字节困惑度下降分别达到 72%和 65%。
图 6a:Llama3-8B 3 位权重量化(带和不带 QAT)。带有“跳过”的条形表示跳过对模型前 3 层和最后 2 层的量化,这些层对量化更敏感。
图 6b:Llama3-8B 3 位权重量化(带和不带 QAT),在 wikitext 上评估(越低越好)。带有“跳过”的条形表示跳过对模型前 3 层和最后 2 层的量化,这些层对量化更敏感。注意对数刻度。
QAT 开销
QAT 在模型中插入许多假量化操作,给微调速度和内存使用增加了相当大的开销。例如,对于 Llama3-8B 这样的模型,我们有(32 * 7) + 1 = 225 个线性层,每个层至少有 1 个假量化操作用于权重,可能还有 1 个假量化操作用于输入激活。内存占用增加也很显著,因为我们不能就地修改权重,因此需要在应用假量化之前克隆它们,尽管可以通过启用激活检查点来大部分缓解这种开销。
在我们的微基准测试中,我们发现 8da4w QAT 微调比常规全量微调慢约 34%。使用激活检查点,每个 GPU 的内存增加约为 2.35 GB。尽管我们可能在未来通过 torch.compile 加速计算,但这些开销大多与 QAT 的工作原理密切相关。
每个 GPU 的统计数据 | 全量微调 | QAT 微调 |
每秒中值令牌数 | 546.314 tok/s | 359.637 tok/s |
中值峰值内存 | 67.501 GB | 69.850 GB |
表 2:在 6 个 A100 GPU(每个 80GB 内存)上进行的 Llama3 QAT 微调开销,针对 int8 每 token 动态激活和 int4 按通道分组权重
展望未来
在本文中,我们通过 torchao 提出了一个 QAT 流程,将其与 torchtune 中的微调 API 集成,并展示了其相对于 PTQ 恢复大部分量化降级的潜力,并在某些任务上与非量化性能相匹配。未来有许多探索方向:
- 超参数调整。大量超参数调整可能会进一步改善微调和 QAT 的结果。除了学习率、批量大小、数据集大小和微调步数等一般超参数外,我们还应该调整 QAT 特定的超参数,例如何时开始/停止伪量化、伪量化的步数以及伪量化值的正则化参数。
- 异常值减少技术。在我们的实验中,我们发现 PTQ 和 QAT 都容易受到异常值的影响。除了在微调期间进行简单的钳位和正则化之外,我们可以探索允许网络学习如何控制这些异常值的技术(例如,学习量化范围、剪枝 softmax 和门控注意力),或者甚至可以从训练后设置中借用异常值抑制技术(例如,SpinQuant、SmoothQuant)并在微调过程中适度应用。
- 混合精度和更复杂的数据类型。特别是在低位域,我们发现对于某些敏感层跳过量化对于 PTQ 和 QAT 都有效。我们是否需要完全跳过这些层的量化,或者我们仍然可以量化它们,只是降低位宽?在 QAT 的背景下探索混合精度量化将非常有趣。使用 MX4 等新数据类型进行训练也是另一个有希望的方向,特别是考虑到即将推出的 Blackwell GPU 将不再支持 int4 张量核心。
- 与 LoRA 和 QLoRA 的兼容性。我们目前在 torchtune 中的 QAT 集成仅支持完整的微调工作流程。然而,许多用户希望使用低秩适配器进行微调,以显著减少他们的内存占用。将 QAT 与 LoRA / QLoRA 等技术相结合将使用户能够获得这些方法的内存和性能优势,同时产生一个最终以最小模型质量下降进行量化的模型。
- 使用 torch.compile 实现可组合性。这是在 QAT(量化感知训练)中显著加快假量化计算并减少内存占用的一种潜在方法。目前,torch.compile 与 torchtune 中全分布式微调食谱(带 QAT 或不带 QAT)使用的分布式策略不兼容,但未来将添加支持。
- 其他层的量化。在这项工作中,我们只探索了量化线性层。然而,在长序列长度的情况下,KV 缓存经常成为吞吐量瓶颈,可以达到数十 GB,因此LLM-QAT 在量化激活和权重的同时,也探索了量化 KV 缓存。先前的工作在其他基于 transformer 的模型中将嵌入层量化到 2 位也取得了成功。
- 对性能良好的 CUDA 内核进行端到端评估。这项工作的自然扩展是为性能良好的 CUDA 内核提供端到端 QAT 流程,类似于现有的 8da4w QAT 流程通过 executorch 降低到 XNNPACK 内核。对于仅对 int4 权重进行量化,我们可以利用高效的 int4 权重 mm 内核和位打包进行量化,并且正在进行为该内核添加 QAT 支持的工作:https://github.com/pytorch/ao/pull/383。对于 8da4w 量化,cutlass 中也在添加混合 4 位/8 位 GEMM。这将需要构建一个高效的 8da4w CUDA 内核。
QAT 代码可以在这里找到。请参考这个 torchtune 教程开始使用。如果您有任何进一步的问题,请随时在 torchao github 上打开一个问题或联系 andrewor@meta.com。我们欢迎您的反馈和贡献!