IBM:段黄中,阿列克谢·卡尔维,杨科伊夫曼,陈林松,迪维雅·库马里,斯维塔·萨拉拉,罗伯特·沃克,普拉内特·阿杜苏米利,尼鲁姆特·德萨伊,拉古·甘蒂,西特拉米·西拉姆
Meta:莱斯特·赖特,魏峰,瓦西里·库兹涅佐夫,德里斯·古塞乌斯
在这篇博客中,我们将展示我们如何在训练过程中通过 FSDP2、DTensor 和 torch.compile 结合 torchao 的 float8 通过线性层更新(计算)和 float8 all_gathers 进行权重通信,实现高达 50%的吞吐量速度提升,同时保持损失和评估基准的平衡。我们在 Meta LLaMa 模型架构大小的范围内展示了这些改进,从小型 1.8B 模型到 405B 模型,使训练速度比以往任何时候都要快。
我们使用 Meta Llama3 架构展示了这些改进,并在两个规模上进行了模型质量研究:100B 个标记在 8B 模型大小,以及 50B 个标记在 70B 模型大小,这提供了浮点 8 和 bf16 训练损失曲线的精确比较。我们展示了与 bf16
对应物相比,损失曲线在这些模型训练运行中实现了相同的损失收敛。此外,我们使用 FineWeb-edu 数据集训练了一个 3B 模型到 1T 个标记,并运行了标准评估基准,以确保模型质量保持完好,并且与 bf16
运行相当。
在 IBM 研究实验室,我们计划采用这些能力来改进我们的数据消融,以增加在给定 GPU 预算下可以进行的实验数量。从长远来看,我们将进行更大规模模型运行,以展示 float8
训练的端到端可行性。
什么是浮点 8?
NVIDIA、ARM 和 Intel 在 2022 年的一篇论文中介绍了 float8
格式,该论文展示了使用低精度 float8 训练模型的可能性,同时不牺牲模型质量。随着新一代 GPU 如 NVIDIA Hopper 系列的推出,FP8 训练成为可能,由于原生 float8 张量核心支持,训练吞吐量有望提高 2 倍以上。实现这一承诺仍面临一些挑战:
(i) 在 float8
中启用核心模型操作如 matmul
和 attention
,
(ii) 在分布式框架中启用 float8
训练,
(iii) 在 float8
中启用 GPU 之间的权重通信。
虽然 NVIDIA 库启用了 float8
matmul
,但后两个是在 FSDP2
和 torchao
的最新更新中提供的。
在这篇博客中,我们使用 torchtitan 作为训练的入口点,IBM 的确定性数据加载器,torchao 的 float8
线性层实现,以及与 FSDP2 结合的最新 PyTorch 夜班版本。对于这次训练,我们使用每个张量(张量级)的 float8 缩放粒度,而不是行级。我们利用 torch.compile
确保获得最大的性能提升。我们正在使用 SDPA 在 bf16
中计算 attention
,并正在努力将其也迁移到 float8。
实验
我们进行了各种实验来展示 float8 训练的好处。首先,我们确保模型质量不会受到影响。为了验证这一点,我们训练了一个 8B 模型和一个 70B 模型,运行了几千步,并比较了 float8 和 bf16 训练运行之间的损失曲线。我们的实验在三个不同的 H100 集群上进行,这些集群配置了 128、256 和 512 个 H100 GPU,环境各不相同,以展示可重复性。第一个集群是在 Meta 的 Grand Teton 上定制的,具有 400Gbps 的定制互连,第二个是 IBM 研究集群,具有 3.2Tbps 的 Infiniband 互连,第三个是 IBM Cloud 集群,具有 3.2Tbps 的 RoCE 互连,用于 GPU 之间的通信。
首先,我们在下面的图中绘制了这两个模型的损失曲线比较,以展示几千步的损失一致性。
图 1:(a)8B 模型 2k 步损失一致性,(b)70B 模型 1k 步损失一致性
我们观察到,在这些不同的模型和不同环境中,我们获得了小规模标记的损失均衡。接下来,我们描述了从 18 亿到 405 亿不同模型大小的吞吐量增益。我们探索了 float8 和 bf16 训练运行的最佳批量大小和激活检查点方案,以确定每秒每 GPU(wps)的指标并报告性能提升。对于 405B 模型,我们利用 DTensor
进行 FSDP2 张量并行训练。我们使用 8K 序列长度进行所有测量。
模型大小 | wps(bf16) | wps(float8) | 百分比提升 |
1.8B | 29K | 35K | 18% |
8B | 8K | 10K | 28% |
70B | 956 | 1430 | 50% |
405B (TP4) | 149 | 227 | 52% |
表 1:相对于 bf16 的性能提升(bf16 和 float8 都使用 torch.compile)
从表 1 中观察到,大型模型(70B 和 405B)的增益可达 50%,小型模型的增益在约 20%到 30%之间。在进一步的实验中,我们发现添加 float8
all_gather
可以使 float8
的计算本身提升约 5%,这与本博客中的观察结果一致。
其次,为了展示 FP8 模型的有效性,我们根据 Llama3 架构训练了一个 3B 模型,在 Hugging Face 的 FineWeb-edu 数据集上进行了 1T 个 token 的训练。我们使用 lm-eval-harness
框架进行了评估,并在下表中展示了一部分结果。我们观察到 bf16
的性能略好于 float8
的分数(大约一个百分点)。虽然一些分数在使用 bf16
时显著提高(例如,MMLU 提高了 3 分),但我们预计当选择合适的超参数并在更大规模的训练运行中(例如, bf16
运行批次大小减半,众所周知,较小的批次大小运行可以提高评估分数)时,这些差距将消失。
基准 | 分数(float8) | 分数(bf16) |
MMLU(5 次射击) | 0.26 | 0.29 |
ARC-e | 0.73 | 0.73 |
ARC-c | 0.43 | 0.46 |
希腊狂热 | 0.65 | 0.67 |
科学 | 0.89 | 0.88 |
开放书问答 | 0.43 | 0.43 |
PIQA | 0.76 | 0.76 |
Winogrande | 0.60 | 0.65 |
平均 | 0.59 | 0.60 |
表格 2:float8 训练模型在 FP16 模式下评估(在 1T FineWeb 预训练令牌)的基准分数。
最后,我们将实验扩展到 IBM 云集群上的 512 个 H100 GPU。我们能够在 512 个 GPU 规模下重现我们观察到的结果和加速效果。以下表格仅总结了大型模型的结果(70B 和 405B)。
模型大小 | wps(bf16) | wps(float8) | 百分比增长 |
70B | 960 | 1448 | 51% |
405B (TP4) | 152 | 217 | 43% |
表 3:相对于 bf16 的性能提升(bf16 和 float8 均使用 torch.compile),512 GPU 规模
未来工作
我们还在评估其他形式的并行性,例如上下文并行性。我们计划评估所有这些特性,以展示其可组合性和为训练大规模模型做出选择的能力。
致谢
我们感谢 IBM 研究实验室的 Davis Wertheimer,他使我们能够为 torchtitan 运行启用数据加载器,使我们能够在多个运行中按相同顺序回放数据。我们还感谢 IBM Cloud,它为我们提供了对 H100 集群的早期测试访问权限。