由 IBM 和 Meta 开发

IBM:段黄中,阿列克谢·卡尔维,杨科伊夫曼,陈林松,迪维雅·库马里,斯维塔·萨拉拉,罗伯特·沃克,普拉内特·阿杜苏米利,尼鲁姆特·德萨伊,拉古·甘蒂,西特拉米·西拉姆
Meta:莱斯特·赖特,魏峰,瓦西里·库兹涅佐夫,德里斯·古塞乌斯

在这篇博客中,我们将展示我们如何在训练过程中通过 FSDP2、DTensor 和 torch.compile 结合 torchao 的 float8 通过线性层更新(计算)和 float8 all_gathers 进行权重通信,实现高达 50%的吞吐量速度提升,同时保持损失和评估基准的平衡。我们在 Meta LLaMa 模型架构大小的范围内展示了这些改进,从小型 1.8B 模型到 405B 模型,使训练速度比以往任何时候都要快。

我们使用 Meta Llama3 架构展示了这些改进,并在两个规模上进行了模型质量研究:100B 个标记在 8B 模型大小,以及 50B 个标记在 70B 模型大小,这提供了浮点 8 和 bf16 训练损失曲线的精确比较。我们展示了与 bf16 对应物相比,损失曲线在这些模型训练运行中实现了相同的损失收敛。此外,我们使用 FineWeb-edu 数据集训练了一个 3B 模型到 1T 个标记,并运行了标准评估基准,以确保模型质量保持完好,并且与 bf16 运行相当。

在 IBM 研究实验室,我们计划采用这些能力来改进我们的数据消融,以增加在给定 GPU 预算下可以进行的实验数量。从长远来看,我们将进行更大规模模型运行,以展示 float8 训练的端到端可行性。

什么是浮点 8?

NVIDIA、ARM 和 Intel 在 2022 年的一篇论文中介绍了 float8 格式,该论文展示了使用低精度 float8 训练模型的可能性,同时不牺牲模型质量。随着新一代 GPU 如 NVIDIA Hopper 系列的推出,FP8 训练成为可能,由于原生 float8 张量核心支持,训练吞吐量有望提高 2 倍以上。实现这一承诺仍面临一些挑战:
(i) 在 float8 中启用核心模型操作如 matmulattention
(ii) 在分布式框架中启用 float8 训练,
(iii) 在 float8 中启用 GPU 之间的权重通信。
虽然 NVIDIA 库启用了 float8 matmul ,但后两个是在 FSDP2torchao 的最新更新中提供的。

在这篇博客中,我们使用 torchtitan 作为训练的入口点,IBM 的确定性数据加载器,torchao 的 float8 线性层实现,以及与 FSDP2 结合的最新 PyTorch 夜班版本。对于这次训练,我们使用每个张量(张量级)的 float8 缩放粒度,而不是行级。我们利用 torch.compile 确保获得最大的性能提升。我们正在使用 SDPA 在 bf16 中计算 attention ,并正在努力将其也迁移到 float8。

实验

我们进行了各种实验来展示 float8 训练的好处。首先,我们确保模型质量不会受到影响。为了验证这一点,我们训练了一个 8B 模型和一个 70B 模型,运行了几千步,并比较了 float8 和 bf16 训练运行之间的损失曲线。我们的实验在三个不同的 H100 集群上进行,这些集群配置了 128、256 和 512 个 H100 GPU,环境各不相同,以展示可重复性。第一个集群是在 Meta 的 Grand Teton 上定制的,具有 400Gbps 的定制互连,第二个是 IBM 研究集群,具有 3.2Tbps 的 Infiniband 互连,第三个是 IBM Cloud 集群,具有 3.2Tbps 的 RoCE 互连,用于 GPU 之间的通信。

首先,我们在下面的图中绘制了这两个模型的损失曲线比较,以展示几千步的损失一致性。

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

图 1:(a)8B 模型 2k 步损失一致性,(b)70B 模型 1k 步损失一致性

我们观察到,在这些不同的模型和不同环境中,我们获得了小规模标记的损失均衡。接下来,我们描述了从 18 亿到 405 亿不同模型大小的吞吐量增益。我们探索了 float8 和 bf16 训练运行的最佳批量大小和激活检查点方案,以确定每秒每 GPU(wps)的指标并报告性能提升。对于 405B 模型,我们利用 DTensor 进行 FSDP2 张量并行训练。我们使用 8K 序列长度进行所有测量。

模型大小 wps(bf16) wps(float8) 百分比提升
1.8B 29K 35K 18%
8B 8K 10K 28%
70B 956 1430 50%
405B (TP4) 149 227 52%

表 1:相对于 bf16 的性能提升(bf16 和 float8 都使用 torch.compile)

从表 1 中观察到,大型模型(70B 和 405B)的增益可达 50%,小型模型的增益在约 20%到 30%之间。在进一步的实验中,我们发现添加 float8 all_gather 可以使 float8 的计算本身提升约 5%,这与本博客中的观察结果一致。

其次,为了展示 FP8 模型的有效性,我们根据 Llama3 架构训练了一个 3B 模型,在 Hugging Face 的 FineWeb-edu 数据集上进行了 1T 个 token 的训练。我们使用 lm-eval-harness 框架进行了评估,并在下表中展示了一部分结果。我们观察到 bf16 的性能略好于 float8 的分数(大约一个百分点)。虽然一些分数在使用 bf16 时显著提高(例如,MMLU 提高了 3 分),但我们预计当选择合适的超参数并在更大规模的训练运行中(例如, bf16 运行批次大小减半,众所周知,较小的批次大小运行可以提高评估分数)时,这些差距将消失。

基准 分数(float8) 分数(bf16)
MMLU(5 次射击) 0.26 0.29
ARC-e 0.73 0.73
ARC-c 0.43 0.46
希腊狂热 0.65 0.67
科学 0.89 0.88
开放书问答 0.43 0.43
PIQA 0.76 0.76
Winogrande 0.60 0.65
平均 0.59 0.60

表格 2:float8 训练模型在 FP16 模式下评估(在 1T FineWeb 预训练令牌)的基准分数。

最后,我们将实验扩展到 IBM 云集群上的 512 个 H100 GPU。我们能够在 512 个 GPU 规模下重现我们观察到的结果和加速效果。以下表格仅总结了大型模型的结果(70B 和 405B)。

模型大小 wps(bf16) wps(float8) 百分比增长
70B 960 1448 51%
405B (TP4) 152 217 43%

表 3:相对于 bf16 的性能提升(bf16 和 float8 均使用 torch.compile),512 GPU 规模

未来工作

我们还在评估其他形式的并行性,例如上下文并行性。我们计划评估所有这些特性,以展示其可组合性和为训练大规模模型做出选择的能力。

致谢

我们感谢 IBM 研究实验室的 Davis Wertheimer,他使我们能够为 torchtitan 运行启用数据加载器,使我们能够在多个运行中按相同顺序回放数据。我们还感谢 IBM Cloud,它为我们提供了对 H100 集群的早期测试访问权限。