• 文档 >
  • 全局分片数据并行
快捷键

全局分片数据并行 ¶

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source][source]

用于在数据并行工作节点间分片模块参数的包装器。

这是由徐等人以及 DeepSpeed 的 ZeRO Stage 3 所启发。FullyShardedDataParallel 通常简称为 FSDP。

要了解 FSDP 的内部机制,请参阅 FSDP 笔记。

示例:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

使用 FSDP 需要先包装您的模块,然后初始化您的优化器。这是必需的,因为 FSDP 会改变参数变量。

在设置 FSDP 时,您需要考虑目标 CUDA 设备。如果设备有 ID( dev_id ),您有三个选择:

  • 将模块放置在该设备上

  • 使用 torch.cuda.set_device(dev_id) 设置设备

  • dev_id 传递给 device_id 构造函数参数。

这确保了 FSDP 实例的计算设备是目标设备。对于选项 1 和 3,FSDP 的初始化始终在 GPU 上发生。对于选项 2,FSDP 的初始化发生在模块的当前设备上,这可能是 CPU。

如果你使用 sync_module_states=True 标志,则需要确保模块在 GPU 上,或者使用 device_id 参数在 FSDP 构造函数中指定 FSDP 将移动模块到的 CUDA 设备。这是必要的,因为 sync_module_states=True 需要 GPU 通信。

FSDP 还会负责将输入张量移动到前向方法的 GPU 计算设备,因此你不需要手动从 CPU 移动它们。

对于 use_orig_params=TrueShardingStrategy.SHARD_GRAD_OP 暴露的是未分片的参数,而不是前向后的分片参数,与 ShardingStrategy.FULL_SHARD 不同。如果你想检查梯度,可以使用 summon_full_params 方法与 with_grads=True 一起使用。

使用 limit_all_gathers=True 时,你可能会在 FSDP 预前向中看到 CPU 线程没有发出任何内核的差距。这是故意的,显示了速率限制器的效果。以这种方式同步 CPU 线程可以防止为后续的所有-gather 操作过度分配内存,实际上不应延迟 GPU 内核执行。

FSDP 在正向和反向计算过程中,出于自动微分相关原因,用 torch.Tensor 替换了管理模块的参数。如果你的模块正向依赖于保存的参数引用而不是每次迭代都重新获取引用,那么它将看不到 FSDP 新创建的视图,自动微分将无法正确工作。

最后,当使用 sharding_strategy=ShardingStrategy.HYBRID_SHARD 且分片进程组为节点内且复制进程组为节点间时,设置 NCCL_CROSS_NIC=1 可以帮助提高某些集群配置下复制进程组的全量减少时间。

局限性

使用 FSDP 时需要注意以下限制:

  • FSDP 目前不支持在 no_sync() 之外使用 CPU 卸载进行梯度累积。这是因为 FSDP 使用新减少的梯度而不是与任何现有梯度累积,这可能导致结果不正确。

  • FSDP 不支持在 FSDP 实例中运行的子模块的前向传播。这是因为子模块的参数将被分片,但子模块本身不是 FSDP 实例,因此其前向传播不会正确地 all-gather 完整参数。

  • 由于 FSDP 注册反向钩子的方式,FSDP 不支持双反向。

  • FSDP 在冻结参数时有一些限制。对于 use_orig_params=False ,每个 FSDP 实例必须管理所有冻结或所有非冻结的参数。对于 use_orig_params=True ,FSDP 支持混合冻结和非冻结参数,但建议避免这样做,以防止高于预期的梯度内存使用。

  • 截至 PyTorch 1.12 版本,FSDP 对共享参数的支持有限。如果您需要针对您的用例增强共享参数支持,请在此问题中发布。

  • 应避免在不使用 summon_full_params 上下文的情况下,在正向和反向之间修改参数,因为这些修改可能不会持久保存。

参数:
  • module (nn.Module) – 这是需要用 FSDP 包装的模块。

  • process_group (可选[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]) – 这是模型分片所覆盖的进程组,因此用于 FSDP 的全局聚合和减少分散的集体通信。如果 None ,则 FSDP 使用默认的进程组。对于混合分片策略,如 ShardingStrategy.HYBRID_SHARD ,用户可以传入进程组的元组,分别表示分片和复制的组。如果 None ,则 FSDP 为用户构建进程组以进行节点内分片和节点间复制。(默认: None

  • sharding_strategy (可选[ShardingStrategy]) – 这配置了分片策略,可能会在内存节省和通信开销之间进行权衡。请参阅 ShardingStrategy 以获取详细信息。(默认: FULL_SHARD

  • cpu_offload(可选[CPUOffload])- 这用于配置 CPU 卸载。如果设置为 None ,则不会发生 CPU 卸载。有关详细信息,请参阅 CPUOffload 。(默认: None

  • auto_wrap_policy(可选[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]])-

    这指定了一个策略,将 FSDP 应用于 module 的子模块,这对于通信和计算重叠是必需的,从而影响性能。如果 None ,则 FSDP 仅应用于 module ,用户应手动将 FSDP 应用于父模块(自下而上进行)。为了方便,此策略直接接受 ModuleWrapPolicy ,允许用户指定要包装的模块类(例如,transformer 块)。否则,应是一个接受三个参数 module: nn.Modulerecurse: boolnonwrapped_numel: int 的调用函数,并返回一个 bool ,指定如果 module 传递的参数是否应该应用 FSDP,如果 recurse=False ,或者如果 recurse=True 应继续遍历模块的子树。用户可以为调用函数添加额外的参数。 size_based_auto_wrap_policytorch.distributed.fsdp.wrap.py 中给出了一个示例调用函数,如果模块子树中的参数超过 100M numel,则应用 FSDP。我们建议在应用 FSDP 后打印模型,并根据需要进行调整。

    示例:

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     nonwrapped_numel: int,
    >>>     # Additional custom arguments
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return nonwrapped_numel >= min_num_params
    >>> # Configure a custom `min_num_params`
    >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
    

  • backward_prefetch (Optional[BackwardPrefetch]) – 此配置显式向后预取所有-gathers。如果为 None ,则 FSDP 不进行向后预取,并且在反向传播过程中没有通信和计算重叠。有关详细信息,请参阅 BackwardPrefetch 。(默认: BACKWARD_PRE

  • mixed_precision (Optional[MixedPrecision]) – 此配置 FSDP 的原生混合精度。如果设置为 None ,则不使用混合精度。否则,可以设置参数、缓冲区和梯度减少的数据类型。有关详细信息,请参阅 MixedPrecision 。(默认: None

  • ignored_modules (Optional[Iterable[torch.nn.Module]]) – 此实例忽略的模块及其子模块的参数和缓冲区。 ignored_modules 中的任何模块都不应该是 FullyShardedDataParallel 实例,并且如果它们嵌套在此实例下,则已构建的 FullyShardedDataParallel 子模块不会被忽略。此参数可用于在使用 auto_wrap_policy 时避免在模块粒度上分片特定参数,或者如果参数的分片不由 FSDP 管理。 (默认: None

  • param_init_fn (Optional[Callable[[nn.Module], None]]) –

    指定如何将当前位于元设备上的模块初始化到实际设备上的 Callable[torch.nn.Module] -> None 。从 v1.12 版本开始,FSDP 通过 is_meta 检测元设备上的具有参数或缓冲区的模块,如果指定了 param_init_fn ,则应用它,否则调用 nn.Module.reset_parameters() 。在两种情况下,实现应仅初始化模块的参数/缓冲区,而不是其子模块的参数/缓冲区。这是为了避免重新初始化。此外,FSDP 还支持通过 torchdistX 的(https://github.com/pytorch/torchdistX) deferred_init() API 进行延迟初始化,其中延迟初始化的模块通过调用 param_init_fn (如果指定)或 torchdistX 的默认 materialize_module() 进行初始化。如果指定了 param_init_fn ,则它应用于所有元设备模块,这意味着它可能取决于模块类型。FSDP 在参数展平和分片之前调用初始化函数。

    示例:

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module: nn.Module):
    >>>     # E.g. initialize depending on the module type
    >>>     ...
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_id(可选[Union[int, torch.device]])- 一个 inttorch.device ,表示 FSDP 初始化发生的 CUDA 设备,包括如果需要的话模块初始化和参数分片。如果 module 在 CPU 上,则应指定此参数以提高初始化速度。如果已设置默认 CUDA 设备(例如通过 torch.cuda.set_device ),则用户可以将 torch.cuda.current_device 传递给此参数。(默认: None

  • sync_module_states(布尔值)- 如果 True ,则每个 FSDP 模块将广播从 rank 0 的模块参数和缓冲区以确保它们在各个 rank 之间复制(向此构造函数添加通信开销)。这有助于以内存高效的方式通过 load_state_dict 加载 state_dict 检查点。请参阅 FullStateDictConfig 以获取此示例。(默认: False

  • forward_prefetch(布尔值)- 如果 True ,则 FSDP 在当前前向计算之前显式预取下一个前向传递的所有-gather。这对于 CPU 密集型工作负载很有用,在这种情况下,提前发出下一个 all-gather 可能有助于重叠。这仅适用于静态图模型,因为预取遵循第一次迭代的执行顺序。(默认: False

  • limit_all_gathers (bool) – 如果 True ,则 FSDP 明确同步 CPU 线程以确保 GPU 内存使用仅来自两个连续的 FSDP 实例(当前运行计算的实例和预取所有-gather 的下一个实例)。如果 False ,则 FSDP 允许 CPU 线程发出所有-gather 而无需任何额外同步。(默认: True )我们通常将此功能称为“速率限制器”。此标志应仅设置为 False ,用于具有低内存压力的特定 CPU 密集型工作负载,在这种情况下,CPU 线程可以积极发出所有内核,无需担心 GPU 内存使用。

  • 使用原参数(bool)- 将此设置为 True 使 FSDP 使用 module 的原参数。FSDP 通过 nn.Module.named_parameters() 而不是 FSDP 内部的 FlatParameter 将那些原参数暴露给用户。这意味着优化器步骤在原参数上运行,允许每个原参数的超参数。FSDP 保留原参数变量,并在非分片和分片形式之间操作它们的数据,其中它们始终是底层非分片或分片 FlatParameter 的视图。根据当前算法,分片形式始终是 1D,丢失了原始张量结构。一个原参数可能在其给定秩的所有、一些或没有数据。在没有数据的情况下,其数据将像一个大小为 0 的空张量。用户不应编写依赖于给定原参数在其分片形式中存在哪些数据的程序。 True 是使用 torch.compile() 所必需的。将此设置为 False 通过 nn.Module.named_parameters() 将 FSDP 内部的 FlatParameter 暴露给用户。(默认: False

  • ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – 被忽略的参数或模块,这些参数将由该 FSDP 实例管理,意味着参数不会被分片,其梯度也不会在各个 rank 之间进行汇总。此参数与现有的 ignored_modules 参数统一,我们可能会很快弃用 ignored_modules 。为了向后兼容,我们保留了 ignored_states 和 ignored_modules,但 FSDP 只允许指定其中一个作为非 None

  • device_mesh (Optional[DeviceMesh]) – DeviceMesh 可以用作 process_group 的替代方案。当传递 device_mesh 时,FSDP 将使用底层进程组进行 all-gather 和 reduce-scatter 集体通信。因此,这两个参数需要互斥。对于混合分片策略如 ShardingStrategy.HYBRID_SHARD ,用户可以传递一个 2D DeviceMesh 而不是进程组的元组。对于 2D FSDP + TP,用户需要传递 device_mesh 而不是 process_group。更多 DeviceMesh 信息,请访问:https://pytorch.org/tutorials/recipes/distributed_device_mesh.html

apply(fn)[source][source]

Apply fn 递归地应用于每个子模块(如 .children() 返回的)以及自身。

典型用途包括初始化模型的参数(参见 torch.nn.init)。

torch.nn.Module.apply 相比,此版本在应用 fn 之前会额外收集所有参数。不应在另一个 summon_full_params 上下文中调用。

参数:

fn ( Module -> None) – 对每个子模块应用此函数

返回:

self

返回类型:

模块

检查是否为根 FSDP 模块()[来源][来源] ¶

检查此实例是否为根 FSDP 模块。

返回类型:

布尔型

clip_grad_norm_(max_norm, norm_type=2.0)[来源][来源] ¶

剪切所有参数的梯度范数。

规范是在将所有参数梯度视为单个向量的情况下计算的,并且梯度是在原地修改的。

参数:
  • max_norm(浮点数或整数)- 梯度的最大范数

  • norm_type(浮点数或整数)- 使用的 p-范数的类型。可以是 'inf' 表示无穷范数。

返回:

将参数的总范数(视为单个向量)。

返回类型:

张量

如果每个 FSDP 实例都使用 NO_SHARD ,即没有梯度在各个 rank 之间分片,那么您可以直接使用 torch.nn.utils.clip_grad_norm_()

如果至少有一个 FSDP 实例使用分片策略(即除 NO_SHARD 之外的其他策略),那么您应该使用这种方法代替 torch.nn.utils.clip_grad_norm_() ,因为这种方法可以处理梯度在各个 rank 之间分片的情况。

返回的总范数将具有所有参数/梯度中“最大”的数据类型,这是由 PyTorch 的类型提升语义定义的。例如,如果所有参数/梯度都使用低精度数据类型,那么返回范数的数据类型将是该低精度数据类型,但如果至少有一个参数/梯度使用 FP32,那么返回范数的数据类型将是 FP32。

警告

由于它使用集体通信,因此需要在所有 rank 上调用此函数。

static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source][source]

展平分片优化器状态字典。

该 API 与 shard_full_optim_state_dict() 类似。唯一的区别是输入 sharded_optim_state_dict 应从 sharded_optim_state_dict() 返回。因此,每个 rank 都会进行 all-gather 调用以收集 ShardedTensor

参数:
  • sharded_optim_state_dict (Dict[str, Any]) – 对应未展平参数并持有分片优化器状态的优化器状态字典。

  • 模型(torch.nn.Module)- 参考第 0#条。

  • optim (torch.optim.Optimizer) – 用于 model 参数的优化器。

返回:

参考第 0#条。

返回类型:

dict[str, Any]

forward(*args, **kwargs)[source][source]

运行包装模块的前向传播,并插入 FSDP 特定的前向和后向分片逻辑。

返回类型:

任何

static fsdp_modules(module, root_only=False)[source][source]

返回所有嵌套的 FSDP 实例。

这可能包括 module 自身,并且仅在 root_only=True 的情况下包括 FSDP 根模块。

参数:
  • 模块(torch.nn.Module)- 根模块,可能或可能不是 FSDP 模块。

  • root_only(布尔值)- 是否仅返回 FSDP 根模块。(默认: False

返回:

嵌套在输入 module 中的 FSDP 模块。

返回类型:

FullyShardedDataParallel 列表

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source][source]

返回完整的优化器状态字典。

在 rank 0 上合并完整的优化器状态并返回,遵循 torch.optim.Optimizer.state_dict() 的约定,即具有 "state""param_groups" 键。将包含在 FSDP 模块中的展平参数映射回其未展平的参数。

由于使用集体通信,需要在所有 rank 上调用此函数。但是,如果 rank0_only=True ,则状态字典仅在 rank 0 上填充,其他所有 rank 返回空的 dict

torch.optim.Optimizer.state_dict() 不同,此方法使用完整的参数名称作为键,而不是参数 ID。

torch.optim.Optimizer.state_dict() 中所述,优化器状态字典中包含的张量没有被克隆,因此可能会有别名惊喜。为了最佳实践,请考虑立即保存返回的优化器状态字典,例如使用 torch.save()

参数:
  • model (torch.nn.Module) – 根模块(可能或可能不是 FullyShardedDataParallel 实例),其参数被传递到优化器 optim 中。

  • optim (torch.optim.Optimizer) – 用于 model 参数的优化器。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 输入到优化器 optim ,表示参数组 list 或参数的迭代器;如果 None ,则此方法假定输入为 model.parameters() 。此参数已弃用,无需再传入。(默认: None )

  • rank0_only (bool) – 如果 True ,则仅在 rank 0 上保存填充的 dict ;如果 False ,则在所有 rank 上保存。(默认: True )

  • group (dist.ProcessGroup) – 模型的进程组或 None 如果使用默认进程组。(默认: None )

返回:

A dict 包含优化器状态,用于 model 的原始未展平参数,并包含“state”和“param_groups”键,遵循 torch.optim.Optimizer.state_dict() 的约定。如果 rank0_only=True ,则非零 rank 返回空的 dict

返回类型:

字典[str, Any]

static get_state_dict_type(module)[source][source]

获取根位于 module 的 FSDP 模块的状态字典类型及其对应配置。

目标模块不必是 FSDP 模块。

返回:

包含当前设置的 state_dict_type 和 state_dict / optim_state_dict 配置的 StateDictSettings

引发:
  • 如果 StateDictSettings 对于不同的 –

  • FSDP 子模块不同。 –

返回类型:

StateDictSettings

属性模块 Module ¶

返回包装的模块。

named_buffers(*args, **kwargs)[source][source]

返回模块缓冲区的迭代器,生成缓冲区的名称及其本身。

拦截缓冲区名称,并在 summon_full_params() 上下文管理器内部删除所有 FSDP 特定的扁平缓冲区前缀的实例。

返回类型:

迭代器[tuple[str, torch.Tensor]]

named_parameters(*args, **kwargs)[source][source]

返回模块参数的迭代器,同时产生参数的名称和参数本身。

summon_full_params() 上下文中拦截参数名称,并移除所有 FSDP 特定展开参数前缀的实例。

返回类型:

迭代器[tuple[str, torch.nn.parameter.Parameter]]

no_sync()[source][source]

禁用 FSDP 实例间的梯度同步。

在此上下文中,梯度将累积在模块变量中,之后将在退出上下文后的第一次前向-反向传递中进行同步。这仅应在根 FSDP 实例上使用,并将递归地应用于所有子 FSDP 实例。

注意

这可能导致更高的内存使用,因为 FSDP 将累积完整的模型梯度(而不是梯度碎片)直到最终同步。

注意

当与 CPU 卸载一起使用时,梯度在上下文管理器内部不会被卸载到 CPU。相反,它们将在最终同步后卸载。

返回类型:

生成器

static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source][source]

将分片模型的优化器对应的状态字典进行转换。

给定的状态字典可以转换为三种类型之一:1)完整优化器状态字典,2)分片优化器状态字典,3)本地优化器状态字典。

对于完整优化器状态字典,所有状态都是未展平且未分片的。可以通过 state_dict_type() 指定 Rank0 和 CPU 仅,以避免内存溢出。

对于分片优化器状态字典,所有状态都是未展平且分片的。可以通过 state_dict_type() 指定 CPU 以进一步节省内存。

对于本地状态字典,不会执行任何转换。但会将状态从 nn.Tensor 转换为 ShardedTensor 以表示其分片特性(目前尚不支持)。

示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
参数:
  • model (torch.nn.Module) – 根模块(可能或可能不是 FullyShardedDataParallel 实例),其参数被传递到优化器 optim 中。

  • optim (torch.optim.Optimizer) – 用于 model 参数的优化器。

  • optim_state_dict (Dict[str, Any]) – 要转换的目标优化器状态字典。如果值为 None,则使用 optim.state_dict()。 (默认: None )

  • group (dist.ProcessGroup) – 模型参数分片所跨越的进程组。如果使用默认进程组,则为 None 。 (默认: None )

返回:

包含 model 的优化器状态的 dict 。优化器状态的分片基于 state_dict_type

返回类型:

字典[str, Any]

static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source][source]

将优化器状态字典转换为可以加载到与 FSDP 模型关联的优化器中的形式。

给定一个通过 optim_state_dict() 转换的 optim_state_dict ,它被转换为可以加载到 optim 的扁平化优化器状态字典, optimmodel 的优化器。 model 必须由 FullyShardedDataParallel 分片。

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>>     model,
>>>     optim,
>>>     optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
参数:
  • model (torch.nn.Module) – 根模块(可能或可能不是 FullyShardedDataParallel 实例),其参数被传递到优化器 optim 中。

  • optim (torch.optim.Optimizer) – 用于 model 参数的优化器。

  • optim_state_dict(Dict[str, Any])- 要加载的优化器状态。

  • is_named_optimizer(bool)- 此优化器是否为 NamedOptimizer 或 KeyedOptimizer。仅在 optim 是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 时设置为 True。

  • load_directly (bool) – 如果设置为 True,此 API 在返回结果之前还将调用 optim.load_state_dict(result)。否则,用户负责调用 optim.load_state_dict() (默认: False

  • group (dist.ProcessGroup) – 模型参数分片所跨越的进程组。如果使用默认进程组,则为 None 。 (默认: None )

返回类型:

dict[str, Any]

register_comm_hook(state, hook)[source][source]

注册通信钩子。

这是一个为用户提供灵活钩子的增强功能,用户可以指定 FSDP 如何聚合多个工作进程之间的梯度。此钩子可用于实现 GossipGrad 和梯度压缩等算法,这些算法涉及不同的通信策略以进行参数同步训练,同时使用 FullyShardedDataParallel

警告

FSDP 通信钩子应在运行初始前向传递之前注册,并且只注册一次。

参数:
  • state (对象) –

    传递给钩子以在训练过程中维护任何状态信息。例如,包括梯度压缩中的错误反馈、GossipGrad 中下一个要通信的节点等。它由每个工作进程本地存储并由该工作进程上的所有梯度张量共享。

  • hook (Callable) – 具有以下签名之一的 Callable:1) hook: Callable[torch.Tensor] -> None :此函数接收一个 Python 张量,该张量表示与该 FSDP 单元包装的模型的所有变量(不包括被其他 FSDP 子单元包装的变量)相关的完整、平坦、未分片的梯度。然后执行所有必要的处理并返回 None ;2) hook: Callable[torch.Tensor, torch.Tensor] -> None :此函数接收两个 Python 张量,第一个张量表示与该 FSDP 单元包装的模型的所有变量(不包括被其他 FSDP 子单元包装的变量)相关的完整、平坦、未分片的梯度。后者表示一个预分配大小的张量,用于存储在缩减后的分片梯度的一部分。在两种情况下,Callable 都执行所有必要的处理并返回 None 。具有签名 1 的 Callable 预期处理 NO_SHARD 情况下的梯度通信。具有签名 2 的 Callable 预期处理分片情况下的梯度通信。

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source][source]

重新键控优化器状态字典 optim_state_dict 以使用键类型 optim_state_key_type

这可以用于实现具有 FSDP 实例和无 FSDP 实例的模型优化器状态字典之间的兼容性。

将 FSDP 完整优化器状态字典(即来自 full_optim_state_dict() )重新键入以使用参数 ID 并可加载到非包装模型中:

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

将非包装模型的普通优化器状态字典重新键入以可加载到包装模型中:

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
返回:

使用指定的参数键重新键入的优化器状态字典。

返回类型:

字典[str, Any]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source][source]

将全优化器状态字典从 rank 0 分散到所有其他 rank。

返回每个 rank 上的分片优化器状态字典。返回值与 shard_full_optim_state_dict() 相同,在 rank 0 上,第一个参数应该是 full_optim_state_dict() 的返回值。

示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意

shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用来获取分片优化器状态字典进行加载。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个 rank 都拥有完整的字典在 CPU 内存中,每个 rank 独立分片字典而不进行任何通信,而后者只要求 rank 0 拥有完整的字典在 CPU 内存中,rank 0 将每个分片移动到 GPU 内存(用于 NCCL),并适当地与 rank 进行通信。因此,前者有更高的 CPU 内存成本,而后者有更高的通信成本。

参数:
  • full_optim_state_dict(可选[Dict[str, Any]])- 对应于未展平参数的优化器状态字典,如果位于 rank 0,则包含完整的非分片优化器状态;在非零 rank 上此参数被忽略。

  • model (torch.nn.Module) – 根模块(可能或可能不是 FullyShardedDataParallel 实例),其参数对应于 full_optim_state_dict 中的优化器状态。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组 list 或参数的可迭代对象;如果是 None ,则此方法假定输入是 model.parameters() 。此参数已弃用,不再需要传递。默认值: None

  • optim (Optional[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是首选参数,而不是 optim_input 。默认值: None

  • group (dist.ProcessGroup) – 模型的进程组或 None 如果使用默认进程组。(默认: None )

返回:

全局优化器状态字典现在已重新映射到展平的参数,而不是未展平的参数,并且仅包括此 rank 的优化器状态部分。

返回类型:

字典[str, Any]

static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

设置目标模块所有子 FSDP 模块的 state_dict_type

还可以(可选)为模型和优化器的状态字典提供配置。目标模块不一定是 FSDP 模块。如果目标模块是 FSDP 模块,其 state_dict_type 也将被更改。

注意

应仅调用此 API 进行顶级(根)模块。

注意

此 API 允许用户在根 FSDP 模块被其他模块包装的情况下,透明地使用传统的 state_dict API 来获取模型检查点。例如,以下代码将确保在所有非 FSDP 实例上调用 state_dict ,而对于 FSDP 则调度到 sharded_state_dict 实现:

示例:

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
>>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
>>> )
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
参数:
  • 模块(torch.nn.Module)- 根模块。

  • state_dict_type(StateDictType)- 要设置的期望 state_dict_type

  • state_dict_config (Optional[StateDictConfig]) – 目标配置 state_dict_type (可选)。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 优化器状态字典的配置。

返回:

包含之前状态字典类型和模块配置的 StateDictSettings。

返回类型:

StateDictSettings

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source][source]

将完整的优化器状态字典进行分片。

将状态映射到 full_optim_state_dict ,从未展开的参数转换为展开的参数,并仅限制到本 rank 的优化器状态部分。第一个参数应该是 full_optim_state_dict() 的返回值。

示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

注意

shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用来获取分片优化器状态字典进行加载。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个 rank 都拥有完整的字典在 CPU 内存中,每个 rank 独立分片字典而不进行任何通信,而后者只要求 rank 0 拥有完整的字典在 CPU 内存中,rank 0 将每个分片移动到 GPU 内存(用于 NCCL),并适当地与 rank 进行通信。因此,前者有更高的 CPU 内存成本,而后者有更高的通信成本。

参数:
  • full_optim_state_dict (Dict[str, Any]) – 与未展平参数对应的优化器状态字典,包含完整的非分片优化器状态。

  • model (torch.nn.Module) – 根模块(可能或可能不是 FullyShardedDataParallel 实例),其参数对应于 full_optim_state_dict 中的优化器状态。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组 list 或参数的可迭代对象;如果是 None ,则此方法假定输入是 model.parameters() 。此参数已弃用,不再需要传递。默认值: None

  • optim (Optional[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是首选参数,而不是 optim_input 。默认值: None

返回:

全局优化器状态字典现在已重新映射到展平的参数,而不是未展平的参数,并且仅包括此 rank 的优化器状态部分。

返回类型:

字典[str, Any]

static sharded_optim_state_dict(model, optim, group=None)[source][source]

返回优化器状态字典的碎片化形式。

该 API 与 full_optim_state_dict() 类似,但此 API 将所有非零维度的状态块压缩到 ShardedTensor 以节省内存。此 API 仅应在使用上下文管理器 state_dict 导出模型 with state_dict_type(SHARDED_STATE_DICT): 时使用。

详细用法请参阅 full_optim_state_dict()

警告

返回的状态字典包含 ShardedTensor ,不能直接由常规的 optim.load_state_dict 使用。

返回类型:

dict[str, Any]

static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

设置目标模块所有子 FSDP 模块的 state_dict_type

此上下文管理器与 set_state_dict_type() 具有相同的功能。请阅读 set_state_dict_type() 的文档以获取详细信息。

示例:

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>> ):
>>>     checkpoint = model.state_dict()
参数:
  • 模块(torch.nn.Module)- 根模块。

  • state_dict_type(StateDictType)- 要设置的期望 state_dict_type

  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的模型 state_dict 配置(可选)。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 目标 state_dict_type 的优化器 state_dict 配置(可选)。

返回类型:

生成器

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source][source]

使用此上下文管理器暴露 FSDP 实例的完整参数。

在模型的前向/反向操作之后,这可能很有用,以便在额外处理或检查时获取参数。它可以接受非 FSDP 模块,并将根据 recurse 参数调用所有包含的 FSDP 模块及其子模块的完整参数。

注意

这可以用于内部 FSDP。

注意

这不能在正向或反向传递中使用。也不能从该上下文中启动正向和反向。

注意

参数将在上下文管理器退出后恢复到其本地分片,存储行为与正向相同。

注意

可以修改全部参数,但只有对应本地参数分片的那个部分会在上下文管理器退出后持久化(除非 writeback=False ,在这种情况下,更改将被丢弃)。在 FSDP 不分片参数的情况下,目前只有当 world_size == 1NO_SHARD 配置时,修改才会持久化,无论 writeback 如何。

注意

此方法适用于不是 FSDP 本身的模块,但可能包含多个独立的 FSDP 单元。在这种情况下,给定的参数将应用于所有包含的 FSDP 单元。

警告

注意, rank0_only=Truewriteback=True 的组合目前不支持,会引发错误。这是因为模型参数的形状会在上下文中跨 rank 不同,写入它们可能导致退出上下文时 rank 之间的不一致性。

警告

注意, offload_to_cpurank0_only=False 将导致完整参数被冗余复制到同一台机器上的 GPU 的 CPU 内存中,这可能会增加 CPU OOM 的风险。建议使用 offload_to_cpurank0_only=True

参数:
  • recurse (bool, Optional) – 递归调用嵌套 FSDP 实例的所有参数(默认:True)。

  • writeback (bool, Optional) – 如果 False ,则在上下文管理器退出后,对参数的修改将被丢弃;禁用此功能可能略微更高效(默认:True)

  • rank0_only (bool, Optional) – 如果 True ,则仅在全局排名 0 上实例化完整参数。这意味着在上下文中,只有排名 0 将拥有完整参数,而其他排名将拥有分片参数。请注意,设置 rank0_only=Truewriteback=True 不受支持,因为在上下文中不同排名的模型参数形状将不同,写入它们可能导致退出上下文时排名之间的一致性问题。

  • offload_to_cpu (bool, 可选) – 如果 True ,则将全部参数卸载到 CPU。请注意,当前仅在参数分片的情况下才会进行卸载(这只有在 world_size = 1 或 NO_SHARD 配置的情况下才不是这种情况)。建议使用 offload_to_cpurank0_only=True 以避免将模型参数的冗余副本卸载到同一 CPU 内存中。

  • with_grads (bool, 可选) – 如果 True ,则梯度也会与参数一起进行非分片处理。当前仅在将 use_orig_params=True 传递给 FSDP 构造函数并将 offload_to_cpu=False 传递给此方法时支持。(默认值: False

返回类型:

生成器

class torch.distributed.fsdp.BackwardPrefetch(value, names=<未提供>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source][source] ¶

这配置了显式的向后预取,通过在向后传递中启用通信和计算重叠来提高吞吐量,但代价是略微增加内存使用。

  • BACKWARD_PRE : 这将实现最大的重叠,但内存使用增加最多。在当前参数集的梯度计算之前预取下一组参数。这重叠了下一个全聚合和当前的梯度计算,在峰值时,它将当前参数集、下一组参数和当前梯度集保持在内存中。

  • BACKWARD_POST : 这将实现较少的重叠,但内存使用较少。在当前参数集的梯度计算之后预取下一组参数。这重叠了当前的 reduce-scatter 和下一个梯度计算,并在分配下一组参数内存之前释放当前参数集,仅在峰值时保持下一组参数和当前梯度集在内存中。

  • FSDP 的 backward_prefetch 参数接受 None ,这将完全禁用向后预取。这没有重叠,也不会增加内存使用。一般来说,我们不推荐这种设置,因为它可能会显著降低吞吐量。

更多技术背景:对于使用 NCCL 后端的单个进程组,任何集体操作,即使来自不同的流,也会竞争相同的设备 NCCL 流,这意味着集体操作的发出顺序对于重叠很重要。两个向后预取值对应不同的发出顺序。

class torch.distributed.fsdp.ShardingStrategy(value, names=<未提供>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source][source] ¶

这指定了 FullyShardedDataParallel 用于分布式训练的分区策略。

  • FULL_SHARD : 参数、梯度和优化器状态被分片。对于参数,这种策略在正向传播前进行解分片(通过 all-gather),在正向传播后重新分片,在反向计算前进行解分片,在反向计算后重新分片。对于梯度,它在反向计算后进行同步和分片(通过 reduce-scatter)。分片优化器状态在每个 rank 上本地更新。

  • SHARD_GRAD_OP : 在计算过程中,梯度和优化器状态被分片,并且参数在计算外部分片。对于参数,这种策略在正向传播前进行解分片,在正向传播后不重新分片,只在反向计算后重新分片。分片优化器状态在每个 rank 上本地更新。在 no_sync() 内,参数在反向计算后不重新分片。

  • NO_SHARD : 参数、梯度和优化器状态没有被分片,而是像 PyTorch 的 DistributedDataParallel API 一样在 rank 间复制。对于梯度,这种策略在反向计算后进行同步(通过 all-reduce)。未分片的优化器状态在每个 rank 上本地更新。

  • 在节点内应用 FULL_SHARD ,并在节点间复制参数。这可以减少通信量,因为昂贵的全局聚合和全局分散操作仅在节点内进行,这对于中等规模的模型来说可能更高效。

  • 在节点内应用 SHARD_GRAD_OP ,并在节点间复制参数。这与 HYBRID_SHARD 类似,但可能提供更高的吞吐量,因为未分片的参数在正向传播后不会被释放,从而节省了预反向传播中的全局聚合操作。

class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source][source]

这配置了 FSDP 原生混合精度训练。

变量:
  • param_dtype (Optional[torch.dtype]) – 这指定了在正向和反向传播过程中模型参数的 dtype,因此也指定了正向和反向计算的 dtype。在正向和反向传播之外,分片参数保持全精度(例如,对于优化器步骤),并且对于模型检查点,参数始终以全精度保存。(默认: None

  • reduce_dtype (Optional[torch.dtype]) – 这指定了梯度归约(即 reduce-scatter 或 all-reduce)的 dtype。如果这是 Noneparam_dtype 不是 None ,则采用 param_dtype 的值,仍然以低精度运行梯度归约。这允许与 param_dtype 不同,例如,强制梯度归约以全精度运行。(默认: None

  • buffer_dtype (Optional[torch.dtype]) – 这指定了缓冲区的 dtype。FSDP 不会分片缓冲区。相反,FSDP 在第一次正向传播中将它们转换为 buffer_dtype ,并在之后保持该 dtype。对于模型检查点,除了 LOCAL_STATE_DICT 之外,缓冲区以全精度保存。(默认: None

  • keep_low_precision_grads (bool) – 如果 False ,则 FSDP 在优化器步骤之前将梯度上转换为全精度。如果 True ,则 FSDP 保持梯度在用于梯度减少的数据类型中,如果使用支持以低精度运行的定制优化器,则可以节省内存。(默认: False

  • cast_forward_inputs (bool) – 如果 True ,则此 FSDP 模块将其前向参数和关键字段转换为 param_dtype 。这是为了确保参数和输入的数据类型匹配,以满足许多操作的要求。当仅对某些但不是所有 FSDP 模块应用混合精度时,可能需要设置为 True ,在这种情况下,混合精度 FSDP 子模块需要重新转换其输入。(默认: False

  • cast_root_forward_inputs (bool) – 如果 True ,则根 FSDP 模块将其前向参数和关键字段转换为 param_dtype ,覆盖 cast_forward_inputs 的值。对于非根 FSDP 模块,此操作不执行任何操作。(默认: True

  • `_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 这指定了在混合精度模式下需要忽略的模块类。当使用 auto_wrap_policy 时,这些类的模块将单独应用 FSDP 并禁用混合精度(这意味着最终的 FSDP 构造将与指定的策略有所偏差)。如果未指定 auto_wrap_policy ,则此操作不会执行任何操作。此 API 为实验性,可能随时更改。(默认: (_BatchNorm,) )`

注意

此 API 为实验性,可能随时更改。

注意

只有浮点张量会被转换为指定的数据类型。

注意

summon_full_params 中,参数被强制转换为全精度,但缓冲区不是。

注意

层归一化和批归一化即使在输入为低精度如 float16bfloat16 时也会在 float32 中累积。仅对那些归一化模块禁用 FSDP 的混合精度意味着保持仿射参数在 float32 中。然而,这会导致那些归一化模块分别进行全聚合和全分散,可能效率低下,因此如果工作负载允许,用户应优先考虑仍然对这些模块应用混合精度。

注意

默认情况下,如果用户传递一个包含任何 _BatchNorm 模块的模型并指定了 auto_wrap_policy ,则批归一化模块将分别应用 FSDP,并禁用混合精度。请参阅 _module_classes_to_ignore 参数。

注意

MixedPrecision 默认具有 cast_root_forward_inputs=Truecast_forward_inputs=False 。对于根 FSDP 实例,其 cast_root_forward_inputs 优先于其 cast_forward_inputs 。对于非根 FSDP 实例,它们的 cast_root_forward_inputs 值将被忽略。默认设置对于典型情况就足够了,在这种情况下,每个 FSDP 实例具有相同的 MixedPrecision 配置,并且只需要在模型前向传递的开始处将输入转换为 param_dtype

注意

对于具有不同 MixedPrecision 配置的嵌套 FSDP 实例,我们建议在每个实例的前向传播之前设置单独的 cast_forward_inputs 值来配置是否进行 casting 输入。在这种情况下,由于 casting 操作在每个 FSDP 实例的前向传播之前发生,因此父 FSDP 实例应该在它的 FSDP 子模块之前运行其非 FSDP 子模块,以避免由于不同的 MixedPrecision 配置而改变激活数据类型。

示例:

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1],
>>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>>     model,
>>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )

上文展示了工作示例。另一方面,如果将 model[1] 替换为 model[0] ,即使用不同 MixedPrecision 的子模块首先运行其前向传播,那么 model[1] 将会错误地看到 float16 激活而不是 bfloat16 激活。

class torch.distributed.fsdp.CPUOffload(offload_params=False)[source][source]

此配置 CPU 卸载。

变量:

offload_params (bool) – 此选项指定是否在未参与计算时将参数卸载到 CPU。如果 True ,则还将梯度卸载到 CPU,这意味着优化器步骤在 CPU 上运行。

class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source][source]

StateDictConfig 是所有 state_dict 配置类的基类。用户应实例化子类(例如 FullStateDictConfig )以配置 FSDP 支持的相应 state_dict 类型的设置。

变量:

offload_to_cpu (bool) – 如果 True ,则 FSDP 将状态字典值卸载到 CPU,如果 False ,则 FSDP 将其保留在 GPU 上。(默认: False

class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source][source]

FullStateDictConfig 是一个配置类,旨在与 StateDictType.FULL_STATE_DICT 一起使用。我们建议在保存完整状态字典时启用 offload_to_cpu=Truerank0_only=True 以节省 GPU 内存和 CPU 内存。此配置类应通过 state_dict_type() 上下文管理器使用,如下所示:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>>     state = fsdp.state_dict()
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn()  # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>>     state_dict = torch.load("my_checkpoint.pt")
>>>     model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(
...     model,
...     device_id=torch.cuda.current_device(),
...     auto_wrap_policy=...,
...     sync_module_states=True,
... )
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
变量:

rank0_only (bool) – 如果 True ,则只有 rank 0 保存完整的状态字典,非零 rank 保存空字典。如果 False ,则所有 rank 都保存完整的状态字典。(默认: False

class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source][source]

ShardedStateDictConfig 是一个配置类,用于与 StateDictType.SHARDED_STATE_DICT 一起使用。

变量:

_use_dtensor (bool) – 如果 True ,则 FSDP 将状态字典值保存为 DTensor ,如果 False ,则 FSDP 将它们保存为 ShardedTensor 。(默认: False

警告

_use_dtensorShardedStateDictConfig 的私有字段,由 FSDP 用于确定状态字典值的类型。用户不应手动修改 _use_dtensor

class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)[source][source]
class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source][source]

OptimStateDictConfig 是所有 optim_state_dict 配置类的基类。用户应实例化子类(例如 FullOptimStateDictConfig )以配置 FSDP 支持的相应 optim_state_dict 类型的设置。

变量:

offload_to_cpu (bool) – 如果 True ,则 FSDP 将状态字典的张量值卸载到 CPU,如果 False ,则 FSDP 保持它们在原始设备上(除非启用参数 CPU 卸载,则原始设备是 GPU)。(默认: True )

class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source][source]
变量:

rank0_only (bool) – 如果 True ,则只有 rank 0 保存完整的状态字典,非零 rank 保存空字典。如果 False ,则所有 rank 都保存完整的状态字典。(默认: False

class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source][source]

ShardedOptimStateDictConfig 是一个配置类,用于与 StateDictType.SHARDED_STATE_DICT 一起使用。

变量:

_use_dtensor (bool) – 如果 True ,则 FSDP 将状态字典值保存为 DTensor ,如果 False ,则 FSDP 将它们保存为 ShardedTensor 。(默认: False

警告

_use_dtensorShardedOptimStateDictConfig 的私有字段,由 FSDP 用于确定状态字典值的类型。用户不应手动修改 _use_dtensor

class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[source][source]
class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source][source]

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源