FSDP 笔记¶
FSDP 预取细节¶
对于重叠的 forward
全聚合与 forward
计算,有两种可能的机制:
隐式前向预取(始终启用)
显式前向预取(
forward_prefetch=True
)
隐式 forward
预取是指依赖于从单独的 CUDA 流中发出 all-gather 操作,以便与之前发出的 forward
计算操作(从 CPU 的角度看)重叠。例如,如果我们有层 0 all-gather -> 层 0 forward
计算 -> 层 1 all-gather -> …,那么层 1 all-gather 可以与层 0 forward
计算重叠,即使 CPU 线程在之后发出它。第一个 all-gather 将无法与任何操作重叠。
显式 forward
预取是指改变 CPU 线程的发出顺序:例如层 0 all-gather -> 层 1 all-gather -> 层 0 forward
计算 -> …。在急切模式下,在执行层 0 时,通常无法知道下一层是哪一层(例如示例中的层 1)。因此,显式 forward
预取仅适用于每次迭代执行顺序固定的模型(我们有时称之为“静态图”)。不满足此约束的模型示例是 FLAVA。
显式预取仅能节省发出一个层的计算内核所需的时间,但代价是下一个全聚合的输出张量必须在当前张量仍在使用时分配。通过在当前计算内核之前发出下一个全聚合,下一个全聚合可以在 GPU 上更早开始。对于大多数工作负载,情况并非如此,因此没有启用它的动机。
相比之下,对于 backward
,我们必须使用显式预取,否则通信和计算将没有重叠。原因是由于我们使用单个 NCCL 进程组来执行全聚合和 reduce-scatter(部分原因是在早期的 NCCL 版本中,在相同设备上的相同 rank 上使用多个进程组是不安全的)。单个 NCCL 进程组意味着一个内部 NCCL 流,全聚合和 reduce-scatter 在该流上串行运行。因此,除非我们显式地重新排序 CPU 发出顺序为下一个全聚合 -> 当前 reduce-scatter,否则当前的 reduce-scatter 会阻塞下一个全聚合,进而阻止下一个 backward
计算,从而防止当前的 reduce-scatter 重叠。
通信负载大小
在 FSDP 中,通信方式为:
在
forward
参数上执行 all-gather在
backward
参数上执行 all-gather在
backward
梯度上执行 reduce-scatter
如果使用激活检查点( checkpoint()
),则无需额外的通信,因为参数已经在 backward
期间预先获取。
在 FSDP 设计中,每个 rank 的通信负载如下确定:每次调用 FullyShardedDataParallel
都会创建一个通信组,该组由 module.parameters()
中的参数组成,但排除已分配给嵌套 FullyShardedDataParallel
实例的任何参数。例如,对于 Llama,如果您将 FullyShardedDataParallel
应用于每个 transformer 块以及根模块,那么每个 transformer 块都有一个通信组,最后还有一个包含初始嵌入和最终线性层的通信组。每个通信组对应一个单独的全局聚合调用和一个单独的减少分散调用。这样,您应用 FullyShardedDataParallel
的方式决定了通信大小。一般来说,将 FSDP 应用于每个 transformer 块是一个好的启发式方法,在当前设计下很难做得更好。
考虑一个例子,我们有一个基于 Transformer 的模型,该模型在 8 个 GPU 上进行了分片,分片仅在 Transformer 块级别进行,每个 Transformer 块包含 16 亿个参数,参数为 fp32(每个 4 字节)。这意味着一旦分片,每个 Transformer 块将包含每个 rank 上的 0.2 亿个参数。
第 0#次传递将以
0.2*4 = 0.8GB
块的大小进行全聚合通信。第 0#次传递将进行 2 次
0.8GB
通信(1 次全聚合和 1 次 reduce-scatter)。
换句话说,将有 3 次通信,每次通信的负载为 0.8GB
。如果模型由 10 个 Transformer 块组成,则总共将有 30 次通信,总负载为 30*0.8=24GB
。
每个通信每个等级的负载大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus
(GBs)。
请注意,在这个例子中,我们没有包括嵌入所需的额外通信,这也应该被考虑在内。数学计算将取决于输入和输出嵌入是否绑定。如果它们没有绑定,通信量将是原来的两倍。
FSDP 缓冲区大小 ¶
首先,让我们来了解一下分配给通信的缓冲区:
目前需要 2 倍的全局聚合缓冲区大小。原因如下:
如在 FSDP 预取细节中所述,在显式预取的情况下( forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1
all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward
,另一个用于预取)。
虽然在理论上,相同序列的隐式预取( forward_prefetch=False
,默认)情况只需要 1 个缓冲区,但实际上仍然是 2 倍的全局聚合大小的缓冲区。原因是,在扁平参数 FSDP 设计中,我们不从全局聚合缓冲区复制。用于计算的参数直接从全局聚合缓冲区读取(事实上,“扁平参数”的主要好处正是因为这一点)。在这种情况下,当“层 1 全局聚合”与“层 0 前向计算”重叠时,“层 0 前向计算”正在使用从“层 0 全局聚合”缓冲区读取的参数。
那么,一个自然的问题就是,你什么时候会想要使用 forward_prefetch=False
?对于静态图模型(如大多数LLMs),这有一个主要的技术原因。实际上,我们为了快速为一些 CPU 密集型内部模型添加了这个选项,并且没有在单元测试中对每个代码路径进行测试,所以我们对此不太自信。 forward_prefetching=False
由于我们不必检查记录的前向顺序作为可能的“失败模式”,因此可以稍微容易一些地进行推理;一个模块的所有-gather 可以在其自己的 record_function
标签下在其分析器跟踪中找到。
backward
目前至少需要 2 倍的全局收集缓冲区大小,可能还需要更多。原因如下:
当前 FSDP 设计使用 recordStream
来管理在一个流中产生并在另一个流中消耗的分配,这可能导致比预期的更多内存使用。这种“非确定性”可以增加多少,因为它取决于 GPU 内核时间相对于 CPU。 limit_all_gathers=True
是一个缓解措施 - 更多细节请参阅 FSDP & CUDACachingAllocator 的讨论。
现有 FSDP 与 autograd 协同工作的方式:
现有的 FSDP 全聚合
flat_param
,这是 autograd 的叶子节点。它调用
torch.split
以获取对应其构成原始参数的flat_param
的 1D 视图。它对每个 1D 分割调用
torch.view
以返回 ND 视图。这意味着在
backward
中,我们最终得到ViewBackward
(ND -> 1D)和SplitWithSizesBackward
(这是一个 concat)。特别是,每个单独的梯度都作为单独的分配来计算,并且显式地 concat 来构建 reduce-scatter 输入缓冲区。这实际上意味着在峰值内存点 reduce-scatter 的 2 倍缓冲区大小。
总结来说,对于 backward
,它大约是 reduce-scatter 缓冲区大小的 2 倍,再加上任何 recordStream
的影响。
第二,让我们讨论额外的缓冲区:
一旦从所有 rank 收集到分片参数,它们需要一个额外的缓冲区,用于存储全部参数,大小为 total_transformer_block_params_in_B*dtype_bytes - 以早先的例子继续,如果每个 transformer 块是 1.6B 参数,并且参数是 fp32 类型,那么缓冲区大小将是 1.6*4=6.4GB。
因此需要两个这样的缓冲区,因为当前正在使用一个,另一个正在预取。
总结来说,我们有:
total_transformer_block_params_in_B*dtype_bytes/num_gpus
的通信缓冲区 2 倍未分片的 Transformer 块参数缓冲区 2 倍
``total_transformer_block_params_in_B*dtype_bytes
或者如果你一直在跟随示例:
2*1.6*4/8=1.6GB
2**1.6*4=12.8GB
并且总计为 14.4GB
。
现在让我们简要讨论一下,由于我们在计算中省略了嵌入,所以嵌入会发生什么:
根据我们在笔记中讨论的规则,即“通信缓冲区大小如下确定”,我们可以这样分析:
假设我们将 FSDP 应用于根模块(例如
Transformer
类)。假设我们进一步将 FSDP 应用于每个 transformer 块(例如TransformerBlock
类)。通常情况下,嵌入和最终线性投影是根节点
Transformer
类的直接子节点。根据我们的规则,这意味着嵌入和最终线性投影被分配给根节点
Transformer
的扁平参数。我们还有 _另一个_ 特殊规则,即根节点在正向传播后不会释放其参数,因为它们无论如何都会在反向传播时立即全部收集。
将这些放在一起,这意味着根节点的扁平参数(包括嵌入和最终投影)在开始正向传播时全部收集,并保持在 GPU 内存中,直到反向传播结束。
如果嵌入和最终线性层没有权重绑定,那么我们_可以_进一步将 FSDP 应用于嵌入和最终线性层。对于权重绑定的参数,它们必须属于同一个扁平参数(否则会重复计数)。这将允许在正向传播后释放嵌入,仅在反向传播结束时进行全聚合。
希望这能更好地说明——每个 FSDP 模块在其
module.parameters
中分配参数,除了那些已经被分配给另一个嵌套 FSDP 模块的参数,而 FSDP 模块的forward
定义了其参数的“活动”区间。因此,嵌套的nn.Module
结构会影响全聚合/释放的调度,从而影响内存/吞吐量性能。