在本文中,我们展示了 FSDP 的扩展性,使用一个预训练示例,一个训练了 2T 个 token 的 70 亿参数模型,并分享了各种我们使用的技巧,以实现每秒 3700 个 token/GPU 的快速训练速度,或 128 个 A100 GPU 每天 40B 个 token。这相当于模型 FLOPS 利用率(MFU)和硬件 FLOPS 利用率(HFU)为 57%。此外,我们还观察到 FSDP 在 512 个 GPU 上的接近线性扩展,这意味着使用这种方法在 512 个 GPU 上训练一个 70 亿参数模型到 2T 个 token,只需不到两周的时间。
IBM 的研究人员训练了一个 Meta Llama 2 70 亿参数架构到 2T 个 token,我们将称之为 LlamaT(est)。这个模型在各种学术基准上展示了与 Llama 2 相当的质量。所有训练代码以及我们实现这种吞吐量的方法都可以在这篇博客中找到。我们还分享了适用于 Llama 2 模型(7B、13B、34B 和 70B)的配置旋钮,适用于 A100 和 H100。
在此过程中,我们还提出了一种适用于 FSDP 的新选择性激活检查点机制,这使我们比标准 FSDP 提升了 10%。我们开源了训练代码库和相关可扩展的数据加载器,作为实现这种吞吐量的方法。
使用 PyTorch 原生路径进行训练的一个关键优势是能够无缝地在多个硬件后端上训练。例如,AllenAI 通过 OLMo 最近发布的用于训练的端到端堆栈也利用了 PyTorch FSDP 在 AMD 和 NVIDIA GPU 上进行训练。我们主要利用了 FSDP 的三个组件来实现我们的吞吐量:
- SDPA Flash 注意力,它实现了融合的注意力内核和高效的注意力计算
- 计算和通信的重叠使得 GPU 的利用率更高
- 选择性激活检查点技术使我们能够在 GPU 内存和计算之间进行权衡
国际商业机器公司(IBM)与 Meta 的 PyTorch 团队紧密合作,在 PyTorch FSDP 上投入了近两年的努力:引入速率限制器以在以太网互连上实现更好的吞吐量,分布式检查点以将检查点时间提高一个数量级,以及为 FSDP 的混合分片模式实现检查点的早期版本。去年年底,我们使用 FSDP 进行端到端模型训练。
训练详情
7B 模型在 128 个 A100 GPU 上训练,具有 400Gbps 网络连接和 GPU 直接 RDMA。我们使用 SDPA FlashAttention v2 进行注意力计算,对于这个模型,我们关闭了限制批量大小的激活检查点,但提供了最高的吞吐量——每个批次 1 百万个 token,对于 128 个 GPU,与激活检查点相比提高了约 10%的吞吐量。在这些参数下,我们几乎实现了计算和通信的完全重叠。我们使用 32 位 AdamW 优化器,beta1 为 0.9,beta2 为 0.95,权重衰减为 0.1,学习率最终达到 3e-5,预热到最大学习率为 3e-4,并在 2T 个 token 上使用余弦调度降至 3e-5。训练使用混合精度 bf16 在一个内部数据集上完成。训练栈使用 IBM 的 Foundation Model Stack 进行模型架构,以及 PyTorch nightlies post-2.2 发布版进行 FSDP 和 SDPA。我们在 2023 年 11 月至 2024 年 2 月期间尝试了几种不同的 nightlies,并观察到吞吐量的提升。
选择性激活检查点
我们共同实现了一种简单有效的选择性激活检查点(AC)机制。在 FSDP 中,常见的做法是检查每个 Transformer 块。一个简单的扩展是检查每个_n_个块,以减少重新计算量,同时增加内存需求。这对于 13B 模型大小非常有效,提高了 10%的吞吐量。对于 7B 模型大小,我们根本不需要激活检查点。FSDP 的未来版本将在操作符级别提供选择性激活检查点,从而实现最优的计算-内存权衡。上述代码的实现在这里。
吞吐量和 MFU、HFU 计算
虽然我们只训练了 7B 模型到 2T 个标记,但我们针对其他模型大小进行了多次实验,以提供最佳配置选项。以下表格总结了两种类型的基础设施——一个具有 128 个 GPU 和 400Gbps 节点间互连的 A100 集群,以及一个具有 96 个 GPU 和 800Gbps 节点间互连的 H100 集群。
模型大小 | 批处理大小 | 激活检查点 | 每秒吞吐量令牌/秒/GPU(A100 80GB 和 400Gbps 互连) | MFU%(A100 80GB) | HFU%(A100 80GB) | 每秒吞吐量令牌/个 GPU(H100 80GB 和 800Gbps 互连) | MFU 百分比(H100 80GB) | HFU 百分比(H100 80GB) |
7B | 2 | 否 | 3700 | 0.57 | 0.57 | 7500 | 0.37 | 0.37 |
13B | 2 | 选择性 | 1800 | 0.51 | 0.59 | 3800 | 0.35 | 0.40 |
34B | 2 | 是的 | 700 | 0.47 | 0.64 | 1550 | 0.32 | 0.44 |
70B | 2 | 是的 | 370 | 0.50 | 0.67 | 800 | 0.34 | 0.45 |
表 1:A100 和 H100 GPU 上各种模型尺寸的模型和硬件 FLOPS 利用率
HFU 数量使用 PyTorch FLOP 计数器和 A100 和 H100 GPU 的理论 bf16 性能计算,而 MFU 数量使用 NanoGPT 和 PaLM 论文中概述的方法计算。我们还注意到,我们故意将较大模型的批大小保持在每个 GPU 2 个,以模仿训练 4k 序列长度模型的选择,并在不超过流行的 4M 个 token 批大小的情况下达到 512 个 GPU。超过这个范围,我们需要张量并行或序列并行。
我们在上面的表格中注意到,对于 A100s,激活重计算会导致 MFU 减少,而 HFU 增加!随着更好的激活检查点方案的引入,我们预计 MFU 将增加并赶上 HFU。然而,我们观察到对于 H100s,MFU 和 HFU 都相对较低。我们分析了 H100 的 PyTorch 性能跟踪,发现由于网络“窥视”存在 10%的差距。此外,我们假设 H100s 的 HBM 带宽是导致 H100s 上 HFU/MFU 减少以及无法获得 3 倍提升(H100s 理论上比 A100s 快 3 倍 - 312 vs 989TFLOPS,但 HBM 带宽只有 A100s 的 2 倍多 - 2.0 vs 3.35TBps)的原因。我们计划尝试其他配置选项,如 Tensor Parallel,以改善 H100s 上 70B 模型的调节。
模型详情
训练的损失曲线如下所示。
图 1:LlamaT 训练损失曲线
2T 检查点通过仓库中提供的脚本转换为 Hugging Face 格式,然后我们使用 lm-evaluation-harness 计算关键学术基准,并通过在 Llama2-7B 上运行进行比较。这些结果记录在下表之中。
评估指标 | Llama2-7B(基线) | LlamaT-7B |
MMLU(零样本) | 0.41 | 0.43 |
MMLU(5 次加权平均) | 0.47 | 0.50 |
Arc 挑战 | 0.46 | 0.44 |
Arc 简单 | 0.74 | 0.71 |
Boolq | 0.78 | 0.76 |
科帕 | 0.87 | 0.83 |
希腊狂热 | 0.76 | 0.74 |
Openbookqa | 0.44 | 0.42 |
皮卡 | 0.79 | 0.79 |
Sciq | 0.91 | 0.91 |
Winogrande | 0.69 | 0.67 |
Truthfulqa | 0.39 | 0.39 |
GSM8k(8 次射击) | 0.13 | 0.11 |
表格 1:语言模型评估工具得分
我们观察到该模型与 Llama2(加粗表示更好)表现具有竞争力。
训练日志
训练过程稳定,没有崩溃,但我们确实遇到了一些小问题:
0-200B tokens:我们观察到迭代时间(执行一个训练步骤所需的时间)有所减慢。我们停止了工作以确保数据加载器没有造成任何减慢,并且检查点性能良好且准确。我们没有发现任何问题。此时,HSDP 检查点代码已在 PyTorch 中可用,我们借此机会切换到 PyTorch 检查点代码。
200B tokens-1.9T:我们在 12 月底没有对工作执行任何手动干预。当我们 1 月初回来时,磁盘空间已满,检查点写入失败,尽管训练工作仍在继续。最后一个已知的检查点是 1.5T。
1.5T-1.7T:我们使用 lm-evaluation-harness 评估了 1.5T 检查点,发现由于 Hugging Face 分词器引入的分隔符和我们的数据加载器也附加了自己的文档分隔符,模型在两个文档之间多训练了一个特殊标记。我们修改了数据加载器以消除额外的特殊标记,并从 1.7T 标记开始使用修改后的数据加载器继续训练。
1.7T-2T:由于特殊标记的变化,损失最初急剧上升,但很快在几十亿个标记内恢复。训练完成时没有进行任何其他手动干预!
关键要点和更高的速度
我们展示了如何使用 FSDP 训练一个 2T 标记的模型,其性能卓越,达到 3700 个标记/秒/GPU,并生成高质量的模型。作为这项练习的一部分,我们开源了所有用于训练的代码和实现这种吞吐量的旋钮。这些旋钮不仅可以用于大规模运行,也可以用于较小规模的调整运行。您可以在这里找到代码。
FSDP API 以 PyTorch 原生方式实现了 ZeRO 算法,并允许调整和训练大型模型。在过去,我们已经看到了 FSDP 的证明点(例如斯坦福 Alpaca、Hugging Face、Llama 2 配方),通过简单的训练循环调整各种LLMs(例如 Meta Llama 2 7B 到 70B Llama),实现了良好的吞吐量和训练时间。
最后,我们注意到有几种杠杆可以加快训练速度:
- 可以加速特定操作(例如,使用 Flash Attention V2 进行注意力计算)的节点优化
- 图优化(例如,融合内核、torch.compile)
- 计算与通信重叠
- 激活重计算
在这篇博客中,我们利用了 1、3 以及 4 的变体,并与 Meta 的 PyTorch 团队紧密合作,以获取 torch.compile(2)以及一个更高级的 4 版本,该版本具有按操作符选择的激活重计算功能。我们计划分享一些简单的格式化代码和示例数据,以便将数据加载到我们的数据加载器中,使其他人能够使用代码库进行模型训练。
致谢
有多个团队参与了达到这一成果,我们想感谢 Meta 和 IBM 的各个团队。特别感谢 PyTorch 分布式团队、Facebook Research 和 Applied AI 团队,他们构建了 FSDP API 并基于我们的反馈进行了改进。我们还感谢 IBM Research 的数据团队,他们整理了本次练习中使用的数据语料库,以及 IBM Research 的基础设施团队(特别是 Claudia Misale、Shweta Salaria 和 Seetharami Seelam),他们优化了 NCCL 和网络配置。通过构建和利用所有这些组件,我们成功展示了 LlamaT 成果。
选择性激活检查点是由 IBM 的 Linsong Chu、Davis Wertheimer、Mudhakar Srivatsa 和 Raghu Ganti 提出的,并由 Meta 的 Less Wright 实现的。
特别感谢 Stas Bekman 和 Minjia Zhang,他们提供了广泛的反馈,帮助改进了博客。他们的见解对于突出优化训练和探索进一步改进的关键方面非常有价值。
附录
通信计算重叠
多节点设置训练的关键方面之一是能够重叠通信和计算。在 FSDP 中,存在多个重叠的机会——在正向传播的 FSDP 单元收集阶段以及反向传播的计算阶段。在正向传播期间重叠收集,同时与前一单元的计算重叠以及与下一单元的收集和梯度散射重叠,有助于将 GPU 利用率提高近 2 倍。我们在 400Gbps 网络互连和 A100 80GB GPU 上展示了这一点。在 HSDP 的情况下,正向传播的预取阶段没有跨节点流量,重叠仅限于反向梯度计算阶段。当然,只有当模型可以在单个节点内分片时,HSDP 才是可行的,这限制了模型的大小约为 30B 参数。
下图展示了 FSDP 中的三个步骤,图中底部为节点间的通信,顶部为图像下半部分的计算流。对于没有激活计算重计算的 7B 模型,我们观察到重叠是完全的。在实践中,由于正向传递的第一个块和反向传递的最后一个块无法重叠,重叠的百分比可以达到 90%。
下面展示了上述三个步骤的放大视图,针对单个步骤。我们可以清楚地看到计算的粒度和通信,以及它们如何以交错的方式重叠。