分布式 Autograd 设计 ¶
本笔记将详细介绍分布式 autograd 的设计,并深入探讨其内部机制。在继续之前,请确保您熟悉 Autograd 机制和分布式 RPC 框架。
背景 ¶
假设你有两个节点和一个非常简单的模型,该模型被分配到两个节点上。这可以通过使用 torch.distributed.rpc 来实现:
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
分布式 autograd 的主要动机是能够使用 loss 在这样分布式的模型上运行反向传播,并计算并记录所有需要梯度的张量的适当梯度。
前向传播期间的 autograd 记录
PyTorch 在前向传播期间构建 autograd 图,该图用于执行反向传播。更多详情请参阅如何编码 autograd 的历史记录。
对于分布式自动微分,我们需要在正向传播过程中跟踪所有 RPC,以确保反向传播能够适当地执行。为此,当我们执行 RPC 时,我们会将 send 和 recv 函数附加到自动微分图上。
send函数附加到 RPC 的源处,其输出边指向 RPC 输入张量的自动微分函数。在反向传播过程中,该函数的输入从目的地接收,作为相应recv函数的输出。recv函数附加到 RPC 的目的地,其输入是从目的地执行的操作器使用输入张量检索的。该函数的输出梯度在反向传播过程中发送到源节点,以发送到相应的send函数。每对
send-recv分配一个全局唯一的autograd_message_id,以唯一标识该对。这在反向传播过程中查找远程节点上的相应函数时非常有用。对于 RRef,每次我们调用
torch.distributed.rpc.RRef.to_here()时,都会为涉及的张量附加一个适当的send-recv对。
例如,这是我们上面示例的 autograd 图的样子(为了简单起见,省略了 t5.sum()):
分布式 Autograd 上下文
每次使用分布式 autograd 的前向和反向传递都会分配一个唯一的 torch.distributed.autograd.context ,并且这个上下文有一个全局唯一的 autograd_context_id 。这个上下文在每个节点上按需创建。
此上下文具有以下作用:
多个节点运行分布式反向传播可能会在同一个张量上累积梯度,因此在该张量的
.grad字段中会累积来自多个分布式反向传播的梯度。这类似于在本地多次调用torch.autograd.backward()。为了提供一种分离每个反向传播梯度的方法,每个反向传播的梯度都累积在torch.distributed.autograd.context中。在正向传播过程中,我们将每个 autograd 传播的
send和recv函数存储在此上下文中。这确保我们持有 autograd 图中适当节点的引用以保持其活跃状态。此外,在反向传播期间,查找适当的send和recv函数也很容易。通常,我们还将此上下文用于存储每个分布式 autograd 传播的一些元数据。
从用户的角度来看,自动微分上下文设置如下:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
需要注意的是,你的模型的前向传播必须在分布式自动微分上下文管理器中调用,因为需要一个有效的上下文来确保所有 send 和 recv 函数被正确存储,以便在所有参与节点上运行反向传播。
分布式反向传播
在本节中,我们概述了在分布式反向传播期间准确计算依赖关系的挑战,并描述了几种算法(及其权衡),说明我们如何执行分布式反向传播。
计算依赖关系
考虑以下代码在单台机器上运行的情况
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
这就是上述代码的 autograd 图的样子:
在反向传播过程中,autograd 引擎首先执行的计算是计算 autograd 图中每个节点的依赖数量。这有助于 autograd 引擎知道图中的节点何时可以执行。括号中的数字( add(1) 和 mul(0) )表示依赖的数量。正如你所见,这意味着在反向传播过程中, add 节点需要 1 个输入,而 mul 节点不需要任何输入(换句话说,不需要执行)。本地的 autograd 引擎通过从根节点(在本例中为 d )遍历图来计算这些依赖关系。
自动微分图中某些节点可能在反向传播过程中不被执行,这对分布式自动微分构成了挑战。考虑以下使用 RPC 的代码片段。
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上述代码对应的自动微分图如下:
计算分布式自动微分图的依赖关系更具挑战性,需要一些开销(无论是计算开销还是网络通信开销)。
对于性能敏感的应用,我们可以通过假设每个 send 和 recv 函数都是反向传播的一部分来避免很多开销(大多数应用不会执行未使用的 RPC)。这简化了分布式自动微分算法,效率更高,但代价是应用程序需要了解这些限制。这个算法被称为快速模式算法,下面将详细介绍。
在一般情况下,可能没有必要要求每个 send 和 recv 函数在反向传播过程中都有效。为了解决这个问题,我们提出了一种 SMART 模式算法,该算法将在后续章节中描述。请注意,目前仅实现了 FAST 模式算法。
FAST 模式算法 ¶
该算法的关键假设是,当我们进行反向传播时,每个 send 函数的依赖项为 1。换句话说,我们假设我们将从另一个节点接收梯度。
算法如下:
我们从具有反向传播根的工人开始(所有根必须是本地的)。
查找当前分布式自动微分上下文中的所有
send函数。从提供的根和检索到的所有
send函数开始,本地计算依赖关系。计算依赖关系后,使用提供的根启动本地自动微分引擎。
当自动微分引擎执行
recv函数时,recv函数将通过 RPC 将输入梯度发送到相应的 worker。每个recv函数都知道目标 worker ID,因为它在正向传播过程中被记录。recv函数还将autograd_context_id和autograd_message_id发送到远程主机。当在远程主机接收到此请求时,我们使用
autograd_context_id和autograd_message_id查找适当的send函数。如果 worker 首次接收到针对给定
autograd_context_id的请求,它将像上面 1-3 点所述那样在本地计算依赖关系。在 6.中检索到的
send函数随后将被排队在本地自动微分引擎上对该 worker 执行。最后,我们不再在 Tensor 的
.grad字段上累积梯度,而是分别在每个 Distributed Autograd Context 上单独累积梯度。梯度存储在Dict[Tensor, Tensor]中,这基本上是一个从 Tensor 到其相关梯度的映射,该映射可以通过get_gradients()API 检索。
例如,以下是一个带有分布式自动求导的完整代码:
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
如下是带有依赖关系的分布式自动求导图(为简化起见,省略了 t5.sum()):
在上述示例中应用 FAST 模式算法如下:
在
Worker 0处,我们从根loss和send1开始计算依存关系。结果,send1被标记为依存关系 1,mul对Worker 0的依存关系也被标记为 1。现在,我们在
Worker 0上启动本地 autograd 引擎。我们首先执行mul函数,将其输出累积到 autograd 上下文中作为t4的梯度。然后,我们执行recv2,将梯度发送到Worker 1。由于这是
Worker 1第一次听说这个反向传播,它开始计算依存关系,并适当地为send2、add和recv1标记依存关系。接下来,我们将
send2放入Worker 1的本地 autograd 引擎中排队,然后它依次执行add和recv1。当执行
recv1时,它将梯度发送到Worker 0。由于
Worker 0已经为这次反向传播计算了依赖关系,它只需在本地上队列并执行send1。最后,
t1、t2和t4的梯度在分布式自动微分上下文中累积。
SMART 模式算法 ¶
该算法的详细内容仍在进行中,但关于基本思路,您可以参考 RFC 中的分布式 Autograd 算法智能模式章节。
分布式优化器 ¶
DistributedOptimizer 运作如下:
优化远程参数列表(
RRef)。这些参数也可以是包含在本地RRef中的本地参数。以
Optimizer类作为本地优化器,在所有不同的RRef所有者上运行。分布式优化器在每个工作节点上创建一个本地实例
Optimizer,并持有对它们的引用RRef。当调用
torch.distributed.optim.DistributedOptimizer.step()时,分布式优化器使用 RPC 在适当的远程工作节点上远程执行所有本地优化器。必须向torch.distributed.optim.DistributedOptimizer.step()提供分布式自动微分context_id作为输入。这是由本地优化器用来应用存储在相应上下文中的梯度的。如果多个并发分布式优化器正在更新工作节点上的相同参数,这些更新将通过锁进行序列化。
简单端到端示例 ¶
将所有内容整合在一起,以下是一个使用分布式自动微分和分布式优化器的简单端到端示例。如果将代码放入名为“dist_autograd_simple.py”的文件中,可以使用命令 MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py :运行。
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)