由 PyTorch 团队

随着模型在深度、批大小和序列长度等方面扩展,激活内存成为整体内存使用中越来越重要的贡献者。为了帮助解决这个问题,PyTorch 提供了激活检查点工具,通过在需要时重新计算它们来减少保存的张量数量,以内存使用为代价换取额外的计算。

在本文中,我们将介绍激活内存的基本概念,现有激活检查点技术的高层次思想,以及一些旨在提高灵活性和提供更多优化/自动化的新技术。

在我们审视这些技术时,我们将比较这些方法如何适应速度与内存权衡图,并希望能提供一些关于如何为您的用例选择正确策略的见解。

(如果您想直接跳转到新的 API,请跳转到下面的“选择性激活检查点”和“内存预算 API”部分。)

flow diagram


激活内存基础

默认情况下,在急切模式(而不是使用 torch.compile )下,PyTorch 的 autograd 会保留中间激活以供反向计算。例如,如果在正向传递期间对张量 x 调用 sin ,autograd 必须记住 x 以在反向计算期间计算 cos(x)

flow diagram

如果在这个正向传递开始时保存这个张量 x ,它将在正向和反向阶段都保留在内存中。它只能在用于计算梯度的之后被清除,这发生在反向传递的末尾(由于执行顺序的相反)。

因此,随着你通过正向传递并执行更多操作,你会积累越来越多的激活,导致激活内存越来越多,直到它(通常)在反向开始时达到峰值(此时可以开始清除激活)。

flow diagram

在上面的图中,橙色方框代表操作,黑色箭头代表它们的张量输入和输出。穿过右侧的黑色箭头代表 autograd 保存用于反向传播的张量。

一种有用的方法是将这种默认保存行为在急切模式下以及我们即将介绍的技术进行视觉组织,这是基于它们在速度与内存之间的权衡。

flow diagram

在这个图中,理想的位置是左上角,那里你有“高”速度,但内存使用量也低。

我们首先将默认保存行为放在右上角(随着我们介绍更多其他技术的要点,我们将更详细地解释原因)。


激活检查点(AC)

激活检查点(AC)是一种在 PyTorch 中减少内存使用的流行技术。

在前向传播过程中,AC 区域内的任何操作都不会保存张量以供反向传播使用。(只有函数的输入会被保存。)在反向传播过程中,需要用于梯度计算的中间激活通过再次运行函数来重新生成。

flow diagram

在图(右侧)中,黑色方框显示了激活检查点的应用位置。与默认的即时方法(左侧)相比,这种设置导致保存的张量更少(1 个对 3 个)。

在模型的正确部分应用 AC 可以降低峰值内存,因为在内存使用通常达到峰值(反向传播开始时)时,中间激活不再在内存中实现。

在速度与内存权衡图中,AC 位于左下角。相对于贪婪模式,它减少了反向传播中节省的内存量,但需要额外的计算成本,因为需要重新计算。

flow diagram

注意,AC 的速度-内存权衡可以通过选择哪些前向传播部分进行检查点以及定义要使用的检查点区域数量进行调整。然而,实施这些更改可能需要修改您的模型结构,并且根据代码的组织方式可能很麻烦。为了本图的用途,我们假设只有一个区域被检查点;基于这个假设,AC 在权衡图中表现为一个单独的点。

还要注意,这里的“内存”并不指峰值内存使用;而是指为固定区域节省的内存量。


torch.compile 和 min-cut 分割器

另一个值得注意的方法是 torch.compile(在 PyTorch 2.0 中引入)。类似于激活检查点, torch.compile 也可以在底层执行一定程度的重新计算。具体来说,它将正向和反向计算追踪到一个单一的联合图中,然后由一个“min-cut”分割器进行处理。这个分割器使用 min-cut/max-flow 算法来分割图,使得需要保存以供反向传播的张量数量最小化。

初看之下,这听起来可能就像我们想要用于激活内存减少的方法。然而,实际情况更为复杂。默认情况下,分割器的主要目标是减少运行时间。因此,它只重新计算某些类型的操作——主要是简单、可融合的、非计算密集型操作(如点积操作)。

在速度与内存权衡图中放置“编译”...

flow diagram

它位于急切非 AC 点的左上角,正如我们预期的那样, torch.compile 将在速度和内存方面都有所改进。

另一方面,与激活检查点相比,torch.compile 在重新计算方面更为保守,将其放置在速度与内存对比图中的左上角。


选择性激活检查点 [NEW!]

而正常的检查点会重新计算所选区域中的每个操作,选择性激活检查点(SAC)是在激活检查点之上的一种额外设置,您可以通过它来对要重新计算的操作有更细粒度的控制。

这在您有某些更昂贵的操作(如 matmuls)且希望避免重新计算,但仍然希望重新计算更便宜的操作(如逐点操作)时非常有用。

flow diagram

在普通 AC(左侧)中,您会保存单个张量然后重新计算整个 AC 区域,而在 SAC(右侧)中,您可以有选择地保存特定操作(标记为红色)在区域中,从而避免重新计算。

要指定有选择地保存的内容,您可以指定一个 policy_fn。为了说明您可以使用此功能进行哪些额外的权衡,我们展示了两个简单的策略函数。

策略 1:不重新计算 matmuls:

aten = torch.ops.aten
compute_intensive_ops = [  
        aten.mm,
        aten.bmm,
        aten.addmm,
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

flow diagram

政策 2:更积极地保存任何计算密集型任务

# torch/_functorch/partitioners.py
aten = torch.ops.aten
compute_intensive_ops = [  
   aten.mm,
   aten.convolution,
   aten.convolution_backward,
   aten.bmm,
   aten.addmm,
   aten._scaled_dot_product_flash_attention,
   aten._scaled_dot_product_efficient_attention,
   aten._flash_attention_forward,
   aten._efficient_attention_forward,
   aten.upsample_bilinear2d,
   aten._scaled_mm
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

flow diagram

在速度与内存图上,SAC 被绘制为从接近 AC 到接近 Eager 的一系列点,具体取决于您选择的策略。

flow diagram

尝试一下吧!(作为 2.5 版本中的原型功能;更多信息请参阅文档+可复制粘贴的示例)

from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts

# Create a policy function that returns a CheckpointPolicy
def policy_fn(ctx, op, *args, **kwargs):
    if op in ops_to_save:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

# Use the context_fn= arg of the existing checkpoint API
out = checkpoint(
    fn, *args,
    use_reentrant=False,
    # Fill in SAC context_fn's policy_fn with functools.partial
    context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)


(仅编译)内存预算 API [新!]

如前所述,任何给定的 SAC 策略都可以表示为速度-内存权衡图上的一个点。然而,并非所有策略都是平等的。所谓的“最优”策略是那些落在帕累托曲线上的策略,例如,对于所有产生相同内存开销的策略,这个策略是使所需计算量最小化的策略。

对于使用 torch.compile 的用户,我们提供了一个内存预算 API,该 API 可以自动将 SAC 应用于编译区域,并使用用户指定的介于 0 和 1 之间的“内存预算”来提供帕累托最优策略,其中预算为 0 的行为类似于 plain-AC,预算为 1 的行为类似于默认的 torch.compile。

flow diagram

下面是一些在 Transformer 模型上的实际结果:

flow diagram

我们观察到,通过仅重新计算点积操作,内存减少了 50%,随着你重新计算越来越多的 matmuls,这种下降趋势会持续。注意力是最昂贵的,因此你通常希望最后才重新计算这些操作。

尝试一下吧!(2.4 版本中作为实验性功能提供;更多信息请参阅此注释块)

torch._dynamo.config.activation_memory_budget = 0.5

out = torch.compile(fn)(inp)

结论

flow diagram

总结来说,PyTorch 中的激活检查点技术提供了多种平衡内存和计算需求的方法,从简单的基于区域的检查点到更选择性和自动化的方法。通过选择最适合您模型结构和资源限制的选项,您可以在计算上做出可接受的权衡的同时实现显著的内存节省。

致谢

我们想感谢 Meta 的 xformers 团队,包括 Francisco Massa,他们为选择性激活检查点的原始版本做出了贡献。