torch.func.grad¶
- torch.func.grad(func, argnums=0, has_aux=False)[source]¶
grad
操作符帮助计算func
对指定输入的梯度。此操作符可以嵌套以计算高阶梯度。- 参数:
func (Callable) – 一个 Python 函数,可以接受一个或多个参数。必须返回一个单元素 Tensor。如果指定
has_aux
等于True
,函数可以返回一个单元素 Tensor 和其他辅助对象:(output, aux)
。argnums (int 或 Tuple[int]) – 指定计算梯度的参数。
argnums
可以是单个整数或整数的元组。默认:0。has_aux (bool) – 标志表示
func
返回一个 Tensor 和其他辅助对象:(output, aux)
。默认:False。
- 返回值:
计算相对于其输入的梯度的函数。默认情况下,函数的输出是相对于第一个参数的梯度张量(张量)。如果指定
has_aux
等于True
,则返回梯度张量和输出辅助对象的元组。如果argnums
是整数的元组,则返回相对于每个argnums
值的输出梯度的元组。- 返回类型:
使用
grad
的示例:>>> from torch.func import grad >>> x = torch.randn([]) >>> cos_x = grad(lambda x: torch.sin(x))(x) >>> assert torch.allclose(cos_x, x.cos()) >>> >>> # Second-order gradients >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) >>> assert torch.allclose(neg_sin_x, -x.sin())
当与
vmap
组合使用时,grad
可以用于计算每个样本的梯度:>>> from torch.func import grad, vmap >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): >>> # Very simple linear model with activation >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> >>> def compute_loss(weights, example, target): >>> y = model(weights, example) >>> return ((y - target) ** 2).mean() # MSELoss >>> >>> weights = torch.randn(feature_size, requires_grad=True) >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
使用
grad
、has_aux
和argnums
的示例:>>> from torch.func import grad >>> def my_loss_func(y, y_pred): >>> loss_per_sample = (0.5 * y_pred - y) ** 2 >>> loss = loss_per_sample.mean() >>> return loss, (y_pred, loss_per_sample) >>> >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) >>> y_true = torch.rand(4) >>> y_preds = torch.rand(4, requires_grad=True) >>> out = fn(y_true, y_preds) >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
注意
使用 PyTorch
torch.no_grad
与grad
结合。案例一:在函数中使用
torch.no_grad
:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在这种情况下,
grad(f)(x)
将尊重内部torch.no_grad
。情况 2:在
torch.no_grad
上下文管理器中使用grad
:>>> with torch.no_grad(): >>> grad(f)(x)
在这种情况下,
grad
将尊重内部torch.no_grad
,但不尊重外部的一个。这是因为grad
是一个“函数转换”:其结果不应依赖于f
外部的上下文管理器的结果。