自动微分力学 ¶
本笔记将概述自动微分的工作原理和记录操作。虽然不是严格必要的,但我们建议您熟悉它,因为它将帮助您编写更高效、更干净的程序,并有助于调试。
自动微分如何编码历史记录 ¶
Autograd 是一个反向自动微分系统。从概念上讲,autograd 在执行操作时记录创建数据的所有操作,为你提供一个有向无环图,其叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点的追踪,你可以使用链式法则自动计算梯度。
在内部,autograd 将这个图表示为 Function
对象(实际上是表达式)的图,这些对象可以被 apply()
用来计算评估图的输出结果。在计算前向传播时,autograd 同时执行所需的计算并构建表示计算梯度的函数的图(每个 torch.Tensor
的 .grad_fn
属性是这个图的入口点)。前向传播完成后,我们在反向传播中评估这个图来计算梯度。
注意事项很重要,即每次迭代时都会从头开始重新创建图,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代中改变图的形状和大小。在启动训练之前,你不必编码所有可能的路径——你运行的就是你要优化的。
保存的张量
一些操作需要在正向传播过程中保存中间结果,以便执行反向传播。例如,函数 将输入 保存以计算梯度。
在定义自定义 Python Function
时,你可以使用 save_for_backward()
在正向传播过程中保存张量,并在反向传播过程中使用 saved_tensors
检索它们。有关更多信息,请参阅扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()
),张量会根据需要自动保存。你可以探索(用于教育或调试目的)某个 grad_fn
保存了哪些张量,通过查找以 _saved
为前缀的属性来实现。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在之前的代码中, y.grad_fn._saved_self
指的是与 x 相同的 Tensor 对象。但这并不总是如此。例如:
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在底层,为了防止引用循环,PyTorch 在保存时会打包张量,并在读取时将其解包到不同的张量中。在这里,您通过访问 y.grad_fn._saved_result
获取的张量与 y
的张量对象不同(但它们仍然共享相同的存储空间)。
一个张量是否会被打包到不同的张量对象中,取决于它是否是其自身 grad_fn 的输出,这是一个可能会变化且用户不应依赖的实现细节。
您可以使用保存的张量的钩子来控制 PyTorch 的打包/解包方式。
非可微函数的梯度 ¶
使用自动微分进行梯度计算仅在所使用的每个基本函数都是可微的情况下有效。不幸的是,我们实际使用的许多函数都没有这个特性(例如 relu
或 sqrt
在 0
,等等)。为了尽量减少非可微函数的影响,我们按照以下规则定义基本运算的梯度:
如果函数是可微的并且因此在当前点存在梯度,则使用它。
如果函数是凸的(至少在局部),则使用最小范数的子梯度(它是最陡下降方向)。
如果函数是凹的(至少局部),则使用最小范数的超梯度(考虑-f(x)并应用前一点)。
如果函数已定义,则通过连续性在当前点定义梯度(注意这里可能存在
inf
的情况,例如sqrt(0)
)。如果可能存在多个值,则任意选择一个。如果函数未定义(
sqrt(-1)
、log(-1)
或当输入为NaN
时的大多数函数,例如),则用作梯度的值是任意的(我们可能会引发错误,但这并不保证)。大多数函数将使用NaN
作为梯度,但出于性能原因,一些函数将使用其他值(例如log(-1)
)。如果函数不是确定性映射(即它不是一个数学函数),则将其标记为不可微。这将导致在使用需要 grad 的
no_grad
环境之外的张量时出错。
局部禁用梯度计算
Python 提供了多种机制来局部禁用梯度计算:
要在整个代码块中禁用梯度,可以使用类似 no-grad 模式和推理模式这样的上下文管理器。为了更精细地排除子图从梯度计算中,可以设置张量的 requires_grad
字段。
下面,除了讨论上述机制外,我们还描述了评估模式( nn.Module.eval()
),这是一种不用于禁用梯度计算的方法,但由于其名称,常常与上述三种混淆。
设置 requires_grad
¶
requires_grad
是一个标志,默认为 false,除非用 nn.Parameter
包裹,允许对子图进行细粒度排除,以从梯度计算中排除。它在正向和反向传播中都会生效:
在正向传播过程中,只有当至少一个输入张量需要 grad 时,操作才会被记录在反向图中。在反向传播过程中( .backward()
),只有具有 requires_grad=True
的叶张量才会将梯度累积到它们的 .grad
字段中。
需要注意的是,尽管每个张量都有这个标志,但只为叶张量(没有 grad_fn
的张量,例如 nn.Module
的参数)设置它才有意义。非叶张量(具有 grad_fn
的张量)是与它们相关的反向图。因此,它们的梯度将作为中间结果被需要,以计算需要 grad 的叶张量的梯度。从这个定义中可以看出,所有非叶张量都将自动具有 require_grad=True
。
设置 requires_grad
应该是您控制哪些部分模型参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的某些部分。
要冻结模型的部分,只需将 .requires_grad_(False)
应用于您不希望更新的参数。如上所述,由于使用这些参数作为输入的计算不会记录在正向传递中,因此它们在反向传递中不会更新其 .grad
字段,因为它们一开始就不会是反向图的一部分,正如您所期望的那样。
由于这是一个非常常见的模式, requires_grad
也可以通过 nn.Module.requires_grad_()
在模块级别设置。当应用于模块时, .requires_grad_()
将对模块的所有参数生效(默认情况下,这些参数具有 requires_grad=True
)。
梯度模式
除了设置 requires_grad
之外,还可以从 Python 中选择三种梯度模式,这些模式可以影响 PyTorch 内部如何通过 autograd 处理计算:默认模式(梯度模式)、无梯度模式和推理模式,所有这些都可以通过上下文管理器和装饰器切换。
模式 |
排除操作记录在反向图中 |
跳过额外的 autograd 跟踪开销 |
在启用模式时创建的张量可以在 grad-mode 模式下后续使用 |
示例 |
---|---|---|---|---|
默认 |
✓ |
前向传播 |
||
无梯度 |
✓ |
✓ |
优化器更新 |
|
推理 |
✓ |
✓ |
数据处理、模型评估 |
默认模式(梯度模式)
“默认模式”是指在没有启用其他模式(如无梯度模式和推理模式)时我们隐式所处的模式。与“无梯度模式”相对,默认模式有时也被称为“梯度模式”。
了解默认模式最重要的地方在于,它是唯一一个 requires_grad
生效的模式。在另外两种模式中, requires_grad
总是被覆盖为 False
。
无梯度模式 ¶
在无梯度模式下,计算行为就像没有任何输入需要梯度一样。换句话说,即使在有输入带有 require_grad=True
的情况下,无梯度模式下的计算也永远不会被记录在反向图中。
当你需要执行不应被自动微分记录的操作,但之后又想使用这些计算的输出进入梯度模式时,启用无梯度模式。这个上下文管理器使得在不临时将张量设置为 requires_grad=False
,然后再回到 True
的情况下,禁用梯度变得方便。
例如,在编写优化器时,无梯度模式可能很有用:当执行训练更新时,您希望就地更新参数,而更新不被 autograd 记录。您还打算在下一个前向传递中使用更新后的参数进行计算。
torch.nn.init 中的实现也依赖于无梯度模式来初始化参数,以避免在就地更新初始化的参数时被 autograd 跟踪。
推理模式 ¶
推理模式是无梯度模式的极端版本。就像在无梯度模式中一样,推理模式中的计算不会记录在反向图中,但启用推理模式将使 PyTorch 进一步加快您的模型速度。这种更好的运行时间伴随着一个缺点:在推理模式下创建的张量在退出推理模式后无法用于任何要被 autograd 记录的计算。
当您执行的计算与 autograd 没有交互,并且您不打算在以后要被 autograd 记录的计算中使用在推理模式下创建的张量时,请启用推理模式。
建议您在不需要自动梯度跟踪的部分代码中尝试推理模式(例如数据预处理和模型评估)。如果您的用例可以直接使用,这将是一个免费的性能提升。如果您在启用推理模式后遇到错误,请检查您是否在退出推理模式后使用推理模式中创建的张量进行自动梯度记录的计算。如果在这种情况下无法避免,您始终可以切换回无梯度模式。
关于推理模式的详细信息,请参阅推理模式。
关于推理模式的实现细节,请参阅 RFC-0011-InferenceMode。
评估模式( nn.Module.eval()
)
评估模式并不是本地禁用梯度计算的机制。它之所以被包含在这里,是因为有时人们会将其误认为是这样的机制。
从功能上讲, module.eval()
(或等价于 module.train(False)
)与无梯度模式和推理模式完全正交。 model.eval()
如何影响你的模型完全取决于你模型中使用的特定模块以及它们是否定义了任何训练模式特定的行为。
如果你使用的模型依赖于 torch.nn.Dropout
和 torch.nn.BatchNorm2d
等可能根据训练模式表现不同的模块,那么你有责任调用 model.eval()
和 model.train()
,例如,为了避免在验证数据上更新 BatchNorm 的运行统计信息。
即使你不确定你的模型是否有训练模式特定的行为,也建议你在训练时始终使用 model.train()
,在评估你的模型(验证/测试)时使用 model.eval()
,因为你可能使用的模块可能会更新以在训练和评估模式中表现出不同的行为。
基于 autograd 的本地操作 ¶
在 autograd 中支持本地操作是一项艰巨的任务,我们通常不建议使用它们。autograd 的积极释放和重用缓冲区使其非常高效,而且很少情况下本地操作能显著降低内存使用量。除非你处于严重的内存压力之下,否则你可能永远不需要使用它们。
限制本地操作适用性的主要有两个原因:
本地操作可能会覆盖计算梯度所需的值。
每个就地操作都需要实现重写计算图。就地版本只是分配新的对象并保留对旧图的引用,而就地操作则需要更改所有输入的创建者,以表示此操作。这可能会很棘手,特别是如果有许多引用相同存储的张量(例如通过索引或转置创建),就地函数如果修改后的输入存储被其他任何引用所引用,将会引发错误。
原地操作正确性检查 ¶
每个张量都保留一个版本计数器,每次在操作中标记为脏时都会增加。当函数保存任何张量以进行反向传播时,它们包含的张量的版本计数器也会保存。一旦访问了 self.saved_tensors
,就会进行检查,如果它大于保存的值,则会引发错误。这确保了如果您使用就地函数且没有看到任何错误,您可以确信计算出的梯度是正确的。
多线程 Autograd
自动微分引擎负责运行所有必要的反向操作以计算反向传播。本节将描述所有可以帮助您在多线程环境中充分利用它的细节。(这仅适用于 PyTorch 1.6+,因为之前版本的行为不同。)
用户可以使用多线程代码(例如 Hogwild 训练)来训练他们的模型,并且不会在并发反向计算上阻塞,示例代码可以是:
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
注意用户应该注意的一些行为:
CPU 上的并发
当您在 CPU 上使用 Python 或 C++ API 通过多个线程运行 backward()
或 grad()
时,您期望看到额外的并发性,而不是在执行过程中按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行為)。
非确定性
如果您从多个线程并发调用 backward()
并且有共享输入(例如 Hogwild CPU 训练),则应预期出现非确定性。这可能会发生,因为参数在多个线程之间自动共享,因此多个线程可能会在梯度累积期间访问并尝试累积相同的 .grad
属性。这在技术上是不安全的,可能会导致竞态条件,结果可能无法使用。
开发具有共享参数的多线程模型的用户应考虑线程模型,并理解上述问题。
功能性 API torch.autograd.grad()
可以用来计算梯度,而不是使用 backward()
以避免非确定性。
保留图 ¶
如果自动微分图的某部分在多个线程之间共享,即先在单个线程中运行前半部分,然后在多个线程中运行后半部分,那么图的前半部分就是共享的。在这种情况下,不同线程在同一个图上执行 grad()
或 backward()
可能会存在一个线程动态销毁图的问题,另一个线程在这种情况下会崩溃。自动微分会向用户报错,类似于调用 backward()
两次而没有 retain_graph=True
,并让用户知道他们应该使用 retain_graph=True
。
自动微分节点线程安全 ¶
由于 Autograd 允许调用线程驱动其反向执行以实现并行性,因此确保在 CPU 上使用并行 backward()
调用共享部分/整个 GraphTask 时线程安全非常重要。
由于 GIL,自定义 Python autograd.Function
是自动线程安全的。对于内置的 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Function
,Autograd 引擎使用线程互斥锁来确保可能具有状态写入/读取的 Autograd 节点的线程安全。
C++钩子没有线程安全
Autograd 依赖于用户编写线程安全的 C++钩子。如果您希望钩子在多线程环境中正确应用,您需要编写适当的线程锁定代码以确保钩子是线程安全的。
复数自动微分
简短版:
当您使用 PyTorch 对定义在复数域和/或值域上的任何函数进行微分时,梯度是在假设该函数是更大实值损失函数的一部分的情况下计算的。计算出的梯度是 (注意 z 的共轭),其负值正是梯度下降算法中使用的最速下降方向。因此,存在一条使现有优化器能够直接处理复数参数的可行路径。
这种约定与 TensorFlow 的复数微分约定相匹配,但与 JAX(计算 )不同。
如果您有一个内部使用复杂运算的真实到真实函数,这里的惯例并不重要:您将始终得到与仅使用真实运算实现时相同的结果。
如果您对数学细节感兴趣,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?¶
复数可微的数学定义是将导数的极限定义推广到复数上。考虑一个函数 ,
其中 和 是两个实值变量函数, 是虚数单位。
使用导数定义,我们可以写出:
为了使这个极限存在,不仅 和 必须是实值可微的,而且 还必须满足柯西-黎曼方程。换句话说:对于实部和虚部( )计算出的极限必须相等。这是一个更严格的条件。
复可微函数通常被称为全纯函数。它们表现良好,具有所有从实值可微函数中看到的良好性质,但在优化世界中实际上没有太大用处。在研究社区中,由于复数不属于任何有序域,因此复值损失函数没有太多意义,所以优化问题中只使用实值目标函数。
结果表明,没有任何有趣的实值目标函数满足柯西-黎曼方程。因此,全纯函数的理论不能用于优化,所以大多数人因此使用 Wirtinger 微分法。
Wirtinger 微分法进入了视野...
因此,我们有了这个关于复可微性和全纯函数的伟大理论,但我们根本无法使用它,因为许多常用的函数都不是全纯的。数学家该怎么办呢?好吧,Wirtinger 观察到,即使 不是全纯的,也可以将其重写为一个总是全纯的二维函数 。这是因为 的实部和虚部可以用 和 表达:
Wirtinger 微分法建议研究 ,如果 是实可微的,则该函数保证是全纯的(另一种思考方式是将它视为坐标变换,从 到 )。这个函数有偏导数 和 。我们可以使用链式法则来建立这些偏导数与 的实部和虚部的偏导数之间的关系。
从上述方程中,我们得到:
这就是你在维基百科上能找到的经典的 Wirtinger 微积分定义。
这个变化有很多美丽的后果。
例如,柯西-黎曼方程可以简单地表述为 (也就是说,函数 可以完全用 来表示,而不需要引用 )。
另一个重要(并且有些反直觉)的结果,正如我们稍后将会看到的,是在对实值损失进行优化时,我们在进行变量更新时应采取的步长由 给出(而不是 )。
想了解更多,请参阅:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger 微积分在优化中有什么用?¶
音频和其他领域的研究人员,更常见的是,使用梯度下降来优化具有复变量的实值损失函数。通常,这些人将实部和虚部视为可以分别更新的独立通道。对于步长 和损失 ,我们可以在 中写出以下方程:
这些方程如何转换到复数空间中 ?
发生了一件非常有趣的事情:Wirtinger 微积分告诉我们,我们可以将上述复变更新公式简化为仅涉及共轭 Wirtinger 导数 ,这正好是我们优化中采取的步骤。
由于共轭 Wirtinger 导数给出了实值损失函数的正确步骤,因此 PyTorch 在对具有实值损失的函数进行微分时提供了这个导数。
PyTorch 是如何计算共轭 Wirtinger 导数的? ¶
通常,我们的导数公式以 grad_output 作为输入,表示我们已计算出的输入向量-雅可比乘积,即 ,其中 是整个计算的损失(产生真实损失)和 是我们函数的输出。这里的目的是计算 ,其中 是函数的输入。实际上,在真实损失的情况下,我们甚至可以只计算 ,尽管链式法则暗示我们还需要访问 。如果您想跳过这个推导,请查看本节最后的方程,然后跳到下一节。
让我们继续使用定义为 的 。如上所述,autograd 的梯度约定是围绕优化真实值损失函数,所以假设 是更大的真实值损失函数 的一部分。使用链式法则,我们可以写出:
(1)¶
现在利用 Wirtinger 导数定义,我们可以写出:
应当在此指出,由于 和 是真实函数,而 根据我们假设 是实值函数的一部分,因此我们有:
(2)¶
即, 等于 。
解上述方程组得到 和 。
(3) ¶
将(3)代入(1),得到:
使用(2),得到:
(4) ¶
这个最后一个方程是编写您自己的梯度的重要方程,因为它将我们的导数公式分解成一个更简单的公式,这个公式很容易手工计算。
我该如何为复杂函数编写自己的导数公式?
上面的方框方程给出了所有复函数导数的通用公式。然而,我们仍然需要计算 和 。你可以通过两种方式来做这件事:
第一种方式是直接使用 Wirtinger 导数的定义,并通过使用 和 (你可以用常规方式计算它们)来计算 和 。
第二种方法是使用变量替换技巧,将 重写为一个双变量函数 ,并通过将 和 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是全纯的,则只使用 (而 将为零)。
以函数 为例,其中 。
使用第一种方法计算 Wirtinger 导数,我们有。
使用(4),以及 grad_output = 1.0(这是在 PyTorch 中将 backward()
调用在标量输出上时的默认 grad 输出值),我们得到:
使用第二种方法计算 Wirtinger 导数,我们直接得到:
再次使用(4),我们得到 。正如你所见,第二种方法涉及的计算更少,更适合快速计算。
那么跨域函数呢?¶
一些函数将复杂输入映射到实数输出,或者相反。这些函数是(4)的特殊情况,我们可以使用链式法则推导出来:
对于 ,我们得到:
对于 ,我们得到:
保存张量的钩子
您可以通过定义一对 pack_hook
/ unpack_hook
钩子来控制保存的张量是如何打包/解包的。 pack_hook
函数应将其单个参数作为张量,但可以返回任何 Python 对象(例如另一个张量、一个元组,甚至包含文件名的字符串)。 unpack_hook
函数接受 pack_hook
的输出作为其单个参数,并应返回一个用于反向传播的张量。 unpack_hook
返回的张量只需与作为输入传递给 pack_hook
的张量具有相同的内容即可。特别是,可以忽略任何与 autograd 相关的元数据,因为它们将在解包期间被覆盖。
这样的例子是:
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
注意 unpack_hook
不应该删除临时文件,因为它可能被多次调用:临时文件应该与返回的 SelfDeletingTempFile 对象的生命周期保持一致。在上面的例子中,我们通过在 SelfDeletingTempFile 对象不再需要时关闭它来防止临时文件泄漏。
注意
我们保证 pack_hook
只会被调用一次,但 unpack_hook
可以根据反向传播的需求多次调用,我们期望它每次都返回相同的数据。
警告
禁止对任何函数的输入进行原地操作,因为这可能导致意外的副作用。PyTorch 会在修改 pack 钩子的输入时抛出错误,但不会捕获修改 unpack 钩子输入的情况。
为保存的张量注册钩子
您可以通过在 SavedTensor
对象上调用 register_hooks()
方法来为保存的张量注册一对钩子。这些对象作为 grad_fn
的属性暴露,并以 _raw_saved_
前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
一旦注册了这对钩子,就会立即调用 pack_hook
方法。每次需要通过 y.grad_fn._saved_self
或反向传播访问保存的张量时,都会调用 unpack_hook
方法。
警告
如果在释放保存的张量之后(即调用反向传播之后)保持对 SavedTensor
的引用,则调用其 register_hooks()
是被禁止的。PyTorch 大多数情况下会抛出错误,但在某些情况下可能失败,并可能导致未定义的行为。
为保存的张量注册默认钩子 ¶
或者,您可以使用上下文管理器 saved_tensors_hooks
来注册一对钩子,这些钩子将应用于在该上下文中创建的所有保存的张量。
示例:
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程局部的。因此,以下代码不会产生预期的效果,因为钩子没有通过 DataParallel。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
注意,使用这些钩子将禁用所有旨在减少 Tensor 对象创建的优化。例如:
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
没有钩子时, x
、 y.grad_fn._saved_self
和 y.grad_fn._saved_other
都指向同一个张量对象。有了钩子后,PyTorch 会将 x 打包成两个新的张量对象,这两个对象与原始的 x 共享相同的存储空间(不进行复制)。
反向钩子执行
本节将讨论不同钩子何时触发或未触发。然后讨论它们触发的顺序。将要涵盖的钩子包括:通过 torch.Tensor.register_hook()
注册到 Tensor 的反向钩子、通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的后累积梯度钩子、通过 torch.autograd.graph.Node.register_hook()
注册到 Node 的后钩子以及通过 torch.autograd.graph.Node.register_prehook()
注册到 Node 的前钩子。
特定钩子是否会触发
通过 torch.Tensor.register_hook()
注册到张量上的钩子将在计算该张量的梯度时执行。(注意,这不需要执行张量的 grad_fn。例如,如果张量作为 inputs
参数传递给 torch.autograd.grad()
,则张量的 grad_fn 可能不会执行,但注册到该张量上的钩子将始终执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到张量上的钩子是在该张量的梯度累积之后执行的,这意味着张量的 grad 字段已被设置。而通过 torch.Tensor.register_hook()
注册的钩子是在计算梯度时运行的,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册的钩子仅在反向传播结束时由 autograd 更新张量的 grad 字段后触发。因此,只能为叶张量注册 post-accumulate-grad 钩子。在非叶张量上通过 torch.Tensor.register_post_accumulate_grad_hook()
注册钩子将引发错误,即使你调用 backward(retain_graph=True)。
使用 torch.autograd.graph.Node.register_hook()
或 torch.autograd.graph.Node.register_prehook()
通过 torch.autograd.graph.Node
注册的钩子只有在它所注册的节点执行时才会触发。
一个特定的节点是否被执行可能取决于反向传播是否使用 torch.autograd.grad()
或 torch.autograd.backward()
调用。具体来说,当你在对应于传递给 torch.autograd.grad()
或 torch.autograd.backward()
作为 inputs
参数的 Tensor 的节点上注册钩子时,你应该注意这些差异。
如果你使用 torch.autograd.backward()
,无论是否指定了 inputs
参数,上述所有提到的钩子都将被执行。这是因为.backward()会执行所有节点,即使它们对应于作为输入指定的 Tensor。(注意,执行此附加节点对应于作为 inputs
传递的 Tensor 通常是不必要的,但仍然这样做。此行为可能会更改;你不应该依赖于它。)
另一方面,如果你使用 torch.autograd.grad()
,注册到对应于传递给 input
的 Tensor 的节点的反向钩子可能不会执行,因为这些节点除非有其他输入依赖于该节点的梯度结果,否则不会执行。
不同钩子触发的顺序 ¶
事件发生的顺序是:
注册到 Tensor 的钩子将被执行
如果 Node 被执行,则注册到 Node 的前置钩子将被执行。
对于保留 grad 的 Tensors,将更新
.grad
字段。节点执行(受上述规则约束)
对于累加
.grad
的叶张量,执行后累加梯度钩子执行注册到节点的后钩子(如果节点已执行)
如果同一张量或节点上注册了多个同类型的钩子,它们将按照注册的顺序执行。较晚执行的钩子可以观察到早期钩子对梯度的修改。
特殊钩子 ¶
torch.autograd.graph.register_multi_grad_hook()
是通过注册到张量上的钩子实现的。每个单独的张量钩子都按照上面定义的张量钩子顺序触发,并且在计算最后一个张量梯度时调用注册的多梯度钩子。
torch.nn.modules.module.register_module_full_backward_hook()
是通过注册到节点上实现的。在正向计算过程中,钩子被注册到模块的输入和输出的 grad_fn 上。由于一个模块可能接受多个输入并返回多个输出,因此在正向之前首先对模块的输入应用一个虚拟的自定义自动微分函数,并在正向输出之前返回模块的输出,以确保这些张量共享一个 grad_fn,然后我们可以将钩子附加到它上面。
张量钩子在张量就地修改时的行为 ¶
通常,注册到张量的钩子会接收到相对于该张量输出的梯度,其中张量的值被视为反向计算时该张量的值。
然而,如果您在张量上注册钩子,然后就地修改该张量,则之前注册的就地修改钩子同样会接收到相对于张量输出的梯度,但张量的值被视为就地修改之前的值。
如果您希望有前一种行为,您应该在所有就地修改完成后将它们注册到张量上。例如:
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解内部机制可能有所帮助,即当钩子注册到张量时,它们实际上会永久绑定到该张量的 grad_fn,因此如果该张量随后被就地修改,即使张量现在有一个新的 grad_fn,之前修改的就地钩子将继续与旧的 grad_fn 关联,例如,当自动微分引擎在图中到达该张量的旧 grad_fn 时,它们将触发。