使用 Join 上下文管理器进行不均匀输入的分布式训练
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:Andrew Gu
备注
在 github 上查看和编辑此教程。
备注
Join
是在 PyTorch 1.10 版本中引入的作为原型功能。此 API 可能会发生变化。
在本教程中,您将看到:
Join 上下文管理器的概述。
使用
DistributedDataParallel
的上下文管理器的示例。使用上下文管理器的示例,包括
DistributedDataParallel
和ZeroRedundancyOptimizer
。向上下文管理器传递关键字参数的示例。
深入了解 Join 上下文管理器的工作原理。
示例:如何使玩具类与上下文管理器兼容。
需求
PyTorch 1.10+
什么是 Join
? ¶
在《Distributed Data Parallel 入门 - 基本用例》中,您看到了使用 DistributedDataParallel 进行数据并行训练的一般框架。这隐式地安排了每个反向传播中的 all-reduces,以同步不同 rank 的梯度。这种集体通信需要进程组中所有 rank 的参与,因此如果某个 rank 的输入较少,则其他 rank 会挂起或出错(取决于后端)。更普遍地说,任何执行每迭代同步集体通信的类都会存在这个问题。
Join
是一个上下文管理器,用于在您的每轮训练循环中围绕使用,以简化不均匀输入的训练。上下文管理器允许那些提前耗尽输入(即提前加入)的 rank 通过钩子来模拟尚未加入的 rank 所执行的集体通信。
使用 Join
与 DistributedDataParallel
¶
PyTorch 的 DistributedDataParallel 与 Join
上下文管理器无缝配合。以下是一个使用示例:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
with Join([model]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
这将产生以下输出(其中 rank 0 和 rank 1 的 print()
可能任意排序):
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
备注
在此之前,DistributedDataParallel 提供了自己的 join() 上下文管理器,该管理器在引入通用的 Join
上下文管理器之前就已经存在。在上面的示例中,使用 with Join([model]):
等同于使用 with model.join():
。现有 DistributedDataParallel.join()
的一个限制是它不允许多个参与类,例如 DistributedDataParallel
和 ZeroRedundancyOptimizer 一起使用。
使用 Join
与 DistributedDataParallel
和 ZeroRedundancyOptimizer
一起使用
Join
上下文管理器不仅与单个类一起工作,还可以与多个类一起工作。PyTorch 的 ZeroRedundancyOptimizer
也与上下文管理器兼容,因此在这里,我们检查如何修改之前的示例以同时使用 DistributedDataParallel
和 ZeroRedundancyOptimizer
:
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
optim = ZeRO(model.parameters(), Adam, lr=0.01)
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Pass both `model` and `optim` into `Join()`
with Join([model, optim]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
optim.step()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
这将产生与之前相同的结果。值得注意的是,还额外将 ZeroRedundancyOptimizer
实例传递给了 Join()
。
传递关键字参数 ¶
类可以提供关键字参数,在运行时修改上下文管理器的行为。例如, DistributedDataParallel
提供了一个参数 divide_by_initial_world_size
,该参数确定梯度是除以初始世界大小还是除以有效世界大小(即非连接的进程数)。这些关键字参数可以直接传递到上下文管理器中。
with Join([model, optim], divide_by_initial_world_size=False):
for input in inputs:
...
警告
传递给上下文管理器的关键字参数在所有参与类之间共享。这不应该是一个限制,因为我们不期望存在多个 Joinable
需要不同设置的情况。尽管如此,这也是需要注意的一点。
Join
是如何工作的? ¶
现在我们已经看到了一些如何使用 Join
上下文管理器的初步示例,让我们更深入地探讨它是如何工作的。这将为您提供对其全部功能的更深入了解,并为您制作自己的兼容类做好准备。在这里,我们将介绍 Join
类以及支持类 Joinable
和 JoinHook
。
Joinable
¶
首先,与 Join
上下文管理器兼容的类必须继承自抽象基类 Joinable
。特别是,一个 Joinable
必须实现:
join_hook(self, **kwargs) -> JoinHook
这将返回 JoinHook
实例,用于确定如何使连接进程阴影 Joinable
执行的每次迭代的集体通信。
join_device(self) -> torch.device
这将返回一个设备,供 Join
上下文管理器使用以执行集体通信,例如 torch.device("cuda:0")
或 torch.device("cpu")
。
join_process_group(self) -> ProcessGroup
这将返回由 Join
上下文管理器使用的进程组,以执行集体通信。
特别是, join_device
和 join_process_group
是确保上下文管理器可以安排已加入和未加入进程之间的集体通信的必要属性。一种用法是使用 all-reduce 计算每次迭代中未加入进程的数量。另一种用法是实现 throw_on_early_termination=True
所需的机制,我们将在下面进行解释。
DistributedDataParallel
和 ZeroRedundancyOptimizer
已经继承自 Joinable
并实现了上述方法,这就是为什么我们可以在前面的例子中直接使用它们。
Joinable
类应确保调用 Joinable
构造函数,因为它初始化一个 JoinConfig
实例,该实例由上下文管理器内部使用以确保正确性。这将被保存在每个 Joinable
作为一个字段 _join_config
。
JoinHook
¶
接下来,让我们分解一下 JoinHook
类。一个 JoinHook
为上下文管理器提供了两个入口点:
main_hook(self) -> None
此钩子由每个已加入的排名在存在尚未加入的排名的情况下反复调用。它的目的是在每次训练迭代中(例如在一个前向传递、反向传递和优化器步骤中)掩盖 Joinable
执行的集体通信。
post_hook(self, is_last_joiner: bool) -> None
当所有排名都已加入时,此钩子会被调用一次。它传递一个额外的 bool
参数 is_last_joiner
,表示该排名是否是最后加入的。此参数可能对同步有用。
为了具体说明这些钩子的样子,提供的 ZeroRedundancyOptimizer
主钩子按照常规执行优化器步骤,因为已加入的排名仍然负责更新和同步其参数的片段,而提供的 DistributedDataParallel
后钩子从最后加入的排名之一广播最终的更新模型,以确保所有排名上的模型相同。
Join
¶
最后,让我们看看这些如何与 Join
类本身相匹配。
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
如前所述的示例,构造函数接收参与训练循环的 Joinable
列表。这些应该是每个迭代中执行集体通信的类。
enable
是一个可以设置为 False
的 bool
,如果你知道不会有不均匀的输入,在这种情况下,上下文管理器将变得空洞,类似于 contextlib.nullcontext()
。这也可能会在参与 Joinable
的计算中禁用 join 相关的计算。
throw_on_early_termination
是一个可以设置为 True
的 bool
,以便每个 rank 在检测到不均匀输入时立即抛出异常。这对于不符合上下文管理器要求的案例很有用,这通常发生在有来自不同类的集体通信可能任意交错的情况下,例如在使用 DistributedDataParallel
与具有 SyncBatchNorm
层的模型时。在这种情况下,应将此参数设置为 True
,以便应用程序逻辑可以捕获异常并确定如何继续。
核心逻辑发生在
__exit__()
方法中,该方法在存在未连接的等级时循环,调用每个Joinable
的主钩子,然后一旦所有等级都连接,就调用它们的后钩子。主钩子和后钩子都是按照Joinable
传入的顺序迭代的。上下文管理器需要从未连接的进程那里获得心跳。因此,每个
Joinable
类在其每次迭代的集体通信之前都应该调用Join.notify_join_context()
。上下文管理器将确保只有第一个传入的Joinable
实际上发送了心跳。
警告
如上所述,关于 throw_on_early_termination
, Join
上下文管理器与某些类的组合不兼容。 Joinable
的 JoinHook
必须是可序列化的,因为每个钩子都必须在继续到下一个之前完全执行。换句话说,两个钩子不能重叠。此外,目前主钩子和后钩子都是按照相同的确定性顺序迭代的。如果这看起来是一个主要的限制,我们可能会修改 API 以允许自定义排序。
使 Toy 类与 Join
协同工作
由于上一节介绍了几个概念,让我们通过一个玩具示例来实际看看它们。在这里,我们将实现一个类,该类统计在其 rank 加入之前所有 rank 看到的输入数量。这应该能提供一个基本的概念,说明您如何使自己的类与 Join
上下文管理器兼容。
具体来说,以下代码让每个 rank 打印出(1)在其加入之前所有 rank 看到的输入数量和(2)所有 rank 的总输入数量。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
由于 rank 0 看到了 5 个输入,而 rank 1 看到了 6 个,因此输出结果为:
10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
一些需要强调的关键点:
每个实例在每个迭代中执行一次单次 all-reduce,因此主钩子也执行一次单次 all-reduce 以覆盖它。
Counter
类在其__call__()
方法开始时调用Join.notify_join_context()
,因为那是在其每迭代集体通信(即其 all-reduce)之前的位置。is_last_joiner
参数用于在后钩子中确定广播源。我们将
sync_max_count
关键字参数传递给上下文管理器,然后将其转发到Counter
的 join 钩子。