由王琳达、伊万·斯莫瑟、卡蒂卡伊·坎达尔瓦尔撰写

在本文中,我们介绍了一项案例研究,即使用 torchtune 的知识蒸馏方法将 Llama 3.1 8B 模型蒸馏为 Llama 3.2 1B 模型。我们展示了知识蒸馏(KD)如何在训练后用于提高指令遵循任务的表现,并展示了用户如何利用此方法。

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术,它将知识从较大的(教师)模型传递到较小的(学生)模型。较大的模型具有更多的参数和知识容量,然而,这种较大的容量在部署时也更为计算密集。知识蒸馏可以将较大模型的知识压缩到较小模型中。其理念是通过学习较大模型的输出,可以提升较小模型的表现。

知识蒸馏是如何工作的?

通过在迁移集上训练,将知识从教师模型传递到学生模型,学生模型被训练来模仿教师模型的标记级概率分布。假设教师模型的分布与迁移数据集相似。下面的图示是知识蒸馏工作原理的简化表示。

Figure 1: Simplified representation of knowledge transfer from teacher to student model

图 1:教师到学生模型知识转移的简化表示

作为针对LLMs的知识蒸馏是一个活跃的研究领域,存在一些研究不同损失方法的论文,如 MiniLLM、DistiLLM、AKL 和通用 KD。在本案例研究中,我们关注标准交叉熵(CE)损失和前向 Kullback-Leibler(KL)散度损失作为基线。前向 KL 散度旨在通过迫使学生的分布与所有教师的分布对齐来最小化差异。

知识蒸馏为什么有用?

知识蒸馏的思路是,一个较小的模型可以使用教师模型的输出作为额外的信号,而不是从头开始训练或使用监督微调,从而实现更好的性能。例如,Llama 3.2 轻量级 1B 和 3B 文本模型通过结合 Llama 3.1 8B 和 70B 的 logits 来恢复剪枝后的性能。此外,对于指令遵循任务的微调,LLM蒸馏的研究表明,知识蒸馏方法可以优于单独的监督微调(SFT)。

Model 方法 DollyEval 自安装 S-NI
GPT-4 评估 GPT-4 评估 Rouge-L
Llama 7B SFT 73.0 69.2 32.4
KD 73.7 70.5 33.7
MiniLLM 76.4 73.1 35.5
Llama 1.1B SFT 22.1 - 27.8
KD 22.2 - 28.1
AKL 24.4 - 31.4
OpenLlama 3B SFT 47.3 41.7 29.3
KD 44.9 42.1 27.9
SeqKD 48.1 46.0 29.1
DistiLLM 59.9 53.3 37.6

表 1:知识蒸馏方法与监督微调的比较

下面是一个简化的例子,说明了知识蒸馏与监督微调的不同之处。

监督微调 知识蒸馏
   
model = llama3_2_1b()
ce_loss = CrossEntropyLoss()
kd_loss = ForwardKLLoss()

tokens, labels = batch["tokens"], batch["labels"]
logits = model(tokens, ...)

loss = ce_loss(logits, labels)
loss.backward()

   
   
   
model = llama3_2_1b()
teacher_model = llama3_1_8b()
ce_loss = CrossEntropyLoss()
kd_loss = ForwardKLLoss()

tokens, labels = batch["tokens"], batch["labels"]
logits = model(tokens, ...)
teacher_logits = teacher_model(tokens, ...)
loss = ce_loss(logits, labels) + kd_loss(logits, teacher_logits, labels)
loss.backward()
   
   

torchtune 中的 KD 配方

使用 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3,以及其他LLM模型系列,利用 torchtune 的 KD 配方。本配方的目标是微调 Llama3.2-1B,通过从 Llama3.1-8B 蒸馏来在 Alpaca 指令遵循数据集上进行。本配方侧重于训练后,并假设教师和学生模型已经进行了预训练。

首先,我们需要下载模型权重。为了与其他 torchtune 微调配置保持一致,我们将使用 Llama3.1-8B 的指令微调模型作为教师模型,Llama3.2-1B 作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

为了使教师模型分布与 Alpaca 数据集相似,我们将使用 LoRA 对教师模型进行微调。根据我们下一节展示的实验,我们发现当教师模型已经在目标数据集上进行微调时,KD 的表现更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令将微调后的 8B 模型蒸馏到 1B 模型,在单个 GPU 上执行。对于这个案例研究,我们使用了一个 A100 80GB GPU。我们还有一个用于在多个设备上运行的分布式配方。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在本节中,我们展示了如何通过更改配置和超参数来影响性能。默认情况下,我们的配置使用的是经过 LoRA 微调的 8B 教师模型,下载了 1B 的学生模型,学习率为 3e -4 ,KD 损失比为 0.5。对于这个案例研究,我们在 alpaca_cleaned_dataset 上进行了微调,并通过 EleutherAI LM 评估工具在 truthfulqa_mc2、hellaswag 和 commonsense_qa 任务上评估了模型。让我们来看看以下因素的影响:

  1. 使用微调的教师模型
  2. 使用微调的学生模型
  3. KD 损失比和学习率超参数调整

使用微调的教师模型

配置文件中的默认设置使用微调的教师模型。现在,我们先来看看不微调教师模型的效果。

在损失中考虑损失,使用基线 8B 作为教师模型的结果比使用微调的教师模型损失更高。KD 损失也保持相对稳定,表明教师模型应该与迁移数据集具有相同的分布。

Figure 2: (left to right) KD loss from forward KL divergence, class loss from cross entropy, total loss: even combination of KD and class loss.

图 2:(从左到右)前向 KL 散度中的 KD 损失,交叉熵中的类别损失,总损失:KD 损失和类别损失的合理组合。

在我们的基准测试中,我们可以看到 1B 模型的监督微调比基线 1B 模型实现了更高的准确率。通过使用微调的 8B 教师模型,我们在 truthfulqa 上看到可比的结果,在 hellaswag 和 commonsense 上有所改进。当使用基线 8B 作为教师时,我们在所有指标上都有所改进,但低于其他配置。

Model TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基准 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基准 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基线 8B 作为教师 0.444 0.4576 0.6123 0.5561
使用微调后的 8B 作为教师 0.4481 0.4603 0.6157 0.5569

表 2:使用基线和微调后的 8B 作为教师模型之间的比较

使用微调的学生模型

在这些实验中,我们研究当学生模型已经微调时 KD 的影响。我们使用基线模型和微调的 8B 和 1B 模型的不同组合来分析影响。

根据损失图,使用微调的教师模型无论学生模型是否微调,都能降低损失。还有一个有趣的现象值得注意,即使用微调的学生模型时,类别损失开始增加。

Figure 3: Comparing losses of different teacher and student model initializations

图 3:比较不同教师和学生模型初始化的损失

使用微调的学生模型进一步提高了 truthfulqa 的准确率,但对于 hellaswag 和 commonsense,准确率却下降了。使用微调的教师模型和基线学生模型在 hellaswag 和 commonsense 数据集上取得了最佳结果。基于这些发现,最佳配置将取决于你正在优化的评估数据集和指标。

Model 真实问答 希腊狂人 常识
mc2 精确度 精确度_归一化 精确度
基准 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基线 8B 和基线 1B 进行 KD 0.444 0.4576 0.6123 0.5561
使用基线 8B 和微调 1B 的 KD 0.4508 0.448 0.6004 0.5274
使用微调 8B 和基线 1B 的 KD 0.4481 0.4603 0.6157 0.5569
使用微调 8B 和微调 1B 的 KD 0.4713 0.4512 0.599 0.5233

表 3:使用基线和微调的教师和学生模型的比较

超参数调整:学习率

默认情况下,该配方具有 3e-4 的学习率。在这些实验中,我们将学习率从最高 1e-3 调整为最低 1e-5。

根据损失图,所有学习率的结果损失相似,除了 1e-5,它具有更高的 KD 损失和类别损失。

Figure 4: Comparing losses of different learning rates

图 4:比较不同学习率的损失

根据我们的基准测试,最佳学习率取决于你优化的是哪个指标和任务。

Model 学习率 真实问答 希腊狂热者 常识
mc2 acc acc_norm acc
基准 Llama 3.1 8B - 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B - 0.5475 0.6031 0.7951 0.7789
基准 Llama 3.2 1B - 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B - 0.4492 0.4595 0.6132 0.5528
使用微调的 8B 和基线 1B 进行 KD 0.0003 0.4481 0.4603 0.6157 0.5569
使用微调的 8B 和基线 1B 进行 KD 0.001 0.4453 0.4535 0.6071 0.5258
使用微调的 8B 和基线 1B 进行 KD 0.0001 0.4489 0.4606 0.6156 0.5586
使用微调的 8B 和基线 1B 进行 KD 0.00001 0.4547 0.4548 0.6114 0.5487

表 4:调整学习率的影响

超参数调整:KD 比率

默认情况下,KD 比率设置为 0.5,对类别损失和 KD 损失给予同等权重。在这些实验中,我们研究了不同 KD 比率的影响,其中 0 仅使用类别损失,1 仅使用 KD 损失。

总体而言,基准测试结果表明,对于这些任务和指标,更高的 KD 比率表现略好。

Model kd_ratio (lr=3e-4) TruthfulQA hellaswag 常识
mc2 acc acc_norm acc
基准 Llama 3.1 8B - 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B - 0.5475 0.6031 0.7951 0.7789
基准 Llama 3.2 1B - 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B - 0.4492 0.4595 0.6132 0.5528
使用微调的 8B 和基线 1B 进行 KD 0.25 0.4485 0.4595 0.6155 0.5602
使用微调的 8B 和基线 1B 进行 KD 0.5 0.4481 0.4603 0.6157 0.5569
使用微调的 8B 和基线 1B 进行 KD 0.75 0.4543 0.463 0.6189 0.5643
使用微调的 8B 和基线 1B 的 KD 1.0 0.4537 0.4641 0.6177 0.5717

表 5:调整 KD 比率的效应

展望未来

在本文中,我们介绍了一项研究,通过在 Llama 3.1 8B 和 Llama 3.2 1B logits 上使用前向 KL 散度损失,使用 torchtune 对LLMs进行蒸馏的方法。为了进一步提高性能和提供更多蒸馏方法的灵活性,有许多未来探索的方向。

  • 扩展 KD 损失功能。KD 配方使用前向 KL 散度损失。然而,如上所述,将学生分布与整个教师分布对齐可能并不有效。有多个论文,如 MiniLLM、DistiLLM 和广义 KD,引入了新的 KD 损失和政策来解决这一限制,并已证明优于使用前向 KL 散度损失的标准交叉熵。例如,MiniLLM 使用反向 KL 散度来防止学生高估教师中低概率区域。DistiLLM 引入了偏斜 KL 损失和自适应训练策略。
  • 启用跨 tokenizer 蒸馏。当前配方要求教师和学生模型使用相同的 tokenizer,这限制了跨不同LLM家族的蒸馏能力。已有关于跨 tokenizer 方法的研究(例如,通用对数蒸馏),我们可以探索。
  • 扩展蒸馏到多模态LLMs和编码器模型。KD 食谱的自然扩展是扩展到多模态LLMs。类似于部署更高效的LLMs,也有部署更小、更高效的多模态LLMs的需求。此外,已有工作展示了LLMs作为编码器模型(例如 LLM2Vec)。从LLMs作为编码器到更小的编码器模型的蒸馏也可能是一个值得探索的方向。