引言
近年来,研究界在自然语言处理、计算机视觉和其他领域使用大型模型取得了许多成功。其中许多成功得益于云 TPU——这是一种强大的分布式训练硬件。为了支持 PyTorch 中的 TPU,PyTorch/XLA 库为 XLA 设备(尤其是 TPU)提供了一个后端,并为在 TPU 上扩展大型 PyTorch 模型奠定了基础。
然而,PyTorch 生态系统中的大多数现有建模扩展工具都假设 GPU(或 CPU)设备,通常依赖于 CUDA 的特定功能,并且不能直接在 TPU 上工作。缺乏扩展工具使得构建无法适应单个 TPU 芯片内存的大型模型变得具有挑战性。
为了支持 TPU 上的模型扩展,我们在 PyTorch/XLA 1.12 版本中实现了广泛采用的完全分片数据并行(FSDP)算法。我们提供了一个与基于 CUDA 的 PyTorch FSDP 类具有相似高级设计的 FSDP 接口,同时处理了 XLA 中的几个限制(更多详细信息请参阅设计笔记)。这个 FSDP 接口使我们能够轻松地在 TPU 上构建具有例如 10B+参数的模型,并已使许多研究探索成为可能。
在 PyTorch/XLA 中使用完全分片数据并行(FSDP)
我们提供了一个包装类 XlaFullyShardedDataParallel
,用于将给定 PyTorch 模型的参数分片到数据并行工作进程中。以下是一个示例用法:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
使用 XlaFullyShardedDataParallel
包装 nn.Module
实例,可以在其上启用 ZeRO-2 算法,在整个训练过程中对其梯度优化器状态进行分片。在其正向和反向传播过程中,首先从相应的分片重建包装模块的完整参数以进行计算。
嵌套 FSDP 包装可用于进一步节省内存。这允许模型在任何给定时间只存储单个层的完整参数。对于嵌套 FSDP,应首先使用内部 FSDP 包装其各个子模块,然后再使用外部 FSDP 包装基本模型。这允许模型在任何给定时间只存储单个层的完整参数。并且拥有外部包装可以确保处理任何剩余的参数,对应于 ZeRO-3 算法。嵌套 FSDP 包装可以应用于子模块的任何深度,并且可以有超过 2 层的嵌套。
模型和优化器的检查点保存和加载可以像以前一样通过保存和加载它们的 .state_dict()
来完成。同时,每个训练过程应保存其自己的分片模型参数和优化器状态的检查点文件,并在恢复时加载对应 rank 的检查点文件(无论 ZeRO-2 还是 ZeRO-3,即嵌套包装与否)。提供了一个命令行工具和 Python 接口,用于将分片模型检查点文件合并成一个完整/非共享的模型检查点文件。
梯度检查点(也称为“激活检查点”或“重材料化”)是模型缩放的一种常见技术,可以与 FSDP 结合使用。我们提供了 checkpoint_module
,一个基于 nn.Module
实例的梯度检查点包装函数(基于 torch_xla.utils.checkpoint.checkpoint
)。
下面的 MNIST 和 ImageNet 示例提供了(平面或嵌套)FSDP、模型检查点保存和合并以及梯度检查点的示例用法。
PyTorch/XLA 中 FSDP 的起始示例
使用 FSDP 训练 MNIST 和 ImageNet
MNIST 和 ImageNet 分类通常可以用作构建更复杂深度学习模型的起点。我们提供了以下 FSDP 示例,针对这两个数据集:
- MNIST:test/test_train_mp_mnist_fsdp_with_ckpt.py(它还说明了检查点的保存和合并)
- ImageNet:test/test_train_mp_imagenet_fsdp.py
将它们与 MNIST 和 ImageNet 的 vanilla 数据并行示例进行比较,说明了如何将训练脚本适配以使用 FSDP。需要注意的是,在 FSDP 包装的模型上步进优化器时,应直接调用 optimizer.step()
,而不是 xm.optimizer_step(optimizer)
。后者会减少跨 rank 的梯度,这在 FSDP 中是不需要的,因为在 FSDP 中梯度已经通过其反向传播中的 reduce-scatter 操作进行了减少和分片。
安装
FSDP 从 PyTorch/XLA 1.12 及更高版本的夜间发布版中可用。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 了解安装指南以及 Cloud TPU 分配。然后在 TPU VM 上克隆 PyTorch/XLA 仓库,具体操作如下
mkdir -p ~/pytorch && cd ~/pytorch
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST
经过 2 个 epoch 的训练,准确率达到了 98.9:
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp
上面的脚本会在最后自动测试分片模型检查点的合并。您也可以通过以下方式手动合并分片检查点文件
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet
经过 100 个 epoch,准确率达到 75.9,与不使用 FSDP 时的结果相同;下载并预处理 ImageNet-1k 数据集到 /datasets/imagenet-1k
:
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 \
--lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 \
--num_epochs 100 \
--use_nested_fsdp
你还可以探索这两个示例中的其他选项,例如 --use_gradient_checkpointing
在 ResNet 块上应用梯度检查点(即激活检查点),或 --compute_dtype bfloat16
以 bfloat16 精度执行正向和反向传播。
大规模模型示例
在 TPU 上构建大型模型时,我们通常需要关注内存限制(例如,TPU v3 每个核心 16 GB,TPU v4 每个芯片 32 GB)。对于无法适应单个 TPU 内存或主机 CPU 内存的大型模型,应使用嵌套 FSDP 实现 ZeRO-3 算法的交错子模块构建,并使用内部 FSDP 包装,以确保在构建过程中模型的全量数据无需存储在内存中。
我们在 https://github.com/ronghanghu/ptxla_scaling_examples 中展示了这些案例,该链接提供了在 TPU v3 pod(具有 128 个核心)上训练 10B+参数的 Vision Transformer(ViT)模型以及其他案例的示例。
设计说明
人们可能会 wonder 为什么需要 在 PyTorch/XLA 中 开发一个独立的 FSDP 类,而不是直接重用 PyTorch 的 FSDP 类或将其扩展到 XLA 后端。在 PyTorch/XLA 中创建一个独立的 FSDP 类的主要动机是,原生的 PyTorch FSDP 类严重依赖于 XLA 设备不支持 CUDA 功能,而 XLA 也具有一些独特的特性,需要特殊处理。这些区别需要 FSDP 的不同实现,这将更容易在一个单独的类中构建。
API 调用 的变化
一个显著的区别是,原生的 PyTorch FSDP 是基于独立的 CUDA 流在 eager 模式下进行异步执行,而 PyTorch/XLA 则在 lazy 模式下运行,并且也不支持流。此外,TPU 要求所有设备统一运行相同的程序。因此,在 PyTorch/XLA 的 FSDP 实现中,需要用 XLA API 和替代的均匀实现来替换 CUDA 调用和进程异构性。
张量存储处理
另一个显著的区别是如何释放张量的存储,这在 XLA 中比在 CUDA 中要困难得多。要实现 ZeRO-3,需要在模块的前向传播后释放全部参数的存储,以便下一个模块可以重用这个内存缓冲区进行后续计算。PyTorch 的 FSPD 通过通过 p.data.storage().resize_(0)
释放参数 p
的实际存储在 CUDA 上完成此操作。然而,由于 XLA HLO IRs 完全功能化且不提供任何操作来释放张量或调整其存储,XLA 张量没有这种 .storage()
句柄。在 PyTorch 接口下方,只有 XLA 编译器可以决定何时释放对应于 XLA 张量的 TPU 设备内存,前提是只有在 Python 中释放张量对象时才能释放内存——这在 FSDP 中不可能发生,因为这些参数张量作为模块属性被引用,并且由 PyTorch autograd 保存以进行反向传播。
我们解决这个问题的方案是将张量的值属性与其 autograd Variable 属性分开,并通过将 nn.Parameter
张量的 .data
属性设置为大小为 1 的虚拟标量来释放 nn.Parameter
张量。这样,Python 中实际的完整参数数据张量就会被解引用,以便 XLA 可以回收其内存用于其他计算,同时 autograd 仍然可以追踪基础 nn.Parameter
作为参数数据的弱引用。为了实现这一点,还需要处理参数的视图,因为在 PyTorch 中视图也持有对其实际数据的引用(这需要修复 PyTorch/XLA 中视图的形状相关问题)。
与 XLA 编译器一起工作
上述解决方案应该足以在 XLA 编译器忠实保留我们的 PyTorch 程序中的操作及其执行顺序的情况下释放全部参数。但还有一个问题——XLA 试图通过在 HLO IRs 上应用常见的子表达式消除(CSE)来优化程序以加快其执行速度。在 FSDP 的简单实现中,XLA 编译器通常会在重建完整参数时,看到它是前向传递中的重复计算,从而消除反向传递中的第 2 次 all-gather,并直接保留和重用我们希望在正向传递后释放的完整参数。为了防止这种不希望的编译器行为,我们在 PyTorch/XLA 中引入了优化屏障操作,并使用它来停止消除第 2 次 all-gather。这种优化屏障也应用于梯度检查点的类似情况,以防止前向和反向传递之间的 CSE 消除重计算。
在未来,如果 CUDA 和 XLA 之间的区别不像上面提到的那么明显,那么考虑将 PyTorch/XLA FSDP 与原生 PyTorch FSDP 合并以拥有统一的接口可能是有意义的。
致谢
感谢 AWS 的郝俊民对 PyTorch/XLA FSDP 拉取请求的审查。感谢 Meta PyTorch 团队的布赖恩·赫什对 PyTorch 核心问题的支持。感谢谷歌的 Isaack Karanja、Will Cromar 和 Blake Hechtman 在 GCP、XLA 和 TPU 问题上的支持。
感谢 Meta FAIR 的 Piotr Dollar、Wan-Yen Lo、Alex Berg、Ryan Mark、Kaiming He、Xinlei Chen、Saining Xie、Shoubhik Debnath、Min Xu 和 Vaibhav Aggarwal 就各种 TPU 相关讨论的贡献。