torch.distributed.fsdp.fully_shard
PyTorch FSDP2 ( fully_shard
)
PyTorch FSDP2 提供了一种针对高效即时模式的完全分片数据并行(FSDP)实现,同时使用按参数分片以提高可用性。
如果你是 FSDP 的新用户,我们建议您从 FSDP2 开始,因为其易用性得到了改进。
如果您目前正在使用 FSDP1,请考虑评估以下差异,看看是否应该切换到 FSDP2:
与 PyTorch FSDP1 ( FullyShardedDataParallel
) 相比:
FSDP2 使用基于
DTensor
的 dim-0 每参数分片,与 FSDP1 的平面参数分片相比,具有更简单的分片表示,同时保持了相似的吞吐量性能。更具体地说,FSDP2 将每个参数在 dim-0 上分块到数据并行工作进程中(使用torch.chunk(dim=0)
),而 FSDP1 将一组张量展平、连接并分块,使得推理每个工作进程上存在哪些数据以及重新分片到不同的并行化变得复杂。每参数分片提供了更直观的用户体验,放宽了对冻结参数的限制,并允许通信免费(分片)的状态字典,这在 FSDP1 中通常需要 all-gathers。FSDP2 实现了一种不同的内存管理方法来处理多流使用,避免了
torch.Tensor.record_stream
。这确保了确定性和预期的内存使用,并且不需要像 FSDP1 的limit_all_gathers=True
那样阻塞 CPU。FSDP2 公开了用于手动控制预取和集体调度的 API,允许高级用户进行更多定制。有关详细信息,请参阅下面的方法
FSDPModule
。FSDP2 简化了一些 API 接口:例如,FSDP2 不直接支持完整的状态字典。相反,用户可以使用
DTensor
API(如DTensor.full_tensor()
)或使用高级 API(如 PyTorch 分布式检查点的分布式状态字典 API)将包含DTensor
的分割状态字典重新分割为完整的状态字典。此外,还删除了一些其他参数;有关详细信息,请参阅此处。
如果您是首次使用 FSDP,或者上述任何一项符合您的用例,我们建议您考虑使用 FSDP2。
请参阅此 RFC 以获取系统设计和实现的详细信息。
注意
torch.distributed.fsdp.fully_shard
目前处于原型状态,正在开发中。核心 API 可能不会改变,但在必要时我们可能会进行一些 API 更改。
前端 API 是 fully_shard
,可以在 module
上调用:
- torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)[source]¶
将完全分片数据并行(FSDP)应用于
module
,其中 FSDP 将模块参数、梯度和优化器状态分片到数据并行工作进程,以节省内存,但会牺牲通信。初始化时,FSDP 将模块参数分片到由
mesh
给出的数据并行工作进程中。在正向传播之前,FSDP 会在数据并行工作进程中全聚合分片参数,以获取正向计算的无分片参数。如果reshard_after_forward
是True
,则 FSDP 在正向传播后释放无分片参数,并在反向传播之前重新全聚合它们。在梯度计算后,FSDP 释放无分片参数,并将无分片梯度在数据并行工作进程中减少散射。此实现将分片参数表示为在 dim-0 上分片的
DTensor
,而无分片参数将类似于原始参数在module
上(例如,如果原始参数是torch.Tensor
,则类似于torch.Tensor
)。模块正向钩子module
会全聚合参数,而模块正向钩子module
(如果需要)会释放它们。类似的反向钩子会全聚合参数,然后在释放参数和减少散射梯度之后。由于将多个张量组合在一起进行集体通信对于通信效率至关重要,因此该实现将这种组合视为一等。调用
fully_shard()
在module
上构建一个包含module.parameters()
中参数的组,除了那些在早期调用子模块时已分配到组的参数。这意味着fully_shard()
应该在您的模型中自下而上调用。每个组的参数都在一个集体中全部收集,其梯度在一个集体中减少分散。将模型划分为多个组(“层叠”)可以实现峰值内存节省和通信/计算重叠。用户通常不应仅在顶层根模块上调用fully_shard()
。- 参数:
模块(Union[nn.Module, List[nn.Module])- 要与 FSDP 分片并一起进行通信的组合模块或模块列表。
mesh(可选[DeviceMesh])- 这个数据并行网格定义了分片和设备。如果是一维,则参数将完全分片在一维网格(FSDP)上,并使用
(Shard(0),)
放置。如果是二维,则参数将分片在第一维,并在零维上复制(HSDP),并使用(Replicate(), Shard(0))
放置。网格的设备类型给出了用于通信的设备类型;如果是 CUDA 或 CUDA 类似设备类型,则我们使用当前设备。reshard_after_forward(Union[bool, int])-
这控制了正向操作后参数的行为,可以权衡内存和通信:如果是
True
,则在正向操作后重新分片参数,并在反向操作中重新收集。如果
False
,则此操作在正向传播后保留未分片的参数在内存中,并在反向传播中避免 all-gather 操作。如果
int
,则此表示正向传播后要重新分片的世界大小。它应该是mesh
分片维度大小的一个非平凡约数(即排除 1 和维度大小本身)。一个选择可以是节点内大小(例如torch.cuda.device_count()
)。这允许反向传播中的 all-gather 操作在一个更小的世界大小上进行,但代价是比设置为True
更高的内存使用。根 FSDP 状态值被特别设置为
False
作为启发式方法,因为其参数通常会立即在反向传播中进行 all-gather 操作。正向传播后,注册到模块的参数取决于此:注册的参数是分片参数如果
True
;未分片参数如果False
;否则是重新分片到更小网格的参数。要在正向和反向传播之间修改参数,注册的参数必须是分片参数。对于False
或int
,可以通过reshard()
手动重新分片来实现。
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可调用函数可用于覆盖参数的划分放置,以便在除 dim-0 以外的维度上划分参数。如果此可调用函数返回
Shard
放置(不是None
),则 FSDP 将根据该放置进行划分(例如Shard(1)
)。在非零维度上进行划分时,我们目前要求进行偶数划分,即在该维度上的张量维度大小必须能被 FSDP 划分网格大小整除。mp_policy (MixedPrecisionPolicy) – 此用于控制混合精度策略,该策略为此模块提供参数/减少混合精度。有关详细信息,请参阅
MixedPrecisionPolicy
。offload_policy (OffloadPolicy) – 此用于控制卸载策略,该策略提供参数/梯度/优化器状态卸载。有关详细信息,请参阅
OffloadPolicy
及其子类。ignored_params (Optional[set[nn.Parameter]]) – 可选(Set[nn.Parameter]):我们不希望与 FSDP 一起划分的参数集合。
- 返回:
应用 FSDP 的模块(原地)。
- 返回类型:
动态构造一个新类,该类继承自 type(module)
和一个 FSDP 类 FSDPModule
。例如,如果我们对一个模块 linear: nn.Linear
调用 fully_shard(linear)
,那么 FSDP 将构造一个新类 FSDPLinear
并将 linear
的类型更改为这个类型。否则, fully_shard
不会改变模块结构以及参数完全限定名。该类 FSDPModule
允许在模块上提供一些 FSDP 特定的方法。
- class torch.distributed.fsdp.FSDPModule(*args, **kwargs)¶
-
- set_all_reduce_hook(hook, *, stream=None)[source][source]¶
- 参数:
hook (Callable[[torch.Tensor], None]) – 用户定义的具有预期签名
hook(reduce_output: torch.Tensor) -> None
的全量减少钩子,其中reduce_output
是仅使用 FSDP 时的减少分散输出或使用原生 HSDP 时的全量减少输出。stream (Optional[torch.cuda.Stream]) – 执行全量减少钩子的流。仅在未使用原生 HSDP 时设置。如果使用原生 HSDP,钩子将在原生 HSDP 全量减少使用的内部定义的全量减少流中运行。
- set_is_last_backward(is_last_backward)[source][source]
设置下一个反向是否为最后一个。在最后一个反向中,FSDP 将等待挂起的梯度减少并清除反向预取的内部数据结构。这可以用于微批处理。
- set_modules_to_backward_prefetch(modules)[source][source]
设置需要在此 FSDP 模块中显式预取所有-gather 的 FSDP 模块。这覆盖了默认的向后 pretching 实现,该实现根据反向 post-forward 顺序预取下一个 FSDP 模块。
传递包含前一个 FSDP 模块的单例列表将产生与默认重叠行为相同的所有-gather 重叠行为。传递长度至少为两个的列表需要更激进的重叠,并将使用更多预留内存。
- 参数:
模块(List[FSDPModule])- 需要预取的 FSDP 模块。
- set_modules_to_forward_prefetch(modules)[source][source]
设置需要在此 FSDP 模块中显式预取所有-gather 的 FSDP 模块。预取操作在此模块的所有-gather 复制输出后运行。
传递包含下一个 FSDP 模块的单例列表将产生与默认重叠行为相同的效果,但预取所有-gather 的发出时间会更早从 CPU 开始。传递长度至少为两个的列表需要更激进的重叠,并将使用更多预留内存。
- 参数:
模块(List[FSDPModule])- 需要预取的 FSDP 模块。
- 设置后优化事件(set_post_optim_event(event)[source][source])
为根 FSDP 模块设置一个后优化步骤事件,以便等待所有-gather 流。
默认情况下,根 FSDP 模块会在当前流上等待所有-gather 流,以确保在所有-gather 之前优化步骤已经完成。然而,如果优化步骤之后有无关的计算,这可能会引入虚假的依赖。此 API 允许用户提供他们自己的事件来等待。在根等待事件后,事件将被丢弃,因此此 API 应在每次迭代中调用新的事件。
- 参数:
事件(torch.Event)- 记录在优化步骤之后,用于等待所有-gather 流的事件。
- set_reduce_scatter_divide_factor(factor)[source][source]
为 reduce-scatter 设置自定义除数因子。这将成为使用 NCCL 的 PreMulSum 的自定义 reduce 操作,允许在归约之前乘以该因子。
- 参数:
因子(浮点数)- 自定义除数因子。
- set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]¶
设置模块是否应该进行 all-reduce 梯度归约。这可以用于仅使用 reduce-scatter 而不使用 all-reduce 来实现 HSDP 的梯度累积。
- set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]¶
设置模块是否应同步梯度。这可以用于实现不进行通信的梯度累积。对于 HSDP,这控制了 reduce-scatter 和 all-reduce。这在 FSDP1 中相当于 no_sync。
- 参数:
requires_gradient_sync (bool) – 是否为模块的参数减少梯度。
recurse (bool) – 是否为所有 FSDP 子模块设置,还是仅设置传入的模块。
- set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]¶
设置模块在反向传播后是否应该重新划分参数。这可以在梯度累积期间使用,以牺牲更高的内存换取减少通信,因为未划分的参数在下一个正向传播之前不需要重新全局收集。
- 参数:
reshard_after_backward (bool) – 是否在反向传播后重新划分参数。
recurse (bool) – 是否为所有 FSDP 子模块设置,或仅设置传入的模块。
- set_unshard_in_backward(unshard_in_backward)[source][source]¶
设置 FSDP 模块的参数是否需要在反向传播中进行解分片。这可以在用户知道此 FSDP 模块的参数组中的所有参数都不需要用于反向计算(例如嵌入)的专业情况下使用。
- unshard(async_op=False)[source][source]¶
通过分配内存和所有参数的全局收集来解分片模块的参数。此方法不是递归的。解分片遵循
MixedPrecisionPolicy
,因此如果设置,它将全局收集param_dtype
之后的参数。- 参数:
async_op (bool) – 如果
True
,则返回一个具有wait()
方法的UnshardHandle
,用于等待未分片操作。如果False
,则返回None
并在此函数内部等待句柄。- 返回类型:
注意
如果
async_op=True
,则 FSDP 将在模块的预前向中等待用户的挂起未分片操作。如果需要在预前向之前等待,用户只需显式调用wait()
。
- 类 torch.distributed.fsdp.UnshardHandle ¶
等待操作
FSDPModule.unshard()
的句柄。
- torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]¶
在
module
上注册方法,使其成为 FSDP 的前向方法。FSDP 对所有聚合参数进行预前向和可选的前向后释放(取决于
reshard_after_forward
)。FSDP 默认只知道对nn.Module.forward()
这样做。此函数将用户指定的方法修补为在方法之前/之后运行预/后前向钩子。如果module
不是FSDPModule
,则此操作不执行任何操作。- 参数:
模块(nn.Module)- 要注册前向方法的模块。
method_name(字符串)- 前向方法名称。
- class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)¶
此配置了 FSDP 的混合精度。与 autocast 不同,它是在模块级别而不是操作级别应用混合精度,这意味着低精度激活被保存用于反向传播,并且仅在模块边界处进行高到低精度的类型转换。
FSDP 与模块级别的混合精度配合良好,因为它无论如何都会在内存中保留高精度分片参数。换句话说,FSDP 不需要额外的内存来保留参数的高精度副本以供优化器步骤使用。
- 变量:
param_dtype (Optional[torch.dtype]) – 这指定了未分片参数的数据类型,因此也指定了前向/反向计算和参数 all-gather 的数据类型。如果这是
None
,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型的分片参数。 (默认:None
)reduce_dtype (Optional[torch.dtype]) – 这指定了梯度归约的数据类型(即 reduce-scatter 或 all-reduce)。如果这是
None
但param_dtype
不是None
,则归约使用计算数据类型。这可以用于在计算时使用低精度,同时进行全精度梯度归约。如果还通过set_requires_gradient_sync()
禁用了梯度归约,则 FSDP 将使用reduce_dtype
累积梯度。 (默认:None
)output_dtype (Optional[torch.dtype]) – 这指定了将浮点前向输出转换为的数据类型。这可以用于帮助实现不同模块具有不同混合精度策略的情况。 (默认:
None
)cast_forward_inputs (bool) – 这指定了 FSDP 是否应将前向的浮点输入张量转换为
param_dtype
。
- class torch.distributed.fsdp.OffloadPolicy¶
这个基类表示不卸载的策略,仅用作
offload_policy
参数的默认值。
- class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)¶
此卸载策略将参数、梯度以及优化器状态卸载到 CPU。分片参数在 all-gather 之前从主机复制到设备。all-gather 后的参数根据
reshard_after_forward
释放。分片梯度在反向传播时从设备复制到主机,优化器步骤在 CPU 上运行,使用 CPU 优化器状态。- 变量:
pin_memory (布尔值) – 是否固定分片参数和梯度内存。固定内存允许更高效的 H2D/D2H 复制,并且复制可以与计算重叠。然而,固定内存不能被其他进程使用。如果您 CPU 内存不足,请设置为
False
。(默认:True
)