torch.autograd.function.FunctionCtx 设置 materialize_grads ¶
- FunctionCtx.set_materialize_grads(value)[来源][来源] ¶
设置是否将梯度张量进行显式化。默认为
True
。应仅从
setup_context()
或forward()
方法中调用此方法。如果
True
,未定义的梯度张量在调用backward()
和jvp()
方法之前将被扩展为零张量。- 示例::
>>> class SimpleFunc(Function): >>> @staticmethod >>> def forward(ctx, x): >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> return g1 + g2 # No check for None necessary >>> >>> # We modify SimpleFunc to handle non-materialized grad outputs >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> ctx.set_materialize_grads(False) >>> ctx.save_for_backward(x) >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> x, = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> if g1 is not None: # We must check for None now >>> grad_input += g1 >>> if g2 is not None: >>> grad_input += g2 >>> return grad_input >>> >>> a = torch.tensor(1., requires_grad=True) >>> b, _ = Func.apply(a) # induces g2 to be undefined