• 文档 >
  • 通用连接上下文管理器
快捷键

通用连接上下文管理器 ¶

通用连接上下文管理器简化了在非均匀输入上的分布式训练。本页面概述了相关类的 API: JoinJoinable ,和 JoinHook 。有关教程,请参阅使用连接上下文管理器进行非均匀输入的分布式训练。

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source][source]

该类定义了通用的连接上下文管理器,允许在进程连接后调用自定义钩子。

这些钩子应覆盖未连接进程的集体通信,以防止挂起和错误,并确保算法的正确性。有关钩子定义的详细信息,请参阅 JoinHook

警告

上下文管理器要求每个参与者 Joinable 在其自己的每次迭代集体通信之前调用 notify_join_context() 方法,以确保正确性。

警告

上下文管理器要求所有 process_group 对象中的 JoinHook 属性都相同。如果有多个 JoinHook 对象,则使用第一个的 device 。进程组和设备信息用于检查未连接的进程,并在 throw_on_early_termination 启用时通知进程抛出异常,这两者都使用全归约。

参数:
  • 可连接项(List[Joinable])- 参与的 Joinable 的列表;它们的钩子将按给定顺序迭代。

  • 启用(bool)- 一个标志,用于启用不均匀输入检测;将 False 设置为启用将禁用上下文管理器的功能,并且仅在用户知道输入将不会不均匀时才应设置(默认: True )。

  • 抛出早期终止异常(bool)- 一个标志,用于控制检测到不均匀输入时是否抛出异常(默认: False )。

示例:

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
静态通知连接上下文(joinable)[source][source]

通知加入上下文管理器,调用进程尚未加入。

然后,如果 throw_on_early_termination=True ,检查是否检测到不均匀的输入(即是否有一个进程已经加入),如果检测到则抛出异常。

此方法应在 Joinable 对象进行每迭代集体通信之前调用。例如,应在 DistributedDataParallel 的前向传递开始时调用此方法。

在此方法中,只有第一个传递给上下文管理器的 Joinable 对象执行集体通信,对于其他对象,此方法为空。

参数:

可加入的(Joinable)- 调用此方法的 Joinable 对象。

返回:

用于 all-reduce 的异步工作句柄,用于通知上下文管理器,如果 joinable 是传递给上下文管理器的第一个参数,则进程尚未加入;否则为 None

class torch.distributed.algorithms.Joinable[source][source]

这定义了一个用于可加入类的抽象基类。

一个可加入的类(继承自 Joinable )应实现 join_hook() ,该函数返回一个 JoinHook 实例,此外还应实现 join_device()join_process_group() ,分别返回设备和进程组信息。

抽象属性 join_devicedevice ¶

返回用于执行加入上下文管理器所需的集体通信的设备。

抽象方法 join_hook(**kwargs)[source][source] ¶

返回给定 JoinableJoinHook 实例。

参数:

kwargs (dict) – 一个 dict ,包含任何关键字参数以修改运行时连接钩子的行为;所有共享相同连接上下文管理器的 Joinable 实例都转发相同的值给 kwargs

返回类型:

JoinHook

抽象属性 join_process_groupAny ¶

返回由 join 上下文管理器本身所需的集体通信所需的过程组。

class torch.distributed.algorithms.JoinHook[source][source]

这定义了一个 join 钩子,它为 join 上下文管理器提供了两个入口点。

入口点:一个主钩子,在存在未加入进程时被反复调用,以及一个后钩子,在所有进程都加入后调用一次。

实现通用连接上下文管理器的连接钩子,定义一个继承自 JoinHook 的类,并适当地重写 main_hook()post_hook()

main_hook()[source][source]

在训练迭代中存在未连接的进程以阴影集体通信时调用此钩子。

训练迭代,即在一次正向传递、反向传递和优化器步骤中。

post_hook(is_last_joiner)[source][source]

所有进程加入后调用钩子。

它传递一个额外的 bool 参数 is_last_joiner ,表示该 rank 是否是最后加入的之一。

参数:

is_last_joiner (bool) – True 如果 rank 是最后加入的之一; False 否则。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源