分布式优化器 ¶
警告
当使用 CUDA 张量时,目前不支持分布式优化器
torch.distributed.optim
暴露了 DistributedOptimizer,它接受一个远程参数列表( RRef
)并在参数所在的工人节点上本地运行优化器。分布式优化器可以使用任何本地优化器基类来在每个工人节点上应用梯度。
- class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source][source]¶
DistributedOptimizer 接收分散在各个工作进程中的参数的远程引用,并为每个参数在本地应用给定的优化器。
该类使用
get_gradients()
来检索特定参数的梯度。对
step()
的并发调用,无论是来自同一客户端还是不同客户端,将在每个工作进程中序列化 - 因为每个工作进程的优化器一次只能处理一组梯度。然而,不能保证整个前向-反向-优化器序列将按顺序为单个客户端执行。这意味着应用的梯度可能不对应于给定工作进程上执行的最新的前向传递。此外,也不能保证跨工作进程的顺序。DistributedOptimizer 默认通过 TorchScript 启用本地优化器,因此优化器更新不会因多线程训练(例如分布式模型并行)中的 Python 全局解释器锁(GIL)而被阻塞。此功能目前对大多数优化器已启用。您还可以按照 PyTorch 教程中的说明为您的自定义优化器启用 TorchScript 支持。
- 参数:
optimizer_class (optim.Optimizer) – 在每个工作节点上实例化的优化器类。
params_rref (list[RRef]) – 要优化的本地或远程参数的 RRef 列表。
args – 要传递给每个工作节点上优化器构造函数的参数。
kwargs – 将传递给每个工作进程的优化器构造函数的参数。
- 示例::
>>> import torch.distributed.autograd as dist_autograd >>> import torch.distributed.rpc as rpc >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> >>> with dist_autograd.context() as context_id: >>> # Forward pass. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> loss = rref1.to_here() + rref2.to_here() >>> >>> # Backward pass. >>> dist_autograd.backward(context_id, [loss.sum()]) >>> >>> # Optimizer. >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> [rref1, rref2], >>> lr=0.05, >>> ) >>> dist_optim.step(context_id)
- step(context_id)[源码][源码] ¶
执行单个优化步骤。
这将在包含要优化的参数的每个工作进程中调用
torch.optim.Optimizer.step()
,并且将阻塞,直到所有工作进程返回。提供的context_id
将用于检索相应的context
,其中包含应应用于参数的梯度。- 参数:
context_id – 应运行优化器步骤的自动微分上下文 ID。
- class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[source][source]¶
包装任意
torch.optim.Optimizer
并运行后局部 SGD,此优化器在每一步运行局部优化器。在预热阶段之后,在应用局部优化器后定期平均参数。- 参数:
optim (Optimizer) – 局部优化器。
averager (ModelAverager) – 运行 post-localSGD 算法的模型平均实例。
示例:
>>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( >>> PostLocalSGDState, >>> post_localSGD_hook, >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # Register a post-localSGD communication hook. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step()
- load_state_dict(state_dict)[source][source]¶
这与
torch.optim.Optimizer
load_state_dict()
相同,但还会将模型平均器的步长值恢复到提供的state_dict
中保存的值。如果
state_dict
中没有"step"
条目,将引发警告并将模型平均器的步长初始化为 0。
- class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source][source]¶
将任意
optim.Optimizer
包装并跨组内 rank 分配其状态。分享方式如 ZeRO 所述。
每个 rank 的本地优化器实例仅负责更新大约
1 / world_size
参数,因此只需要保留1 / world_size
优化器状态。在本地更新参数后,每个 rank 将广播其参数到所有其他节点,以保持所有模型副本处于相同状态。ZeroRedundancyOptimizer
可以与torch.nn.parallel.DistributedDataParallel
结合使用,以减少每个 rank 的峰值内存消耗。使用排序贪心算法在每个层中打包一定数量的参数。每个参数只属于一个层,不会在层之间分配。分区是任意的,可能与参数注册或使用顺序不匹配。
- 参数:
params (
Iterable
) – 一个Iterable
的torch.Tensor
或dict
,包含所有参数,这些参数将在层之间进行分片。- 关键字参数:
optimizer_class (
torch.nn.Optimizer
) – 本地优化器的类。process_group (
ProcessGroup
,可选) –torch.distributed
ProcessGroup
(默认:dist.group.WORLD
由torch.distributed.init_process_group()
初始化)。parameters_as_bucket_view (bool, optional) – 如果
True
,参数将被打包到桶中以加快通信,并且param.data
字段指向不同偏移量的桶视图;如果False
,每个单独的参数将分别通信,并且每个params.data
保持完整(默认:False
)。overlap_with_ddp (bool, optional) – 如果
True
,step()
将与DistributedDataParallel
的梯度同步重叠;这需要(1)一个用于optimizer_class
参数的功能优化器或具有功能等价物的优化器,以及(2)注册由ddp_zero_hook.py
中的函数之一构建的 DDP 通信钩子;参数将被打包到与DistributedDataParallel
中的那些匹配的桶中,这意味着parameters_as_bucket_view
参数将被忽略。如果False
,step()
在反向传播后(按正常方式)独立运行。 (默认:False
)**默认值 – 任何后续的参数,将转发到本地优化器。
示例:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer
要求所有传入的参数都是相同的密集类型。警告
如果您传递
overlap_with_ddp=True
,请注意以下事项:由于当前重叠DistributedDataParallel
与ZeroRedundancyOptimizer
的实现方式,前两次或三次训练迭代中,优化器步骤不会执行参数更新,这取决于是否static_graph=False
或static_graph=True
。这是因为它需要DistributedDataParallel
使用的梯度桶策略信息,该信息在static_graph=False
的情况下直到第二次前向传递才最终确定,如果是static_graph=True
,则直到第三次前向传递才最终确定。为了调整这一点,一个选项是在前面添加虚拟输入。警告
ZeroRedundancyOptimizer 是实验性的,可能会更改。
- 添加参数组(param_group)[源][源]
向
Optimizer
的param_groups
添加一个参数组。这在微调预训练网络时可能很有用,因为冻结的层可以被设置为可训练的,并在训练过程中添加到
Optimizer
中。- 参数:
param_group (dict) – 指定要优化的参数和组特定的优化选项。
警告
此方法处理更新所有分区上的碎片,但需要在所有 rank 上调用。在 rank 的子集上调用此方法会导致训练挂起,因为通信原语依赖于管理的参数,并期望所有 rank 参与同一组参数。
- consolidate_state_dict(to=0)[source][source]¶
在目标 rank 上合并一个
state_dict
s 列表(每个 rank 一个)。- 参数:
to(int)- 接收优化器状态的 rank(默认:0)。
- 抛出异常:
运行时错误 - 如果在
overlap_with_ddp=True
实例完全初始化之前调用此方法,这发生在DistributedDataParallel
梯度桶重建之后。
警告
需要在所有进程中调用此方法。
- 属性 join_devicedevice ¶
返回默认设备。
- join_hook(**kwargs)[source][source]¶
返回 ZeRO 连接钩子。
通过在优化器步骤中覆盖集体通信,它使训练可以在不均匀的输入上进行。
在调用此钩子之前,梯度必须被正确设置。
- 参数:
kwargs (dict) – 一个
dict
包含任何关键字参数以修改运行时连接钩子的行为;所有共享相同连接上下文管理器的Joinable
实例都转发相同的kwargs
值。
此钩子不支持任何关键字参数;即
kwargs
未使用。
- 属性 join_process_groupAny ¶
返回进程组。
- 加载状态字典(state_dict)[source][source] ¶
从输入中加载给定排名的状态
state_dict
,如有需要更新本地优化器。- 参数:
state_dict (dict) – 优化器状态;应该是一个从调用
state_dict()
返回的对象。- 抛出异常:
运行时错误 - 如果在
overlap_with_ddp=True
实例完全初始化之前调用此方法,这发生在DistributedDataParallel
梯度桶重建之后。