• 文档 >
  • torch.nn >
  • torch.nn.utils.parametrize.cached
快捷键

torch.nn.utils.parametrize.cached

torch.nn.utils.parametrize.cached()[source][source]

允许在注册的参数化中启用缓存系统的上下文管理器。

当此上下文管理器处于活动状态时,参数化对象的值在第一次需要时计算并缓存。离开上下文管理器时,将丢弃缓存的值。

这在正向传递中多次使用参数化参数时很有用。例如,当参数化 RNN 的循环内核或共享权重时。

激活缓存的简单方法是通过将神经网络的正向传递部分进行包装

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在训练和评估过程中。还可以将使用多次参数化张量的模块部分进行包装。例如,具有参数化循环核的 RNN 循环:

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源