• 文档 >
  • 自动微分包 - torch.autograd >
  • torch.autograd.Function.forward
快捷键

torch.autograd.Function.forward

static Function.forward(*args, **kwargs)[source]

定义自定义 autograd Function 的前向操作。

此函数需由所有子类重写。定义前向传播有两种方式:

使用方法 1(组合前向和上下文):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • 它必须接受一个上下文 ctx 作为第一个参数,后跟任意数量的参数(张量或其他类型)。

  • 更多详情请参阅组合或分离的前向()和 setup_context()。

使用说明 2(分离前向和 ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • 前向操作不再接受 ctx 参数。

  • 取而代之,您还必须重写 torch.autograd.Function.setup_context() 静态方法来处理设置 ctx 对象。 output 是前向操作的输出, inputs 是前向操作的输入元组。

  • 更多详情请参阅扩展 torch.autograd。

上下文可以用来存储任意数据,这些数据可以在反向传播过程中被检索。不应直接在 ctx 上存储张量(尽管为了向后兼容,目前没有强制执行)。相反,如果张量打算用于 backward (等价于 vjp )或 ctx.save_for_forward() 如果打算用于 jvp ,则应使用 ctx.save_for_backward() 保存。

返回类型:

任何


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源