备注
点击此处下载完整示例代码
雅可比矩阵、海森矩阵、hvp、vhp 以及更多:函数变换的组成
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
计算雅可比矩阵或海森矩阵在许多非传统深度学习模型中很有用。使用 PyTorch 的常规自动微分 API( Tensor.backward()
, torch.autograd.grad
)来高效地计算这些量是困难的(或令人烦恼的)。PyTorch 的 JAX 启发式函数变换 API 提供了高效计算各种高阶自动微分量的方法。
备注
本教程需要 PyTorch 2.0.0 或更高版本。
计算雅可比矩阵
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
让我们从我们想要计算雅可比矩阵的函数开始。这是一个简单的线性函数,具有非线性激活。
def predict(weight, bias, x):
return F.linear(x, weight, bias).tanh()
让我们添加一些虚拟数据:一个权重、一个偏置和一个特征向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
让我们将 predict
视为一个函数,它将输入 x
从\(R^D \to R^D\)映射。PyTorch Autograd 计算向量-雅可比乘积。为了计算这个\(R^D \to R^D\)函数的完整雅可比矩阵,我们需要逐行计算,每次使用不同的单位向量。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
与逐行计算雅可比矩阵不同,我们可以使用 PyTorch 的 torch.vmap
函数转换来消除 for 循环并矢量化计算。我们不能直接将 vmap
应用于 torch.autograd.grad
;相反,PyTorch 提供了一个 torch.func.vjp
转换,它可以与 torch.vmap
组合:
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在后续教程中,反向模式自动微分和 vmap
的组合将给我们每个样本的梯度。在本教程中,反向模式自动微分和 vmap
的组合将给我们雅可比矩阵计算! vmap
和自动微分转换的不同组合可以给我们不同的有趣量。
PyTorch 提供了一个 torch.func.jacrev
便利函数,该函数执行 vmap-vjp
组合来计算雅可比矩阵。 jacrev
接受一个 argnums
参数,表示我们希望对其计算雅可比矩阵的参数。
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
让我们比较两种计算雅可比矩阵的性能。函数转换版本要快得多(并且随着输出数量的增加而变得更快)。
通常情况下,我们期望通过 vmap
进行向量化可以帮助消除开销,并更好地利用您的硬件。
vmap
通过将外部循环推入函数的基本操作中来实现这种魔法,以获得更好的性能。
让我们快速编写一个函数来评估性能,并处理微秒和毫秒的测量:
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后运行性能比较:
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
让我们来对上面的内容与我们的 get_perf
函数进行相对性能比较:
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
此外,将问题反过来也很容易,我们想要计算模型参数(权重、偏置)的雅可比矩阵,而不是输入
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩阵( jacrev
)与正向模式雅可比矩阵( jacfwd
)比较
我们提供了两个 API 来计算雅可比矩阵: jacrev
和 jacfwd
jacrev
使用反向模式自动微分。正如您上面所看到的,它是由我们的vjp
和vmap
变换组成的。jacfwd
使用正向模式自动微分。它被实现为我们的jvp
和vmap
变换的组合。
jacfwd
和 jacrev
可以互相替换,但它们具有不同的性能特征。
一般而言,如果您正在计算 \(R^N \to R^M\) 函数的雅可比矩阵,并且输出比输入多得多(例如,\(M > N\)),则建议使用 jacfwd
,否则使用 jacrev
。尽管有例外,但以下是对此的非严谨论证:
在反向模式自动微分中,我们是逐行计算雅可比矩阵,而在正向模式自动微分(计算雅可比-向量积)中,我们是逐列计算。雅可比矩阵有 M 行 N 列,所以如果它在某一方向上更高或更宽,我们可能更喜欢处理行数或列数较少的方法。
首先,让我们用比输出更多的输入进行基准测试:
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
然后进行相对基准测试:
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
现在是反向的 - 输出(M)比输入(N)多:
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
相对性能比较:
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
使用 functorch.hessian 进行 Hessian 计算
我们提供了一个方便的 API 来计算 Hessian: torch.func.hessiani
。Hessian 是雅可比的雅可比(或偏导数的偏导数,即二阶导数)。
这表明可以直接组合 functorch 的雅可比变换来计算 Hessian。实际上, hessian(f)
在底层只是 jacfwd(jacrev(f))
。
注意:为了提高性能:根据您的模型,您可能还想使用 jacfwd(jacfwd(f))
或 jacrev(jacrev(f))
来利用上述关于更宽矩阵与更高矩阵的规则来计算 Hessian。
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证使用 Hessian API 和使用 jacfwd(jacfwd())
的结果是否相同。
torch.allclose(hess_api, hess_fwdfwd)
批量雅可比和批量 Hessian
在上述示例中,我们一直在使用单个特征向量。在某些情况下,您可能希望对输入的批量输出求雅可比,即给定形状为 (B, N)
的输入批量和一个从\(R^N\)到\(R^M\)的函数,我们希望得到形状为 (B, M, N)
的雅可比。
做这件事最简单的方法是使用 vmap
:
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
如果你有从 (B, N) 到 (B, M) 的函数,并且确定每个输入都产生独立的输出,那么有时也可以在不使用 vmap
的情况下完成此操作,方法是先对输出求和,然后计算该函数的雅可比矩阵:
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果你有一个从 \(R^N \to R^M\) 的函数,但输入是批量的,你可以将 vmap
与 jacrev
组合起来计算批量的雅可比矩阵:
最后,批量的海森矩阵也可以类似地计算。最简单的方法是使用 vmap
来批量化海森矩阵的计算,但在某些情况下,求和技巧也适用。
计算 Hessian-向量乘积
计算 Hessian-向量乘积(hvp)的朴素方法是实现完整的 Hessian 并执行与向量的点积。我们可以做得更好:实际上,我们不需要实现完整的 Hessian 来完成这个操作。我们将介绍两种(许多)不同的计算 Hessian-向量乘积的策略:- 组合反向模式 AD 与反向模式 AD - 组合反向模式 AD 与正向模式 AD
与反向模式与反向模式相比,组合反向模式 AD 与正向模式 AD 通常是计算 hvp 的更内存高效的方式,因为正向模式 AD 不需要构建 Autograd 图并保存中间结果以供反向操作:
下面是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 前向-AD 没有覆盖您的操作,那么我们可以用反向模式 AD 与反向模式 AD 组合:
def hvp_revrev(f, primals, tangents):
_, vjp_fn = vjp(grad(f), *primals)
return vjp_fn(*tangents)
result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
assert torch.allclose(result, result_hvp_revrev[0])
脚本总运行时间:(0 分钟 0.000 秒)