• 教程 >
  • 分布式检查点(DCP)入门
快捷键

分布式检查点(DCP)入门教程 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:张伊莉,库梅拉·罗德里戈,黄建钦,帕斯卡林·卢卡斯

备注

在 github 上查看和编辑此教程。

前提条件:

在分布式训练期间,检查点化 AI 模型可能会很具挑战性,因为参数和梯度被分配到各个训练器,并且当您重新开始训练时,可用的训练器数量可能会发生变化。PyTorch 分布式检查点(DCP)可以帮助简化此过程。

在本教程中,我们将展示如何使用 DCP API 与一个简单的 FSDP 包装模型。

DCP 工作原理 ¶

torch.distributed.checkpoint() 支持并行保存和加载多 rank 的模型。您可以使用此模块并行保存任意数量的 rank,并在加载时重新分片到不同的集群拓扑结构。

此外,通过使用 torch.distributed.checkpoint.state_dict() 中的模块,DCP 提供了在分布式环境中优雅地处理 state_dict 生成和加载的支持。这包括管理模型和优化器之间的完全限定名(FQN)映射,以及为 PyTorch 提供的并行设置设置默认参数。

DCP 在几个重要方面与 torch.save()torch.load() 不同:

  • 它为每个检查点生成多个文件,每个 rank 至少有一个。

  • 它在原地操作,意味着模型应首先分配其数据,DCP 使用该存储。

  • DCP 提供对有状态对象(在 torch.distributed.checkpoint.stateful 中正式定义)的特殊处理,如果它们被定义,则自动调用 state_dict 和 load_state_dict 方法。

备注

本教程中的代码在 8-GPU 服务器上运行,但可以轻松推广到其他环境。

如何使用 DCP ¶

在这里,我们使用一个用 FSDP 包装的玩具模型进行演示。同样,这些 API 和逻辑也可以应用于更大的模型以进行检查点保存。

保存

现在,让我们创建一个玩具模块,用 FSDP 包装它,用一些虚拟输入数据喂它,然后保存。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

请前往检查点目录。你应该会看到如下所示的 8 个检查点文件。

Distributed Checkpoint

加载

保存后,让我们创建相同的 FSDP 包装模型,并将存储中保存的状态字典加载到模型中。您可以在相同的世界大小或不同的世界大小中加载。

请注意,在加载之前,您必须调用 model.state_dict() ,并将其传递给 DCP 的 load_state_dict() API。这与 torch.load() 有根本的不同,因为 torch.load() 只需要加载检查点的路径。我们需要在加载之前使用 state_dict 的原因是:

  • DCP 使用从模型状态字典预分配的存储来从检查点目录加载。在加载过程中,传入的状态字典将就地更新。

  • DCP 在加载之前需要从模型获取分片信息以支持重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果您想在非 FSDP 包装的模型中加载已保存的检查点,尤其是在非分布式设置中进行推理,您也可以使用 DCP 来实现。默认情况下,DCP 以单程序多数据(SPMD)风格保存和加载分布式 state_dict 。但是,如果没有初始化进程组,DCP 会推断出您的意图是按“非分布式”风格保存或加载,即完全在当前进程中。

备注

多程序多数据(MPMD)的分布式检查点支持仍在开发中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式

尚未提及的一个缺点是,DCP 保存的检查点格式与使用 torch.save 生成的格式本质上不同。当用户希望与习惯于 torch.save 格式的用户共享模型,或者通常只想为他们的应用程序添加格式灵活性时,这可能会成为一个问题。在这种情况下,我们提供了 torch.distributed.checkpoint.format_utils 中的 format_utils 模块。

为方便用户,提供了一个命令行工具,其格式如下:

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上述命令中, modetorch_to_dcpdcp_to_torch 之一。

对于可能希望直接转换检查点的用户,也提供了相应的方法。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
dcp_to_torch_save(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

结论 ¶

总结来说,我们学习了如何使用 DCP 的 save()load() API,以及它们与 torch.save()torch.load() 的区别。此外,我们还学习了如何使用 get_state_dict()set_state_dict() 在状态字典生成和加载过程中自动管理与并行性相关的 FQN 和默认值。

如需更多信息,请参阅以下内容:


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源