• 文档 >
  • torch.Tensor >
  • torch.Tensor.register_post_accumulate_grad_hook
快捷键

torch.Tensor.register_post_accumulate_grad_hook

Tensor.register_post_accumulate_grad_hook(hook)[source][source]

注册在梯度累积之后运行的反向钩子。

累积张量所有梯度之后将调用钩子,这意味着该张量的.grad 字段已被更新。累积梯度后的钩子仅适用于叶子张量(没有.grad_fn 字段的张量)。在非叶子张量上注册此钩子将引发错误!

钩子应具有以下签名:

hook(param: Tensor) -> None

注意,与其他自动微分钩子不同,此钩子作用于需要梯度的张量,而不是梯度本身。钩子可以就地修改和访问其 Tensor 参数,包括其.grad 字段。

此函数返回一个句柄,该句柄具有一个 handle.remove() 方法,用于从模块中移除钩子。

注意

有关此钩子何时执行以及其执行相对于其他钩子的顺序的更多信息,请参阅反向钩子执行。由于此钩子在反向传播期间运行,它将在 no_grad 模式下运行(除非 create_graph 为 True)。如果您需要在钩子中重新启用自动微分,可以使用 torch.enable_grad()。

示例:

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源