过去一年中,专家混合(MoE)模型因其强大的开源模型如 DBRX、Mixtral、DeepSeek 等而迅速崛起。在 Databricks,我们与 PyTorch 团队紧密合作,以扩展 MoE 模型的训练。在这篇博客文章中,我们将讨论我们如何使用 PyTorch 分布式和 MegaBlocks(PyTorch 中一个高效的 MoE 开源实现)扩展到超过三千个 GPU。
什么是 MoE?
MoE 模型是一种使用多个专家网络进行预测的模型架构。门控网络用于路由和组合专家的输出,确保每个专家在不同的、专业的标记分布上进行训练。基于 transformer 的大型语言模型架构通常包括一个嵌入层,该层连接到多个 transformer 块(图 1,子图 A)。每个 transformer 块包含一个注意力块和一个密集的前馈网络(图 1,子图 B)。这些 transformer 块堆叠,使得一个 transformer 块的输出成为下一个块的输入。最终输出通过全连接层和 softmax 获得下一个输出的概率。
当在LLMs中使用 MoE 时,密集的前馈层被 MoE 层所取代,该 MoE 层由一个门控网络和多个专家组成(图 1,子图 D)。门控网络通常是一个线性前馈网络,它接收每个标记并生成一组权重,这些权重决定了哪些标记被路由到哪个专家。专家本身通常也实现为前馈网络。在训练过程中,门控网络会适应性地将输入分配给专家,从而使模型能够实现专业化并提高其性能。然后,路由器的输出被用来加权专家的输出,从而给出 MoE 层的最终输出。
图 1:在 transformer 块中使用混合专家
与密集模型相比,MoEs 在给定的计算预算下提供了更有效的训练。这是因为门控网络只将标记发送到专家子集,减少了计算负载。因此,模型的能力(其总参数数量)可以在不按比例增加计算需求的情况下增加。在推理过程中,只使用部分专家,因此 MoE 能够比密集模型更快地执行推理。然而,整个模型需要加载到内存中,而不仅仅是使用的专家。
MoEs 中的稀疏性提高了计算效率,这是因为特定的标记只会被路由到专家子集。专家的数量以及专家的选择取决于门控网络的实现,但常见的方法是选择前 k 个。门控网络首先为每个专家预测一个概率值,然后将标记路由到前 k 个专家以获得输出。然而,如果所有标记总是路由到相同的专家子集,则训练变得低效,其他专家最终会欠训练。为了缓解这个问题,引入了一个负载均衡损失,鼓励均匀地将标记路由到所有专家。
在设计 MoEs 时,专家数量以及选择前 k 个专家是一个重要因素。更多的专家数量允许在不增加计算成本的情况下扩展到更大的模型。这意味着模型具有更高的学习能力,然而,超过一定点后,性能提升往往会减弱。选择专家的数量需要与模型的推理成本相平衡,因为整个模型需要加载到内存中。同样,在选择前 k 个时,训练期间较低的前 k 值会导致较小的矩阵乘法,如果通信成本足够大,就会留下未使用的计算能力。然而,在推理期间,较高的前 k 值通常会导致较慢的推理速度。
MegaBlocks
MegaBlocks 是一种高效的 MoE 实现,它使用稀疏矩阵乘法并行计算专家输出,即使在 token 分配不均匀的情况下也能做到。MegaBlocks 实现了无丢弃的 MoE,在利用 GPU 内核进行高效训练的同时避免了丢弃 token。在 MegaBlocks 之前,动态路由公式迫使在模型质量和硬件效率之间做出权衡。在此之前,用户不得不在计算中丢弃 token 或者在填充上浪费计算和内存。专家可以接收可变数量的 token,并且可以使用块稀疏矩阵乘法高效地执行专家计算。我们已经将 MegaBlocks 集成到 LLM Foundry 中,以实现 MoE 训练扩展到数千个 GPU。
图 2:专家计算的矩阵乘法
专家并行性
随着模型规模扩大,无法适应单个 GPU,我们需要更高级的并行形式。专家并行是一种模型并行形式,我们将不同的专家放置在不同的 GPU 上以获得更好的性能。专家权重不是在所有 GPU 之间通信,而是将标记发送到包含专家的设备。通过移动数据而不是权重,我们可以将多个机器上的数据聚合到单个专家。路由器确定哪些输入序列的标记应该发送到哪些专家。这通常是通过计算每个标记-专家对的门控分数来完成的,然后将每个标记路由到得分最高的专家。一旦确定了标记到专家的分配,就会执行一个全对全的通信步骤,将标记发送到托管相关专家的设备。这涉及到每个设备发送分配给其他设备专家的标记,同时接收分配给其本地专家的标记。
专家并行化的关键优势在于处理少量的大型矩阵乘法,而不是多个小矩阵乘法。由于每个 GPU 只拥有部分专家,因此它只需为这些专家进行计算。相应地,当我们跨多个 GPU 聚合标记时,每个矩阵的大小成比例增大。由于 GPU 针对大规模并行计算进行了优化,较大的操作可以更好地利用其能力,从而提高利用率和效率。关于更大矩阵乘法优势的更深入解释,请参阅此处。计算完成后,将执行另一个全对全通信步骤,将专家输出送回原始设备。
图 3:专家并行中的标记路由
我们利用 PyTorch 的 DTensor,这是一种用于描述张量分片和复制的低级抽象,以有效地实现专家并行。我们首先手动将专家放置在不同的 GPU 上,通常跨节点分片以确保我们可以在路由标记时利用 NVLink 进行快速 GPU 通信。然后,我们可以在这种布局之上构建设备网格,这使得我们可以简洁地描述整个集群的并行性。当我们需要其他形式的并行性时,我们可以使用这个设备网格轻松地检查点或重新排列专家。
使用 PyTorch FSDP 扩展 ZeRO-3
与专家并行性相结合,我们对所有其他层使用数据并行性,其中每个 GPU 存储模型和优化器的副本,并处理不同的数据块。在每个 GPU 完成正向和反向传播后,将梯度累积到所有 GPU 上以进行全局模型更新。
ZeRO-3 是一种数据并行形式,其中权重和优化器在每个 GPU 上分片,而不是复制。每个 GPU 现在只存储完整模型的一部分,大大降低了内存压力。当需要模型的一部分进行计算时,它将在所有 GPU 之间收集,计算完成后,收集到的权重将被丢弃。我们使用 PyTorch 对 ZeRO-3 的实现,称为完全分片数据并行(FSDP)。
随着我们扩展到数千个 GPU,设备间的通信成本增加,导致训练速度减慢。通信增加是因为需要同步和共享所有 GPU 上的模型参数、梯度和优化器状态,这涉及到全聚合和全散射操作。为了减轻这个问题同时保持 FSDP 的优势,我们利用混合分片数据并行(HSDP)将模型和优化器分片到一定数量的 GPU 上,并多次复制以充分利用集群。在 HSDP 中,在反向传播过程中需要额外的全减少操作来同步副本之间的梯度。这种方法允许我们在大规模分布式训练中平衡内存效率和通信成本。要使用 HSDP,我们可以扩展我们之前的设备网格从专家并行性,并让 PyTorch 在需要时进行分片和聚合的重活。
图 4:FSDP 和 HSDP
使用 PyTorch,我们可以有效地结合这两种类型的并行性,在需要实现像专家并行这样的自定义功能时,利用 FSDP 的高级 API,同时使用低级的 DTensor 抽象。我们现在拥有一个具有专家并行分片维度、ZeRO-3 分片维度以及纯数据并行复制维度的 3D 设备网格。这些技术共同实现了在非常大的集群上的近乎线性扩展,使我们能够实现超过 40%的 MFU 数量。
使用 Torch 分布式进行弹性检查点
容错性对于确保LLMs能够在长时间内可靠地训练至关重要,尤其是在节点故障频繁的分布式环境中。为了避免在作业不可避免地遇到故障时丢失进度,我们需要检查点的模型状态,包括参数、优化器状态和其他必要元数据。当发生故障时,系统可以从最后保存的状态恢复,而不是从头开始。为了确保对故障的鲁棒性,我们需要频繁地检查点,并以最高效的方式保存和加载检查点,以最大限度地减少停机时间。此外,如果过多的 GPU 失败,我们的集群大小可能会改变。因此,我们需要能够在不同的 GPU 数量上弹性恢复。
PyTorch 通过其分布式训练框架支持弹性检查点,该框架包括跨不同集群配置保存和加载检查点的实用工具。PyTorch 分布式检查点确保模型的状态可以在训练集群的所有节点上并行保存和恢复,无论集群的组成因节点故障或添加而发生变化。
此外,在训练非常大的模型时,检查点的尺寸可能非常大,导致检查点的上传和下载时间非常慢。PyTorch 分布式检查点支持分片检查点,这使得每个 GPU 只需保存和加载模型的部分。当将分片检查点与弹性训练结合使用时,每个 GPU 会读取元数据文件以确定在恢复时下载哪些分片。元数据文件包含有关每个张量存储在每个分片中的哪些部分的信息。然后 GPU 可以下载其模型部分的分片并加载该部分的检查点。
图 5:在额外 GPU 上对检查点保存和恢复进行分片
通过在 GPU 间并行检查点,我们可以分散网络负载,提高鲁棒性和速度。当使用 3000+个 GPU 训练模型时,网络带宽很快成为瓶颈。我们利用 HSDP 中的复制功能,首先在一个副本上下载检查点,然后将必要的碎片发送到其他副本。在我们的 Composer 集成中,我们可以每隔 30 分钟可靠地将检查点上传到云存储,并在节点故障的情况下在 5 分钟内自动从最新的检查点恢复。
结论
我们非常兴奋地看到 PyTorch 如何通过卓越的性能实现最先进的训练。在我们的文章中,我们展示了如何通过 PyTorch 分布式和 MegaBlocks 在 Foundry 上实现高效的 MoE 训练。此外,PyTorch 弹性检查点允许我们在节点故障发生时快速在不同数量的 GPU 上恢复训练。使用 PyTorch HSDP 使我们能够高效地扩展训练并提高检查点恢复时间。我们期待着继续在一个强大而充满活力的开源社区的基础上构建,以帮助将优秀的 AI 模型带给每个人。加入我们,在LLM Foundry 和 PyTorch 上构建优秀的模型。