• 文档 >
  • 毕业审核机制
快捷键

毕业审核机制 ¶

本笔记概述了 gradcheck()gradgradcheck() 函数的工作原理。

它将涵盖实值和复值函数的前向和后向模式自动微分,以及高阶导数。本笔记还涵盖了 gradcheck 的默认行为以及传递 fast_mode=True 参数的情况(以下称为快速 gradcheck)。

符号和背景信息 ¶

在本笔记中,我们将使用以下约定:

  1. xxyyaabbvvuuururuiui 是实值向量,而 zz 是一个复值向量,它可以表示为两个实值向量 z=a+ibz = a + i b

  2. NNMM 是两个整数,我们将分别用它们表示输入空间和输出空间的维度。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我们的基本实值到实值的函数,满足 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我们的基本复值到实值的函数,满足 y=g(z)y = g(z)

对于简单的实到实的情况,我们将其写作 JfJ_fff 大小为 M×NM \times N 的雅可比矩阵。这个矩阵包含所有偏导数,其中位置 (i,j)(i, j) 的项包含 yixj\frac{\partial y_i}{\partial x_j} 。然后反向模式自动微分计算,对于一个大小为 vv 的给定向量 MM ,计算量 vTJfv^T J_f 。另一方面,正向模式自动微分计算,对于一个大小为 NN 的给定向量 uu ,计算量 JfuJ_f u

对于包含复数值的函数,情况要复杂得多。这里只提供概要,完整描述请见复数自动微分。

满足复数可微(柯西-黎曼方程)的约束对于所有实值损失函数来说过于严格,所以我们选择了使用 Wirtinger 微积分。在 Wirtinger 微积分的基本设置中,链式法则需要访问 Wirtinger 导数(以下称为 WW )和共轭 Wirtinger 导数(以下称为 CWCW )。 WWCWCW 都需要传播,因为在一般情况下,尽管它们的名字如此,一个并不是另一个的复共轭。

为了避免传播两个值,对于向后模式 AD,我们始终假设正在计算导数的函数要么是实值函数,要么是更大的实值函数的一部分。这个假设意味着我们在反向传播过程中计算的所有中间梯度都与实值函数相关联。在实践中,当进行优化时,这个假设并不具有约束力,因为此类问题需要实值目标(因为复数没有自然排序)。

在这个假设下,使用 WWCWCW 定义,我们可以证明 W=CWW = CW^* (我们在这里使用 * 表示复共轭)因此只需要将两个值中的一个“反向传播”通过图,另一个可以很容易地恢复。为了简化内部计算,PyTorch 使用 2CW2 * CW 作为反向传播的值,并在用户请求梯度时返回。与实数情况类似,当输出实际上在 RM\mathcal{R}^M 时,向后模式 AD 不计算 2CW2 * CW ,而只计算给定向量 vRMv \in \mathcal{R}^MvT(2CW)v^T (2 * CW)

对于前向模式的 AD,我们使用类似的逻辑,在这种情况下,假设该函数是更大函数的一部分,其输入为 R\mathcal{R} 。基于这个假设,我们可以做出类似的断言,每个中间结果都对应一个输入为 R\mathcal{R} 的函数。在这种情况下,使用 WWCWCW 定义,我们可以证明中间函数的 W=CWW = CW 。为了确保前向和后向模式在一维函数的基本情况下计算相同的量,前向模式也计算 2CW2 * CW 。与实数情况类似,当输入实际上在 RN\mathcal{R}^N 时,前向模式 AD 不计算 2CW2 * CW ,而只计算给定向量 uRNu \in \mathcal{R}^N(2CW)u(2 * CW) u

默认后向模式 gradcheck 行为

实数到实数函数

要测试一个函数 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y ,我们以两种方式重建大小为 M×NM \times N 的完整雅可比矩阵 JfJ_f :解析和数值。解析版本使用我们的后向模式 AD,而数值版本使用有限差分。然后逐元素比较重建的雅可比矩阵是否相等。

默认真实输入数值评估

如果我们考虑一维函数的基本情况( N=M=1N = M = 1 ),那么我们可以使用维基百科文章中的基本有限差分公式。我们使用“中心差分”以获得更好的数值特性:

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

该公式可以很容易地推广到多个输出( M>1M \gt 1 ),通过使 yx\frac{\partial y}{\partial x} 成为一个大小为 M×1M \times 1 的列向量,如 f(x+eps)f(x + eps) 。在这种情况下,上述公式可以原样重用,并且仅通过两次评估用户函数(即 f(x+eps)f(x + eps)f(xeps)f(x - eps) )即可近似整个雅可比矩阵。

处理多个输入( N>1N \gt 1 )的情况计算上更昂贵。在这种情况下,我们逐个遍历所有输入,并对 xx 的每个元素依次应用 epseps 扰动。这允许我们按列重建 JfJ_f 矩阵。

默认真实输入分析评估 ¶

对于分析评估,我们使用上述事实,即反向模式自动微分计算 vTJfv^T J_f 。对于只有一个输出的函数,我们只需使用 v=1v = 1 通过单次反向传递恢复完整的雅可比矩阵。

对于具有多个输出的函数,我们求助于一个循环,该循环遍历输出,其中每个 vv 都是一个对应于每个输出的 one-hot 向量,依次排列。这允许逐行重建 JfJ_f 矩阵。

复数到实数函数 ¶

为了测试函数 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to yz=a+ibz = a + i b ,我们重构包含 2CW2 * CW 的(复值)矩阵。

默认复数输入数值评估

考虑首先的简单情况 N=M=1N = M = 1 。我们知道,从(本研究的第 3 章)中,我们有:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

注意,在上面的方程中, ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 导数。为了对这些进行数值评估,我们使用上述实数到实数的计算方法。这允许我们计算 CWCW 矩阵,然后将其乘以 22

注意,截至编写代码时,该代码以略微复杂的方式计算此值:

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

默认的复杂输入分析评估 ¶

由于反向模式自动微分(AD)已经精确计算了 CWCW 导数的两倍,所以我们在这里简单地使用与实数到实数情况相同的技巧,当存在多个实数输出时,逐行重建矩阵。

复数输出函数 ¶

在这种情况下,用户提供的函数不遵循 autograd 的假设,即我们计算反向自动微分(AD)的函数是实值的。这意味着直接使用 autograd 对这个函数定义不明确。为了解决这个问题,我们将测试函数 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M (其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C} )替换为两个函数: hrhrhihi ,使得:

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

其中 qPq \in \mathcal{P} 。然后,我们使用上述的实数到实数或复数到实数的案例,根据 P\mathcal{P}hrhrhihi 进行基本的 gradcheck。

注意,截至编写代码时,代码并没有显式地创建这些函数,而是通过手动传递 grad_out\text{grad\_out} 参数到不同的函数,通过 realrealimagimag 函数手动执行链式法则。当 grad_out=1\text{grad\_out} = 1 时,我们考虑 hrhr 。当 grad_out=1j\text{grad\_out} = 1j 时,我们考虑 hihi

快速反向模式 gradcheck

虽然上述 gradcheck 的表述很棒,但为了确保正确性和可调试性,它非常慢,因为它会重建完整的雅可比矩阵。本节介绍了一种以更快的速度执行 gradcheck 的方法,同时不影响其正确性。当检测到错误时,可以通过添加特殊逻辑来恢复可调试性。在这种情况下,我们可以运行默认版本,重建完整的矩阵,以便向用户提供详细信息。

这里的高级策略是找到一个可以通过数值和解析方法高效计算的标量量,并且能够很好地代表慢速 gradcheck 计算的全矩阵,从而确保能够捕捉到雅可比矩阵中的任何差异。

实际到实际的快速 gradcheck

我们在这里想要计算的标量量是对于给定的随机向量 vRMv \in \mathcal{R}^M 和随机单位范数向量 uRNu \in \mathcal{R}^NvTJfuv^T J_f u

对于数值评估,我们可以高效地计算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然后我们计算这个向量与 vv 的点积,以获得感兴趣的标量值。

对于解析版本,我们可以使用反向模式自动微分来直接计算 vTJfv^T J_f 。然后我们与 uu 进行点积,以获得期望值。

复数到实数的函数的快速梯度检查

与真实到真实的情况类似,我们希望对完整矩阵进行降维。但是, 2CW2 * CW 矩阵是复数矩阵,因此在这种情况下,我们将与复数标量进行比较。

由于在数值情况下对我们可以高效计算的内容存在一些限制,并且为了将数值评估的数量保持在最低,我们计算以下(尽管令人惊讶)的标量值:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^MurRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速复数输入数值评估

我们首先考虑如何使用数值方法计算 ss 。为此,考虑到我们正在考虑 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to yz=a+ibz = a + i b ,以及 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}) ,我们将其重新写为如下:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在这个公式中,我们可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像实数到实数的快速版本一样进行评估。一旦这些实数值被计算出来,我们就可以重建右侧的复数向量,并与实数值的 vv 向量进行点积。

快速复数输入分析评估

对于分析情况,事情要简单得多,我们将公式重新写为:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我们可以利用向后模式 AD 为我们提供的高效计算 vT(2CW)v^T (2 * CW) 的方法,然后对实部与 urur 进行点积,对虚部与 uiui 进行点积,最后重构最终的复数标量 ss

那为什么不使用复数 uu 呢?

在这一点上,你可能想知道为什么我们没有选择复数 uu ,而是直接进行了 2vTCWu2 * v^T CW u' 的缩减。为了深入探讨这个问题,在本段中,我们将使用 uu 的复数版本,记为 u=ur+iuiu' = ur' + i ui' 。使用这样的复数 uu' ,问题是当进行数值评估时,我们需要计算:

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

这将需要四次实数到实数的有限差分评估(比上述方法多一倍)。由于这种方法没有更多的自由度(与实值变量数量相同)并且我们试图获得最快的评估,所以我们使用了上述的另一种公式。

复杂输出函数的快速 gradcheck ¶

就像在慢速情况下一样,我们考虑两个实值函数,并为每个函数使用上面的适当规则。

Gradgradcheck 实现 ¶

PyTorch 还提供了一个验证二阶梯度的实用工具。这里的目的是确保反向实现也是正确可微分的,并计算出正确的结果。

该功能通过考虑函数 F:x,vvTJfF: x, v \to v^T J_f 实现,并在该函数上使用上面定义的 gradcheck。注意,在这种情况下, vv 只是一个与 f(x)f(x) 相同类型的随机向量。

通过使用相同的函数 FF 上的 gradcheck 的快速版本,实现了 gradgradcheck 的快速版本。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源