torch.nn.utils.parametrize.cached
- torch.nn.utils.parametrize.cached()[source][source]
允许在注册的参数化中启用缓存系统的上下文管理器。
当此上下文管理器处于活动状态时,参数化对象的值在第一次需要时计算并缓存。离开上下文管理器时,将丢弃缓存的值。
这在正向传递中多次使用参数化参数时很有用。例如,当参数化 RNN 的循环内核或共享权重时。
激活缓存的简单方法是通过将神经网络的正向传递部分进行包装
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在训练和评估过程中。还可以将使用多次参数化张量的模块部分进行包装。例如,具有参数化循环核的 RNN 循环:
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)