torch.func.hessian ¬
- torch.func.hessian(func, argnums=0)[source] ¬
通过正向-反向策略计算
func
关于索引argnum
的 arg(s) 的 Hessian。前向-反向策略(组合
jacfwd(jacrev(func))
)是性能良好的默认选择。可以通过其他组合jacfwd()
和jacrev()
,如jacfwd(jacfwd(func))
或jacrev(jacrev(func))
来计算 Hessian。- 参数:
func(函数)- 一个 Python 函数,接受一个或多个参数,其中必须有一个是 Tensor,并返回一个或多个 Tensor
argnums(整数或元组)- 可选,整数或整数元组,表示要获取 Hessian 的参数。默认:0。
- 返回值:
返回一个函数,该函数接受与
func
相同的输入,并返回func
相对于argnums
参数的 Hessian。
注意
您可能会看到此 API 因“X 运算符未实现前向模式 AD”而失败。如果是这样,请提交错误报告,我们将优先处理。另一种选择是使用
jacrev(jacrev(func))
,它具有更好的运算符覆盖范围。使用一个 R^N -> R^1 函数的基本用法给出一个 N x N 的 Hessian 矩阵:
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))