• 文档 >
  • 分布式检查点 - torch.distributed.checkpoint
快捷键

分布式检查点 - torch.distributed.checkpoint ¶

分布式检查点(DCP)支持并行从多个 rank 加载和保存模型。它处理加载时的 resharding,使得可以在一个集群拓扑结构中保存,在另一个中加载。

DCP 与 torch.save 和 torch.load 在几个重要方面有所不同:

  • 它为每个检查点生成多个文件,每个 rank 至少有一个。

  • 它在原地操作,意味着模型应首先分配其数据,DCP 使用该存储。

加载和保存检查点的入口如下:

其他资源:¶

class torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value, names=<未提供>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source][source] ¶

异步检查点类型枚举。

torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, no_dist=False)[source][source]

以 SPMD 风格保存分布式模型。

此函数与 torch.save() 不同,因为它通过让每个 rank 只保存其本地碎片来处理 ShardedTensorDTensor

对于每个 Stateful 对象(具有 state_dictload_state_dict ),在序列化之前,save 将调用 state_dict

警告

PyTorch 版本之间保存的状态字典没有向后兼容性的保证。

警告

如果使用 process_group 参数,请确保只有其 rank 调用 save_state_dict,并且 state_dict 中的所有数据都属于它。

注意

当为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存检查点时,应该只有一个 shard_group 调用 save_state_dict,并且需要传入相应的进程组。

注意

如果没有可用的进程组,此函数假定意图是保存

本地进程中的 state_dict。

参数:
  • state_dict(Dict[str, Any])- 要保存的状态字典。

  • checkpoint_id(Union[str, os.PathLike, None])- 此检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认: None

  • storage_writer(Optional[StorageWriter])- 用于执行写入的 StorageWriter 实例。如果没有指定,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,将引发异常。(默认: None

  • planner (Optional[SavePlanner]) – SavePlanner 的实例。如果未指定,将使用默认规划器。(默认: None

  • process_group (Optional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。(默认: None

  • no_dist (bool) – 如果为 True ,此函数将假定意图是加载检查点而不使用跨等级同步。(默认: False

返回:

已保存检查点的元数据对象。

返回类型:

元数据

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
...     "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

注意

save_state_dict 使用集体操作来协调跨进程的写入。对于基于 NCCL 的进程组,对象的内部张量表示必须在通信之前移动到 GPU 设备。在这种情况下,使用的设备由 torch.cuda.current_device() 指定,并且用户有责任确保这一点,以便每个进程都有一个单独的 GPU,通过 torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, async_checkpointer_type=AsyncCheckpointerType.THREAD)[source][source]

save 的异步版本。此代码首先将 state_dict 解析到暂存存储(默认为 CPU 内存),然后在单独的线程中调用保存。

警告

该功能处于实验阶段,可能会发生变化。

参数:
  • state_dict (Dict[str, Any]) – 要保存的状态字典。

  • checkpoint_id (Union[str, os.PathLike, None]) – 该检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认: None

  • storage_writer (Optional[StorageWriter]) – 用于执行‘阶段’和‘保存’的 StorageWriter 实例。如果没有指定,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,将引发异常。(默认: None

  • planner (可选[SavePlanner]) – SavePlanner 的实例。如果没有指定,将使用默认规划器。(默认: None

  • process_group (可选[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。(默认: None

返回:

保存后包含结果 Metadata 对象的 Future。

返回类型:

Future

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
...     "/checkpoint/1"
... )
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]

此方法已弃用。请切换到‘save’。

返回类型:

元数据

torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None, no_dist=False)[source][source]

以 SPMD 风格将检查点加载到分布式状态字典中。

每个 rank 必须在其提供的 state_dict 中具有相同的键。键不匹配可能会导致挂起或错误。如果您不确定,可以使用 utils._assert_same_keys API 进行检查(但可能会产生通信成本)。

每个 rank 将尝试读取最少的必要数据以满足请求的状态字典。当加载 ShardedTensorDTensor 实例时,每个 rank 只读取其本地分片的数据。

对于每个 Stateful 对象(具有 state_dictload_state_dict ),加载将首先调用 state_dict ,然后尝试反序列化,反序列化完成后调用 load_state_dict 。对于每个非 Stateful 对象,加载将反序列化对象,然后将其替换为反序列化的对象。

警告

在调用此函数之前, state_dict 中的所有张量必须在它们的目标设备上分配。

所有非张量数据均使用 torch.load() 加载,并在 state_dict 中就地修改。

警告

用户必须调用根模块的 load_state_dict,以确保加载后处理和非张量数据正确传播。

参数:
  • state_dict (Dict[str, Any]) – 要将检查点加载到其中的状态字典。

  • checkpoint_id(Union[str, os.PathLike, None])- 此检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认: None

  • storage_reader(Optional[StorageReader])- 用于执行读取的 StorageWriter 实例。如果没有指定,DCP 将自动根据 checkpoint_id 推断读取器。如果 checkpoint_id 也为 None,将引发异常。(默认: None

  • planner(Optional[LoadPlanner])- LoadPlanner 的实例。如果没有指定,将使用默认的计划器。(默认: None

  • process_group(Optional[ProcessGroup])- 用于跨 rank 同步的 ProcessGroup。(默认: None

  • no_dist(布尔值)- 如果 True ,则此函数将假定意图是加载检查点而不使用跨秩同步。(默认: False

返回:

无。

返回类型:

无。

示例
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
...     "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 使用集体来协调跨秩的读取。对于基于 NCCL 的进程组,对象的内部张量表示必须在通信之前移动到 GPU 设备。在这种情况下,使用的设备由 torch.cuda.current_device() 指定,并且用户有责任确保这样设置,以便每个秩都有一个单独的 GPU,通过 torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]

此方法已弃用。请切换到‘load’。

以下模块也适用于对异步检查点使用的阶段机制进行额外自定义(torch.distributed.checkpoint.async_save):

class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source][source]

本协议旨在为 dcp.async_save 提供定制和可扩展性,使用户能够自定义在并行执行常规 dcp.save 路径之前如何对数据进行“预置”。预期的操作顺序(具体定义在 torch.distributed.state_dict_saver.async_save 中)如下:

  1. AsyncStager.stage_data(state_dict):

    此调用给 AsyncStager 提供了“预置”state_dict 的机会。在此上下文中,预置的期望和目的是创建一个“训练安全”的 state dict 表示形式,这意味着在预置完成后对模块数据的任何更新都不应反映在此方法返回的 state dict 中。例如,在默认情况下,整个 state dict 的副本被创建在 CPU RAM 上并返回此处,使用户可以在继续训练的同时避免对正在序列化的数据进行更改。

  2. 在预置返回的 state_dict 上并行调用 dcp.save。此调用负责

    用于序列化 state_dict 并将其写入存储。

  3. 如果 AsyncStager.should_synchronize_after_execute 为 True,则此方法将在执行后立即被调用

    序列化线程启动并在从 dcp.async_save 返回之前。如果此设置为 False,则假定用户已为优化训练循环中的保存延迟(例如,通过重叠预取与正向/反向传递)定义了自定义同步点,并且用户负责在适当的时间调用 AsyncStager.synchronize_staging。

属性 should_synchronize_after_execute bool

执行阶段后是否同步。

stage(state_dict)[source][source]

返回一个“已分阶段”的状态字典副本。对于已分阶段的副本,预期它不会受到阶段调用完成后发生的任何更新的影响。

返回类型:

dict[str, Union[~StatefulT, Any]]

synchronize_staging()[source][source]

如果某些情况下阶段是异步的,则应调用此方法以确保阶段完成并且可以安全地开始修改原始状态字典

class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source][source]

异步阶段器的实现,将状态字典放置在 CPU RAM 中,并在复制完成前阻塞。此实现还提供了使用固定内存优化阶段延迟的选项。

注意:在这种情况下,synchronize_staging 是一个空操作。

stage(state_dict)[源代码][源代码] ¶

返回 state_dict 在 CPU 上的副本。

返回类型:

dict[str, Union[~StatefulT, Any]]

同步预发布()[源][源] ¶

由于预发布是阻塞的,这是一个无操作函数。

除了上述入口点之外,以下所述的 Stateful 对象在保存/加载过程中提供额外的定制化。.. automodule:: torch.distributed.checkpoint.stateful

类 torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[源][源] ¶

可检查点和恢复的对象的状态协议。

load_state_dict(state_dict)[source][source]

从提供的 state_dict 恢复对象的状态。

参数:

state_dict (dict[str, Any]) – 从中恢复的状态字典

state_dict()[source][source]

对象应返回其 state_dict 表示的字典。此函数的输出将被检查点记录,并在 load_state_dict()中稍后恢复。

警告

由于检查点恢复的 inplace 特性,此函数也在 torch.distributed.checkpoint.load 期间被调用。

返回:

对象的状态字典

返回类型:

字典

本示例展示了如何使用 PyTorch 分布式检查点保存 FSDP 模型。

以下类型定义了在检查点期间使用的 IO 接口:

class torch.distributed.checkpoint.StorageReader[source][source]

load_state_dict 使用的接口,用于从存储中读取。

一个 StorageReader 实例在分布式检查点中既充当协调者又充当跟随者。作为初始化的一部分,每个实例都会被告知其角色。

子类应期望 load_state_dict 调用的以下顺序:

  1. (所有节点)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。

  2. (所有等级) 读取元数据()

  3. (所有等级) 设置存储读取器()

  4. (所有等级) 准备本地计划()

  5. (协调器) 准备全局计划()

  6. (所有等级) read_data()

准备全局计划(plans)[源][源] ¶

执行存储加载的集中式规划。

此方法仅在协调器实例上调用。

虽然这种方法可以生成完全不同的计划,但首选的方式是将特定存储的数据存储在 LoadPlan::storage_data 中。

参数:

plans (列表[torch.distributed.checkpoint.planner.LoadPlan]) – 一个包含 LoadPlan 实例的列表,每个 rank 对应一个实例。

返回:

经过存储全局规划后的转换后的 LoadPlan 列表

返回类型:

列表[torch.distributed.checkpoint.planner.LoadPlan]

准备本地计划(plan)[源][源]

执行特定存储的本地规划。

虽然此方法可以生成完全不同的计划,但推荐的方式是将特定存储数据存储在 LoadPlan::storage_data 中。

参数:

plan (LoadPlan) – 正在使用的 LoadPlan 的本地计划。

返回:

存储后的本地规划后转换的 LoadPlan

返回类型:

加载计划

抽象 read_data(plan, planner)[source][source] ¶

使用 planner 解析数据,从 plan 读取所有项目。

子类应该调用 LoadPlanner::load_bytes 来将 BytesIO 对象反序列化到正确的位置。

子类应该调用 LoadPlanner::resolve_tensor 来获取它应该加载数据的张量。

负责正确安排所需的跨设备复制的责任在于 StorageLayer。

参数:
  • plan(LoadPlan)- 要在本地执行的本地计划。

  • 规划器(LoadPlanner)- 使用该规划器对象来解析项。

返回:

所有读取完成后的未来。

返回类型:

Future[None]

抽象 read_metadata()[source][source] ¶

阅读检查点元数据。

返回:

正在加载的检查点关联的元数据对象。

返回类型:

元数据

abstract reset(检查点 ID=None)[source][source] ¶

调用表示即将发生一个新的检查点读取。如果用户为这次检查点读取设置了检查点 ID,则可能存在检查点 ID。检查点 ID 的含义取决于存储方式。它可以是文件夹/文件的路径,或者键值存储的键。

参数:

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。检查点 ID 的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储方式更像是键值存储,它也可以是键。(默认值: None

abstract set_up_storage_reader(metadata, is_coordinator)[source][source]

初始化此实例。

参数:
  • 元数据(元数据)- 要使用的元数据模式。

  • is_coordinator(布尔值)- 此实例是否负责协调检查点。

抽象类方法 validate_checkpoint_id(checkpoint_id)[source][source] ¶

检查给定的 checkpoint_id 是否被存储支持。这允许我们启用自动存储选择。

返回类型:

布尔型

class torch.distributed.checkpoint.StorageWriter[source][source]

save_state_dict 使用的接口,用于写入存储。

一个 StorageWriter 实例在分布式检查点中既充当协调者又充当跟随者。作为初始化的一部分,每个实例都会被告知其角色。

子类应期望以下调用序列。

  1. (所有等级)如果用户传递有效的 checkpoint_id,则设置 checkpoint_id。

  2. (所有等级)设置存储写入器()。

  3. (所有等级)准备本地计划()。

  4. (协调器) 准备全局计划()

  5. (所有进程) 写入数据()

  6. (协调器) 完成()

抽象 finish(metadata, results)[来源][来源] ¶

编写元数据并标记当前检查点为成功。

用于序列化元数据的实际格式/模式是实现细节。唯一的要求是它必须能够恢复到相同的对象图中。

参数:
  • 元数据(Metadata)- 新检查点的元数据

  • 结果(list[list[torch.distributed.checkpoint.storage.WriteResult]])- 来自所有排名的 WriteResults 列表。

返回:

None

返回类型:

None

准备全局计划[源代码][源代码] ¶

执行存储的集中式规划。

此方法仅在协调器实例上调用。

虽然此方法可以生成完全不同的计划,但首选方式是将存储特定数据存储在 SavePlan::storage_data 中。

参数:

plans(列表[torch.distributed.checkpoint.planner.SavePlan])- 一个包含 SavePlan 实例的列表,每个 rank 一个。

返回:

存储全局规划后的转换后的 SavePlan 列表。

返回类型:

list[torch.distributed.checkpoint.planner.SavePlan]

准备本地计划(准备局部计划)[源代码][源代码]

执行特定存储的本地规划。

虽然此方法可以生成完全不同的计划,但推荐的方式是将特定存储的数据存储在 SavePlan::storage_data 中。

参数:

本地计划(SavePlan)- 正在使用的来自 SavePlanner 的本地计划。

返回:

存储后的转换 SavePlan

返回类型:

SavePlan

抽象重置(checkpoint_id=None)[源][源]

调用表示即将发生一个新的检查点写入。如果用户为该检查点写入设置了检查点 ID,则可能存在检查点 ID。检查点 ID 的含义依赖于存储。它可以是一个文件夹/文件的路径,也可以是键值存储的键。(默认:@0#)

参数:

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。检查点 ID 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认: None

abstract set_up_storage_writer(is_coordinator)[source][source]

初始化此实例。

参数:

是否为协调器(布尔值)- 此实例是否负责协调检查点。

storage_meta()[source][source]

返回存储特定的元数据。这用于在检查点中存储可用于提供请求级可观察性的附加信息。StorageMeta 在保存调用期间传递给 SavePlanner 。默认返回 None。

TODO:提供示例

返回类型:

Optional[存储元数据]

抽象类方法 validate_checkpoint_id(checkpoint_id)[source][source] ¶

检查给定的 checkpoint_id 是否由存储支持。这允许我们启用自动存储选择。

返回类型:

布尔型

抽象 write_data(plan, planner)[source][source] ¶

使用 planner 解析数据,将 plan 中的所有条目写出来。

子类应该对计划中的每个条目调用 SavePlanner::resolve_data ,以获取底层对象进行写入。

子类应延迟调用 resolve_data,因为它可能会分配内存。对于张量,做以下假设:

  • 它们可能位于任何设备上,包括与 WriteItem::tensor_data 不匹配的设备。

  • 它们可能是视图,也可能不是连续的。只需保存投影即可。

参数:
  • 计划(SavePlan)- 要执行的保存计划。

  • 规划器(SavePlanner)- 用于将项目解析为数据的规划器对象。

返回:

一个完成到 WriteResult 列表的未来。

返回类型:

Future[list[torch.distributed.checkpoint.storage.WriteResult]]

以下类型定义了在检查点期间使用的计划器接口:

class torch.distributed.checkpoint.LoadPlanner[source][source]

定义了 load_state_dict 用来计划加载过程的协议的抽象类。

LoadPlanner 是具有状态的实体,可用于自定义整个加载过程。

LoadPlanner 作为 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。

计划器子类在 load_state_dict 期间可以期待以下调用序列:

  1. set_up_planner - 在所有进程中调用。

    标记检查点加载的开始。

  2. create_local_plan - 在所有进程中调用。

    处理状态字典并生成一个将要发送进行全局规划的 LoadPlan。

  3. create_global_plan - 仅在协调器进程中调用。

    从所有 rank 中获取 LoadPlan 并做出任何全局决策。

  4. load_bytes - 在每个 rank 上被多次调用

    这在每个 state_dict 中的非 tensor 值上只调用一次。

  5. resolve_tensor 和 commit_tensor - 在每个 rank 上被多次调用

    它们成对地用于每个 state_dict 中的 Tensor 值。

建议用户扩展 DefaultLoadPlanner 而不是直接扩展此接口,因为大多数更改都可以通过更改单个方法来表示。

常见的扩展模式有两种:

重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 的工作复杂性。由于加载是在原地发生的,我们需要保留原始 state_dict 的引用,因此我们需要能够原地执行它。

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         metadata: Metadata,
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

修改 resolve_tensor 和 commit_tensor 以处理加载时转换。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
abstract commit_tensor(read_item, tensor)[source][source]

当 StorageReader 完成数据加载到 tensor 后调用一次。

提供的张量与调用 resolve_tensor 返回的相同。如果此 LoadPlanner 需要在将其复制回 state_dict 之前对 tensor 进行后处理,则需要此方法。

张量内容将遵循其设备同步模型。

创建全局计划(create_global_plan)[源][源]

计算全局负载计划并返回每个进程的规划。

. 注意:仅在协调进程上调用

返回类型:

list[torch.distributed.checkpoint.planner.LoadPlan]

抽象 create_local_plan()[来源][来源] ¶

根据由 set_up_planner 提供的状态字典和元数据创建 LoadPlan。

. 注意。这将在每个 rank 上调用。

返回类型:

装载计划

完成计划(中心计划)[来源][来源]

接受协调器的计划并返回最终的装载计划。

返回类型:

装载计划

abstract load_bytes(read_item, value)[source][source]

加载由 read_item``and ``value 描述的项目。

此方法预期将就地修改底层 state_dict。

value 的内容由用于生成正在加载的检查点的 SavePlanner 定义。

resolve_bytes(read_item)[source][source]

返回用于由 StorageReader 加载 read_item 的 BytesIO。

BytesIO 应与底层 state_dict 中的一个别名,因为 StorageReader 将替换其内容。

返回类型:

BytesIO

abstract resolve_tensor(read_item)[source][source]

返回由 read_item 描述的张量,供 StorageReader 加载 read_item 使用。

该张量应与底层 state_dict 中的一个别名,因为 StorageReader 将替换其内容。如果由于任何原因无法实现这一点,规划器可以使用 commit_tensor 方法将数据复制回 state_dict 中的那个。

返回类型:

张量

abstract set_up_planner(state_dict, metadata=None, is_coordinator=False)[source][source]

初始化此实例以将数据加载到 state_dict

. 注意。这将在每个 rank 上调用。

class torch.distributed.checkpoint.LoadPlan(items: list[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source][source]
class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source][source]
class torch.distributed.checkpoint.SavePlanner[source][source]

抽象类,定义了 save_state_dict 用于规划保存过程的协议。

SavePlanner 是一种有状态的实体,可以用来自定义整个保存过程。

SavePlanner 作为 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。

规划器子类在 save_state_dict 期间可以期待以下调用序列:

  1. set_up_planner - 在所有进程中调用。

    标记检查点保存的开始。

  2. create_local_plan - 在所有进程中调用。

    处理状态字典并生成一个将要发送进行全局规划的 SavePlan。

  3. create_global_plan - 仅在协调器进程中调用。

    从所有等级获取 SavePlan 并做出任何全局决策。

  4. finish_plan - 在所有等级上调用。

    这给每个等级一个调整全局规划决策的机会。

  5. resolve_data - 在每个等级上多次调用。

    在状态字典中查找存储层的值以进行写入。

建议用户扩展 DefaultSavePlanner 而不是直接扩展此接口,因为大多数更改都可以通过单个方法的更改来表示。

常见的扩展模式有 3 种:

重新编写 state_dict。这是扩展保存过程的最简单方法,因为它不需要理解 SavePlan 的工作复杂性:

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         storage_meta: Optional[StorageMeta],
>>>         is_coordinator: bool,
>>>     ) -> None:
>>> # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)

同时修改本地计划和查找。这在精细控制数据持久化方式时很有用

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤来做出每个等级单独无法做出的中央决策

>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         iters = [iter(all_plans[0].items)] * len(all_plans)
>>>         items_per_rank = [
>>>             [item for item in items if item is not None]
>>>             for items in zip(*zip_longest(*iters), strict=True)
>>>         ]
>>>         all_plans = [
>>>             replace(plan, items=items)
>>>             for plan, items in zip(all_plans, items_per_rank, strict=True)
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,一些规划器需要在检查点中保存额外的元数据,这是通过每个等级在本地计划中贡献他们的数据项,然后全局规划器汇总它们来实现的:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
abstract create_global_plan(all_plans)[source][source]

计算全局检查点计划并返回每个 rank 的本地计划。

仅在协调 rank 上调用此操作。

返回类型:

tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata]

抽象 create_local_plan()[source][source] ¶

计算当前排名的保存计划。

这将被汇总并传递给 create_global_plan。可以通过 SavePlan::planner_data 传递规划器特定数据。

这将在所有排名上被调用。

返回类型:

保存计划

完成计划抽象 finish_plan(new_plan)[source][source] ¶

合并由 create_local_plan 创建的计划和 create_global_plan 的结果。

这将在所有进程中调用。

返回类型:

保存计划 SavePlan

抽象 resolve_data(write_item)[source][source] ¶

write_item 转换并准备用于存储,确保幂等性和线程安全。

state_dict 中查找与 write_item 关联的对象,并在存储层消费之前应用任何转换(如序列化)。

在每个排名上多次调用,至少在最终的 SavePlan 中的每个 WriteItem 上调用一次。

此方法应该是幂等的且线程安全的。StorageWriter 实现可以随时调用它,以满足其需求。

任何分配内存的转换都应该在调用此方法时懒加载,以减少检查点所需的峰值内存。

当返回张量时,它们可以在任何设备或格式上,它们也可以是视图。确定如何保存它们的责任在于存储层。

返回类型:

Union[Tensor, BytesIO]

abstract set_up_planner(state_dict, storage_meta=None, is_coordinator=False)[source][source]

初始化此规划器以保存 state_dict

实现应将这些值保存下来,因为这些值在保存过程中不会提供。

这将在所有 rank 上调用。

class torch.distributed.checkpoint.SavePlan(items: list[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None, usable: bool = True)[source][source]
class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source][source]

数据类,用于存储需要写入存储的信息。

tensor_storage_size()[source][source]

计算底层张量的存储大小,如果不是张量写入,则为 None。

返回:

可选的 int 类型,存储大小,如果存在底层张量,则以字节为单位。

返回类型:

可选[int]

我们提供了一个基于文件系统的存储层:

class torch.distributed.checkpoint.FileSystemReader(path, _extension_registry=None)[source][source]
属性 checkpoint_idUnion[strPathLike] ¶

返回将要用于加载检查点的 checkpoint_id。

类 torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True, _extensions=None)[source][source] ¶

使用文件 I/O 的 StorageWriter 的基本实现。

本实现做出以下假设和简化:

  • 检查点路径是一个空目录或不存在目录。

  • 文件创建是原子的

检查点由每个写入请求的一个文件以及一个包含序列化元数据的.metadata 文件组成。

stage(state_dict)[source][source]

AsyncStager.stage 的覆盖

返回类型:

dict[str, Union[~StatefulT, Any]]

我们提供了 LoadPlanner 和 SavePlanner 的默认实现,可以处理所有 torch.distributed 构造,如 FSDP、DDP、ShardedTensor 和 DistributedTensor。

class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False, enable_plan_caching=False)[source][source]
lookup_object(index)[source][source]

从规划器接口扩展,使其易于扩展默认规划器。

返回类型:

任何

transform_object(write_item, object)[source][source]

从规划器接口扩展,使其易于扩展默认规划器。

class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]

在 LoadPlanner 的基础上增加了多个功能。

特别是增加了以下功能:

flatten_state_dict: 处理嵌套字典的 state_dict flatten_sharded_tensors: 对于 FSDP 在 2D 并行模式下 allow_partial_load: 如果为 False,当 state_dict 中存在但检查点中不存在的键时,将引发运行时错误

lookup_tensor(index)[source][source]

从规划器接口扩展,使其易于扩展默认规划器。

返回类型:

张量

transform_tensor(read_item, tensor)[source][source]

从规划器接口扩展,使其易于扩展默认规划器。

由于遗留的设计决策,FSDP 和 DDP 的状态字典可能具有不同的键或完全限定名称(例如,layer1.weight),即使原始未并行化的模型是相同的。此外,FSDP 提供各种类型的模型状态字典,如完整和分片状态字典。此外,优化器状态字典使用参数 ID 而不是完全限定名称来标识参数,这可能在并行使用时(例如,管道并行)引起问题。

为了应对这些挑战,我们为用户提供了一系列 API,以便轻松管理状态字典。get_model_state_dict()返回一个与未并行化模型状态字典键一致的状态字典。同样,get_optimizer_state_dict()提供具有所有并行应用键一致的优化器状态字典。为了实现这种一致性,get_optimizer_state_dict()将参数 ID 转换为与未并行化模型状态字典中找到的完全限定名称相同的名称。

注意,这些 API 返回的结果可以直接用于 torch.distributed.checkpoint.save()和 torch.distributed.checkpoint.load()方法,无需进行任何额外的转换。

提供了 set_model_state_dict()和 set_optimizer_state_dict()方法来加载由相应 getter API 生成的模型和优化器 state_dict。

注意,set_optimizer_state_dict()只能在调用 backward()之前或 step()在优化器上被调用之后调用。

注意,此功能为实验性,API 签名可能在将来发生变化。

torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]

返回模型的状态字典和优化器状态字典。

get_state_dict 可以处理由 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及这些并行方式的任意组合并行化的任何模块。 get_state_dict 的主要功能包括:1.) 返回可以与不同数量的训练师和/或不同的并行方式重新分片的模型和优化器状态字典。2.) 隐藏特定于并行状态字典的 API。用户不需要调用这些 API。3.) 对结果状态字典进行合理性检查。

结果状态字典的键是规范的全限定名(FQNs)。规范的全限定名是指基于参数在 nn.Module 层次结构中的位置的全限定名。更具体地说,一个参数的规范全限定名是当模块没有被任何并行方式分布时, module.named_parameters()module.named_buffers() 返回的全限定名。由于优化器内部使用参数 ID 来表示参数,因此在调用此 API 时,将进行参数 ID 到规范全限定名的转换。

也可以处理未并行化的模块。在这种情况下,仅执行一个功能 - 将优化器参数 ID 转换为规范 FQNs。

示例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
...     fsdp_model, fsdp_optim
... )
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict
参数:
  • 模型(nn.Module)- 模型的 nn.Module。

  • 优化器(Union[None, Optimizer, Iterable[Optimizer]])- 用于优化 model 的优化器。

  • 子模块(已弃用)- Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。

  • options (StateDictOptions) – 控制模型状态字典和优化器状态字典返回方式的选项。请参阅 StateDictOptions 获取详细信息。

返回:

Tuple 包含模型状态字典和优化器状态字典。

返回类型:

Tuple[Dict[str, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source][source]

返回模型的 state_dict。

详细用法请见 get_state_dict

参数:
  • model (nn.Module) – 模型的 nn.Module。

  • submodules(已弃用)- Optional[set[nn.Module]]:仅返回属于子模块的模型参数。

  • options (StateDictOptions) – 控制模型状态字典和优化器状态字典返回方式的选项。请参阅 StateDictOptions 获取详细信息。

返回:

The state_dict for model.

返回类型:

Dict[str, ValueType]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]

返回优化器的组合状态字典。

详细用法请见 get_state_dict

参数:
  • model (nn.Module) – 模型的 nn.Module 。

  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • 子模块(已弃用)- Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。

  • 选项(StateDictOptions)- 控制如何返回模型状态字典和优化器状态字典的选项。有关详细信息,请参阅 StateDictOptions。

返回:

optimizers 的状态字典。

返回类型:

优化器状态类型

torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source][source]

加载模型状态字典和优化器状态字典。

get_state_dict 的对应操作,用于将状态字典设置到模型和优化器中。给定的 model_state_dictoptim_state_dict 不必由 get_state_dict 返回,但必须满足以下要求:1) 所有 FQNs 都必须是 get_state_dict 中定义的规范 FQNs,2) 如果张量是分片的,它必须是 ShardedTensor 或 DTensor,3) 优化器状态字典不能包含参数 ID;键应该是规范 FQNs。

警告: set_state_dict 只能在 backward() 之前或 step() 之后调用。

被调用在优化器上。否则,优化器状态将无法正确初始化。

参数:
  • 模型(nn.Module)- 模型的 nn.Module。

  • 优化器(Union[Optimizer, Iterable[Optimizer]]) - 用于优化 model 的优化器。

  • model_state_dict (Dict[str, ValueType]) - (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型状态字典。如果 model_state_dict 的键为 nn.Module,则键是 model 的子模块,值应该是子模块的状态字典。在加载状态字典时,将向子模块前缀追加到状态字典中。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType: 加载的优化器状态字典。

  • options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。请参阅 StateDictOptions 获取详细信息。

返回:

  • missing_keys 是包含模型状态字典中缺失键的字符串列表。

  • unexpected_keys 是包含模型状态字典中意外键的字符串列表。

返回类型:

使用 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source][source]

加载模型状态字典。

将状态字典设置到模型的对应方法。详见 set_state_dict 的详细用法。

参数:
  • model (nn.Module) – 模型对应的 nn.Module。

  • model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 加载的模型状态字典。如果 model_state_dict 的键是 nn.Module,则键是 model 的子模块,值应该是子模块的状态字典。在加载状态字典时,将向子模块前添加前缀。

  • options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。有关详细信息,请参阅 StateDictOptions。

返回:

  • missing_keys 是包含缺失键的字符串列表

  • unexpected_keys 是一个包含意外键的字符串列表

返回类型:

使用 NamedTuplemissing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source][source]

加载优化器的状态字典。

get_optimizer_state_dict 的对应操作用于设置优化器的状态字典。详情用法见 set_state_dict

警告: set_optimizer_state_dict 只能在 backward() 之前或之后调用。

在优化器上调用 step() 。否则,优化器状态将无法正确初始化。

参数:
  • model (nn.Module) – 模型的 nn.Module。

  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType:要加载的优化器状态字典。

  • options (StateDictOptions) – 控制模型状态字典和优化器状态字典加载的选项。请参阅 StateDictOptions 获取详细信息。

返回:

None

返回类型:

None

class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False, dsd_fqn_modifiers='_fqn_modifiers')[source][source]

此数据类指定了 get_state_dict/set_state_dict 的工作方式。

  • full_state_dict : 如果设置为 True,返回的状态字典中的所有张量都将被收集。返回的状态字典中不会包含 ShardedTensor 和 DTensor。

  • cpu_offload : 将所有张量卸载到 CPU。为防止 CPU OOM,如果 full_state_dict 也为 True,则只有 rank0 将获得状态字典,而所有其他 rank 将获得空状态字典。

  • ignore_frozen_params : 如果值为 True,则返回的状态字典将不包含任何冻结参数 - requires_grad 为 False。默认值为 False。

  • keep_submodule_prefixes (已弃用):当 submodules 不为 None 时,此选项指示是否从 state_dict 键中保留子模块前缀。例如,如果子模块是 module.pretrain ,则参数的完整 FQN 为 pretrain.layer1.weight 的 param。当此选项为 True 时,返回的 state_dict 中的参数键将为 pretrain.layer1.weight 。如果选项为 False,则键将为 layer1.weight 。请注意,如果 keep_submodule_prefixes 为 False,则可能存在冲突的 FQN,因此 submodules 中应只有一个子模块。

  • strict : 当 strict 调用 model.load_state_dict() 时使用的 set_state_dict 选项。

  • broadcast_from_rank0 : 当选项为 True 时,rank0 应接收完整的

    state_dict 并将 state_dict/ optim_state_dict 中的张量逐个广播到其他 ranks。其他 ranks 将接收张量并根据模型和优化器中的本地分片进行分片。 full_state_dict 必须设置为 True 才能使用此选项。此选项目前仅支持 DTensor,不支持传统的 ShardedTensor。

对于习惯使用和共享 torch.save 格式的用户,以下方法提供了在格式之间转换的离线实用工具。

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source][source]

给定一个包含 DCP 检查点的目录,此函数将将其转换为 Torch 保存文件。

参数:
  • dcp_checkpoint_dir (Union[str, PathLike]) – 包含 DCP 检查点的目录。

  • torch_save_path (Union[str, PathLike]) – 要存储转换后的 Torch 保存文件的文件名。

警告

为了避免内存溢出,建议仅在单个 rank 上运行此函数。

torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source][source]

给定 torch 保存文件的存储位置,将其转换为 DCP 检查点。

参数:
  • torch_save_path (Union[str, PathLike]) – 火炬保存文件的文件名。

  • dcp_checkpoint_dir (Union[str, PathLike]) – 存储 DCP 检查点的目录。

警告

为避免内存溢出,建议仅在单个 rank 上运行此函数。

以下类也可以用于从 torch.save 格式在线加载和重新分片模型。

class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source][source]

用于读取 Torch Save 文件的存储读取器。此读取器将在协调器 rank 上读取整个检查点,然后将每个张量广播和分片到所有 rank。

. 注意:建议与 DynamicMetaLoadPlanner 一起使用

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
准备全局计划(global_plan)[源][源] ¶

存储读取器方法的实现

返回类型:

列表[torch.distributed.checkpoint.planner.LoadPlan]

准备本地计划(plan)[源][源] ¶

存储读取方法的实现

返回类型:

加载计划

read_data(plan, planner)[source][source]

在协调器 rank 上读取 torch 保存的数据,之后广播,这会产生通信成本,但避免了在每个 rank 上加载整个检查点,希望防止内存溢出问题

返回类型:

未来[无]

读取元数据()[来源][来源] ¶

扩展默认的 StorageReader 以支持构建元数据文件

返回类型:

元数据

重置(checkpoint_id=None)[源][源] ¶

实现 StorageReader 方法

设置存储读取器(metadata,is_coordinator)[源][源] ¶

实现 StorageReader 方法

classmethod validate_checkpoint_id(checkpoint_id)[source][source]

存储读取器方法的实现

返回类型:

布尔型

class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]

扩展 DefaultLoadPlanner,根据传入的状态字典创建一个新的元数据对象,避免需要从磁盘读取元数据。这在读取没有元数据文件的格式时很有用,例如 Torch 保存文件。

. 注意:建议与 BroadcastingTorchSaveReader 一起使用

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
set_up_planner(state_dict, metadata=None, is_coordinator=False)[source][source]

设置规划器,通过从状态字典创建元数据对象扩展默认行为

以下实验性接口提供,以在生产环境中提高可观察性:


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源