使用 ZeroRedundancyOptimizer 进行分片优化状态 ¶
创建时间:2025 年 4 月 1 日 | 最后更新时间:2025 年 4 月 1 日 | 最后验证:未验证
在本菜谱中,您将学习:
零冗余优化器的高级思想
如何在分布式训练中使用零冗余优化器及其影响
需求
PyTorch 1.8+
什么是 ZeroRedundancyOptimizer
? ¶
ZeroRedundancyOptimizer(零冗余优化器)这一概念来源于 DeepSpeed/ZeRO 项目和 Marian,它们通过将优化器状态分片存储在分布式数据并行过程中来减少每个进程的内存占用。在《分布式数据并行入门教程》中,我们展示了如何使用 DistributedDataParallel(DDP)来训练模型。在该教程中,每个进程都保留了一个优化器的专用副本。由于 DDP 在反向传播过程中已经同步了梯度,因此所有优化器副本在每次迭代中都会操作相同的参数和梯度值,这就是 DDP 保持模型副本状态一致的方式。通常,优化器还会维护本地状态。例如, Adam
优化器使用每个参数的 exp_avg
和 exp_avg_sq
状态。因此, Adam
优化器的内存消耗至少是模型大小的两倍。鉴于这一观察结果,我们可以通过在 DDP 进程中分片存储优化器状态来减少优化器的内存占用。更具体地说,我们不是为所有参数创建每个参数的状态,而是在不同的 DDP 进程中,每个优化器实例只保留所有模型参数的一部分的优化器状态。 优化器 step()
只更新其分片中的参数,然后将更新后的参数广播到所有其他 DDP 进程,以确保所有模型副本仍然处于相同的状态。
如何使用 ZeroRedundancyOptimizer
? ¶
下面的代码演示了如何使用 ZeroRedundancyOptimizer。大部分代码与分布式数据并行笔记中展示的简单 DDP 示例类似。主要区别在于 example
函数中的 if-else
子句,它包装了优化器构造,在 ZeroRedundancyOptimizer 和 Adam
优化器之间切换。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
def print_peak_memory(prefix, device):
if device == 0:
print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")
def example(rank, world_size, use_zero):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
print_peak_memory("Max memory allocated after creating local model", rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
print_peak_memory("Max memory allocated after creating DDP", rank)
# define loss function and optimizer
loss_fn = nn.MSELoss()
if use_zero:
optimizer = ZeroRedundancyOptimizer(
ddp_model.parameters(),
optimizer_class=torch.optim.Adam,
lr=0.01
)
else:
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
# forward pass
outputs = ddp_model(torch.randn(20, 2000).to(rank))
labels = torch.randn(20, 2000).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
print_peak_memory("Max memory allocated before optimizer step()", rank)
optimizer.step()
print_peak_memory("Max memory allocated after optimizer step()", rank)
print(f"params sum is: {sum(model.parameters()).sum()}")
def main():
world_size = 2
print("=== Using ZeroRedundancyOptimizer ===")
mp.spawn(example,
args=(world_size, True),
nprocs=world_size,
join=True)
print("=== Not Using ZeroRedundancyOptimizer ===")
mp.spawn(example,
args=(world_size, False),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
下面的输出显示了结果。当启用 ZeroRedundancyOptimizer
与 Adam
时,优化器 step()
的峰值内存消耗是 vanilla Adam
内存消耗的一半。这与我们的预期相符,因为我们正在将 Adam
优化器状态跨两个进程进行分片。输出还显示,在 ZeroRedundancyOptimizer
的情况下,模型参数在经过一次迭代后仍然具有相同的值(参数总和在有 ZeroRedundancyOptimizer
和无 ZeroRedundancyOptimizer
的情况下相同)。
=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875