• 文档 >
  • torch.func >
  • torch.func API 参考 >
  • torch.func.hessian
快捷键

torch.func.hessian ¬

torch.func.hessian(func, argnums=0)[source] ¬

通过正向-反向策略计算 func 关于索引 argnum 的 arg(s) 的 Hessian。

前向-反向策略(组合 jacfwd(jacrev(func)) )是性能良好的默认选择。可以通过其他组合 jacfwd()jacrev() ,如 jacfwd(jacfwd(func))jacrev(jacrev(func)) 来计算 Hessian。

参数:
  • func(函数)- 一个 Python 函数,接受一个或多个参数,其中必须有一个是 Tensor,并返回一个或多个 Tensor

  • argnums(整数或元组)- 可选,整数或整数元组,表示要获取 Hessian 的参数。默认:0。

返回值:

返回一个函数,该函数接受与 func 相同的输入,并返回 func 相对于 argnums 参数的 Hessian。

注意

您可能会看到此 API 因“X 运算符未实现前向模式 AD”而失败。如果是这样,请提交错误报告,我们将优先处理。另一种选择是使用 jacrev(jacrev(func)) ,它具有更好的运算符覆盖范围。

使用一个 R^N -> R^1 函数的基本用法给出一个 N x N 的 Hessian 矩阵:

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源