摘要:在 PyTorch 分布式的新异步检查点功能中,我们与 IBM 的反馈相结合,展示了 IBM 研究团队如何实现并将有效检查点时间缩短了 10-20 倍。示例:7B 模型检查点的“停机时间”从平均 148.8 秒缩短到 6.3 秒,或提高了 23.62 倍。
这直接转化为在每 24 小时期间获得更多的网络训练进度,同时继续稳健地检查,或者更频繁地进行检查以缩短恢复窗口/时间。
在这篇笔记中,我们展示了实现异步检查点功能的代码和架构,以及由 IBM 研究团队验证的时间结果。
模型检查点是大模型训练的重要组成部分,但检查点是一个昂贵的流程,因为每个检查点过程都需要阻塞训练进度以保存最新的模型权重。然而,不进行检查点或减少检查点频率可能会导致训练进度损失很大。例如,死锁、落后者和 GPU 错误等故障需要重新启动训练过程。为了从故障中恢复,所有(训练)工作者必须停止他们的训练过程,并从最后一个保存的检查点重新启动。
因此,在容错性和训练进度之间固有的紧张关系表现为一种权衡,但现在有了异步检查点,PyTorch 分布式能够显著减少这种紧张关系,并允许频繁地进行检查点,同时对整体训练时间的影响最小。
作为背景,大约一年前我们就展示了分布式检查点如何大幅缩短了与原始 torch.save()功能相比的检查点时间。正如 IBM 研究部门所指出的,torch.save 可能需要 30 分钟才能检查点单个 11B 模型(PyTorch 1.13)。
随着分布式检查点技术的进步,检查点可以在 4 分钟内完成,适用于高达 30B 的模型大小。
使用异步检查点,由于检查点而损失的训练时间现在缩短至 30 秒以下,通常仅为 6 秒。
为了明确,异步检查点并没有像之前的更新所展示的那样压缩实际的序列化检查点时间。相反,它将最终的检查点过程从关键路径(到 CPU 线程)移除,以便在单独的线程中完成检查点的同时继续 GPU 训练。
然而,对于用户来说,由于检查点而导致的训练中断时间几乎相同,在许多情况下减少了 10 倍甚至 20 倍。
如上图所示的速度提升图表显示,异步检查点比一年前的大幅提升又提高了 10 倍到 23 倍。
异步检查点是如何工作的?
异步检查点将检查点过程模块化为两个部分,而不是一个单一的过程。第一阶段将每个 GPU/排名的数据从 GPU 复制到 CPU。这是用户可见的中断时间,对于 7B-13B 模型大小,可能需要 6-14 秒。第二阶段异步将数据从 CPU 内存复制到磁盘以持久化检查点。
在第一阶段数据被复制到 CPU 后,GPU 可以立即继续训练。因此,使用异步检查点,检查点的时间仅仅是复制最新模型状态到 CPU 所需的时间。
同时,当训练恢复时,非阻塞的 CPU 线程会与内存中新到达的数据一起工作,完成完整的检查点/序列化过程到磁盘(即持久保存)。
注意,PyTorch 的分布式检查点依赖于集体通信调用,以优化保存所需的每个 rank 的元数据,以及最终的同步,这标志着检查点完成并使操作原子化。如果检查点线程使用与训练相同的进程组,这可能会干扰分布式训练(因为分布式训练也依赖于类似的调用以同步多个 GPU 之间的训练),这可能会导致真正的集体挂起。
特别是,调用之间的竞争条件可能会导致训练和异步检查点保存线程同时等待集体调用,从而导致真正的集体挂起。
我们通过为异步检查点初始化一个单独的进程组来避免这种场景。这会将检查点集体分离到它们自己的逻辑进程组中,从而确保它不会干扰主训练线程中的集体调用。
我如何在训练中使用异步检查点?
异步检查点的使用相对简单。使用最新的 PyTorch 夜间版本,您需要使用 nccl 和 gloo 初始化您的进程组。Gloo 对于 cpu 线程部分是必需的。
从那里,创建一个用于异步检查点的副本来使用。然后像往常一样进行训练,但在您想要检查点的时候,使用异步保存 API,传入要保存的状态、检查点 ID 和检查点进程组。
异步检查点在 torchtitan 中也得到了全面实现。在这里,它被用于与预训练自己的 Llama2 或 Lllama3 模型一起使用。使用它就像更新 toml 配置文件一样简单:
未来工作
过去一年中,检查点技术取得了巨大进步。从几乎半小时的检查点时间缩短到 5 分钟以下,现在通过分布式检查点技术进一步缩短到 30 秒以下。
最后的边疆——零开销检查点,通过在反向传播过程中流式传输更新的权重,甚至消除了小于 30 秒的检查点时间,这样在异步检查点启动时,检查点数据已经位于 CPU 上。
这将有效地将大型模型训练转移到检查点不会造成中断或停机的地方,从而实现更高的鲁棒性(因为可以更频繁地执行检查点)和更快的训练进度,因为没有停机时间用于检查点。
源代码链接:https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py