• 文档 >
  • torch.func >
  • torch.func API 参考 >
  • torch.func.grad_and_value
快捷键

torch.func.grad_and_value

torch.func.grad_and_value(func, argnums=0, has_aux=False)[source]

返回一个函数,用于计算梯度与原值或前向计算的元组。

参数:
  • func (Callable) – 一个 Python 函数,可以接受一个或多个参数。必须返回一个单元素 Tensor。如果指定 has_aux 等于 True ,函数可以返回一个单元素 Tensor 和其他辅助对象: (output, aux)

  • argnums (int 或 Tuple[int]) – 指定计算梯度的参数。 argnums 可以是单个整数或整数的元组。默认:0。

  • has_aux (bool) – 标志表示 func 返回一个 Tensor 和其他辅助对象: (output, aux) 。默认:False。

返回值:

计算相对于其输入和前向计算的梯度的函数。默认情况下,函数的输出是相对于第一个参数的梯度张量(张量)和原计算。如果指定 has_aux 等于 True ,则返回梯度张量(张量)和前向计算(输出辅助对象)的元组。如果 argnums 是整数的元组,则返回相对于每个 argnums 值的输出梯度张量(张量)和前向计算的元组。

返回类型:

可调用

查看 grad() 以获取示例


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源