torch.autograd.Function.forward¶
- static Function.forward(*args, **kwargs)[source]¶
定义自定义 autograd Function 的前向操作。
此函数需由所有子类重写。定义前向传播有两种方式:
使用方法 1(组合前向和上下文):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受一个上下文 ctx 作为第一个参数,后跟任意数量的参数(张量或其他类型)。
更多详情请参阅组合或分离的前向()和 setup_context()。
使用说明 2(分离前向和 ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
前向操作不再接受 ctx 参数。
取而代之,您还必须重写
torch.autograd.Function.setup_context()
静态方法来处理设置ctx
对象。output
是前向操作的输出,inputs
是前向操作的输入元组。更多详情请参阅扩展 torch.autograd。
上下文可以用来存储任意数据,这些数据可以在反向传播过程中被检索。不应直接在 ctx 上存储张量(尽管为了向后兼容,目前没有强制执行)。相反,如果张量打算用于
backward
(等价于vjp
)或ctx.save_for_forward()
如果打算用于jvp
,则应使用ctx.save_for_backward()
保存。- 返回类型: