torch.autograd.function.FunctionCtx.save_for_backward¶
- FunctionCtx.save_for_backward(*tensors)[source][source]¶
保存给定的张量以供未来调用
backward()
。应最多调用一次,在
setup_context()
或forward()
方法中,并且仅与张量一起调用。所有打算在反向传播中使用的张量都应使用
save_for_backward
(而不是直接在ctx
上)保存,以防止梯度错误和内存泄漏,并启用保存的张量钩子。请参阅torch.autograd.graph.saved_tensors_hooks
。注意,如果保存了中间张量(既不是输入也不是输出),则您的自定义函数可能不支持双向求导。不支持双向求导的自定义函数应使用
@once_differentiable
装饰其backward()
方法,以便执行双向求导时引发错误。如果您想支持双向求导,可以基于输入在反向传播期间重新计算中间项,或者将中间项作为自定义函数的输出返回。有关更多详细信息,请参阅双向求导教程。在
backward()
中,保存的张量可以通过saved_tensors
属性进行访问。在将它们返回给用户之前,会进行检查以确保它们没有被用于任何修改其内容的就地操作。参数也可以
None
。这是一个空操作。查看 torch.autograd 的扩展部分以获取此方法的使用详情。
- 示例::
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * z >>> out = x * y + y * z + w * y >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z is not a tensor >>> return out >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + w) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)