• 文档 >
  • torch.utils.checkpoint
快捷键

torch.utils.checkpoint

注意

检查点是通过在反向传播期间为每个检查点段重新运行前向传递段来实现的。这可能导致像 RNG 状态这样的持久状态比不使用检查点时更先进。默认情况下,检查点包括逻辑来处理 RNG 状态,使得使用 RNG(例如通过 dropout)的检查点传递具有与不使用检查点的传递相比的确定性输出。存储和恢复 RNG 状态的逻辑可能会根据检查点操作的运行时间产生适度的性能影响。如果不需要与不使用检查点的传递相比的确定性输出,则提供 preserve_rng_state=Falsecheckpointcheckpoint_sequential 以省略在每个检查点期间存储和恢复 RNG 状态。

存储逻辑保存和恢复 CPU 和另一种设备类型(通过 Tensor 参数排除 CPU 张量推断设备类型)的 RNG 状态到 run_fn 。如果有多个设备,则只保存单个设备类型的设备状态,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能会导致梯度不正确。(注意,如果检测到 CUDA 设备,则优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,则将保存和恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType 设置为其他设备)。然而,逻辑无法预测用户是否会在 run_fn 中将张量移动到新设备。因此,如果在 run_fn 中将张量移动到新设备(“新”表示不属于[当前设备+Tensor 参数中的设备]的集合),则与未检查点传递相比,确定性输出永远不会得到保证。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source][source]

检查点模型或模型的一部分。

激活检查点是一种以计算换取内存的技术。在反向传播中,不是将所需的张量保留到用于反向传播的梯度计算中,而是在检查点区域的前向计算中省略保存张量以供反向使用,并在反向传播期间重新计算它们。激活检查点可以应用于模型的任何部分。

目前有两种检查点实现,由 use_reentrant 参数确定。建议您使用 use_reentrant=False 。有关它们之间差异的讨论,请参阅以下注释。

警告

如果在反向传播过程中 function 的调用与正向传播不同,例如由于全局变量,检查点版本可能不等效,可能会导致错误抛出或导致梯度错误地静默。

警告

应显式传递 use_reentrant 参数。在版本 2.4 中,如果未传递 use_reentrant ,将引发异常。如果您正在使用 use_reentrant=True 变体,请参阅以下注意事项和潜在限制。

注意

检查点的可重入变体( use_reentrant=True )和不可重入变体( use_reentrant=False )在以下方面有所不同:

  • 非可重入检查点在所有所需中间激活被重新计算后立即停止重新计算。此功能默认启用,但可以通过 set_checkpoint_early_stop() 禁用。可重入检查点在反向传播期间始终重新计算 function 的全部内容。

  • 可重入变体在正向传播过程中不记录 autograd 图,因为它在 torch.no_grad() 下运行。非可重入版本记录 autograd 图,允许在检查点区域内对图进行反向传播。

  • 可重入检查点仅支持 torch.autograd.backward() API 进行反向传播,且没有其输入参数,而非可重入版本支持所有执行反向传播的方式。

  • 至少一个输入和输出必须对可重入变体有 requires_grad=True 。如果这个条件不满足,模型被检查点的部分将不会有梯度。非可重入版本没有这个要求。

  • 可重入版本不考虑嵌套结构中的张量(例如,自定义对象、列表、字典等)参与 autograd,而非可重入版本则考虑。

  • 重新进入的检查点不支持从计算图中断开张量的检查点区域,而非重新进入版本支持。对于重新进入版本,如果检查点段包含使用 detach()torch.no_grad() 断开张量的情况,反向传播将引发错误。这是因为 checkpoint 使得所有输出都需要梯度,这会导致模型中定义了没有梯度的张量时出现问题。为了避免这种情况,请在 checkpoint 函数外部断开张量。

参数:
  • 函数 - 描述模型或模型部分的前向传递中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递 (activation, hidden)function 应正确使用第一个输入作为 activation ,第二个输入作为 hidden

  • preserve_rng_state (布尔值,可选) – 在每次检查点期间省略存储和恢复 RNG 状态。注意,在 torch.compile 下,此标志不起作用,我们始终保留 RNG 状态。默认: True

  • use_reentrant (bool) – 指定是否使用需要可重入式自动微分激活检查点的变体。此参数应显式传递。在版本 2.5 中,如果未传递 use_reentrant ,将引发异常。如果 use_reentrant=Falsecheckpoint 将使用不需要可重入式自动微分的实现。这允许 checkpoint 支持附加功能,例如与 torch.autograd.grad 正常工作,并支持将关键字参数输入到检查点函数中。

  • context_fn (Callable, optional) – 一个返回两个上下文管理器元组的可调用对象。该函数及其重新计算将在第一个和第二个上下文管理器下运行。此参数仅在 use_reentrant=False 时受支持。

  • determinism_check (str, optional) – 指定要执行的确定性检查的字符串。默认情况下,它设置为 "default" ,该选项将重新计算的张量的形状、数据类型和设备与保存的张量进行比较。要关闭此检查,请指定 "none" 。目前只支持这两个值。如果您想看到更多的确定性检查,请提交一个问题。此参数仅在 use_reentrant=False 时受支持,如果 use_reentrant=True ,则确定性检查始终禁用。

  • debug(布尔值,可选)- 如果 True ,错误信息将包括原始正向计算以及重新计算期间运行的运算符的跟踪。此参数仅在 use_reentrant=False 支持的情况下有效。

  • args - 包含 function 输入的元组

返回:

运行 function*args 上的输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source][source]

检查点序列模型以节省内存。

顺序模型按顺序执行一系列模块/函数。因此,我们可以将此类模型划分为多个部分,并对每个部分进行检查点。除了最后一个部分外,所有部分都不会存储中间激活。每个检查点部分的输入将被保存,以便在反向传播中重新运行该部分。

警告

应显式传递 use_reentrant 参数。在版本 2.4 中,如果未传递 use_reentrant ,我们将引发异常。如果您正在使用 use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False

参数:
  • 函数 - 要按顺序运行的 torch.nn.Sequential 或模块/函数的列表(构成模型)。

  • 段落 - 在模型中创建的块的数量

  • 输入 - 一个输入到 functions 的张量

  • preserve_rng_state (布尔值,可选) - 在每个检查点期间省略存储和恢复 RNG 状态。默认: True

  • use_reentrant (布尔值) - 指定是否使用需要可重入 autograd 的激活检查点变体。此参数应显式传递。在版本 2.5 中,如果未传递 use_reentrant ,将引发异常。如果 use_reentrant=Falsecheckpoint 将使用不需要可重入 autograd 的实现。这允许 checkpoint 支持额外的功能,例如与 torch.autograd.grad 正常工作,并支持将关键字参数输入到检查点函数中。

返回:

运行 functions 顺序执行于 *inputs 的输出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[源代码][源代码] ¶

设置是否在运行时打印额外的调试信息的上下文管理器。有关 checkpoint()debug 标志的更多信息,请参阅。请注意,当设置时,此上下文管理器将覆盖传递给 checkpoint 的 debug 的值。要使用本地设置,请传递 None 给此上下文。

参数:

enabled (布尔值) – 是否打印调试信息。默认为 ‘None’。

class torch.utils.checkpoint.CheckpointPolicy(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]

用于指定在反向传播期间进行检查点策略的枚举。

支持以下策略:

  • {MUST,PREFER}_SAVE : 在前向传播过程中将保存操作的输出,在反向传播过程中不会重新计算。

  • 在前向传播过程中,操作的输出不会被保存,而是在反向传播过程中重新计算

使用 MUST_* 而不是 PREFER_* 来指示策略不应被其他子系统(如 torch.compile)覆盖

注意

总是返回 PREFER_RECOMPUTE 的策略函数与传统的检查点保存等效

每个操作返回 PREFER_SAVE 的策略函数并不等同于不使用检查点。使用此类策略将保存额外的张量,而不仅限于实际用于梯度计算的张量

class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source][source]

在选择性检查点期间传递给策略函数的上下文。

此类用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括当前策略函数调用是否在重新计算期间。

示例

>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    print(ctx.is_recompute)
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )
torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source][source]

激活检查点期间避免重新计算某些操作的辅助工具。

与 torch.utils.checkpoint.checkpoint 一起使用,以控制反向传播期间哪些操作需要重新计算。

参数:
  • policy_fn_or_list (Callable 或 List) –

    • 如果提供了策略函数,它应接受操作的操作数 SelectiveCheckpointContextOpOverload 、参数 args 和关键字参数 kwargs,并返回一个 CheckpointPolicy 枚举值,指示是否需要重新计算操作的执行。

    • 如果提供了操作列表,则相当于一个策略返回指定的操作为 CheckpointPolicy.MUST_SAVE,其他操作为 CheckpointPolicy.PREFER_RECOMPUTE。

  • allow_cache_entry_mutation (布尔值,可选) – 默认情况下,如果任何由选择性激活检查点缓存的张量被修改,则会引发错误,以确保正确性。如果设置为 True,则禁用此检查。

返回:

两个上下文管理器的元组。

示例

>>> import functools
>>>
>>> x = torch.rand(10, 10, requires_grad=True)
>>> y = torch.rand(10, 10, requires_grad=True)
>>>
>>> ops_to_save = [
>>>    torch.ops.aten.mm.default,
>>> ]
>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    if op in ops_to_save:
>>>        return CheckpointPolicy.MUST_SAVE
>>>    else:
>>>        return CheckpointPolicy.PREFER_RECOMPUTE
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> # or equivalently
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
>>>
>>> def fn(x, y):
>>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源