torch.func.grad_and_value¶
- torch.func.grad_and_value(func, argnums=0, has_aux=False)[source]¶
返回一个函数,用于计算梯度与原值或前向计算的元组。
- 参数:
func (Callable) – 一个 Python 函数,可以接受一个或多个参数。必须返回一个单元素 Tensor。如果指定
has_aux
等于True
,函数可以返回一个单元素 Tensor 和其他辅助对象:(output, aux)
。argnums (int 或 Tuple[int]) – 指定计算梯度的参数。
argnums
可以是单个整数或整数的元组。默认:0。has_aux (bool) – 标志表示
func
返回一个 Tensor 和其他辅助对象:(output, aux)
。默认:False。
- 返回值:
计算相对于其输入和前向计算的梯度的函数。默认情况下,函数的输出是相对于第一个参数的梯度张量(张量)和原计算。如果指定
has_aux
等于True
,则返回梯度张量(张量)和前向计算(输出辅助对象)的元组。如果argnums
是整数的元组,则返回相对于每个argnums
值的输出梯度张量(张量)和前向计算的元组。- 返回类型:
查看
grad()
以获取示例