跳过模块参数初始化 ¶
创建时间:2025 年 4 月 1 日 | 最后更新时间:2025 年 4 月 1 日 | 最后验证:未验证
简介
当创建模块时,其可学习参数将根据与模块类型关联的默认初始化方案进行初始化。例如,对于 torch.nn.Linear
模块的权重参数,其初始化来自均匀分布(-1/sqrt(in_features), 1/sqrt(in_features))。如果需要其他初始化方案,传统上需要在模块实例化后重新初始化参数:
from torch import nn
# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)
# Re-initialize weight from a different distribution.
nn.init.orthogonal_(m.weight)
在这种情况下,在构造过程中进行的初始化是浪费的计算,如果权重参数很大,可能相当复杂。
跳过初始化 ¶
现在可以在模块构建期间跳过参数初始化,避免不必要的计算。这可以通过使用 torch.nn.utils.skip_init()
函数轻松实现:
from torch import nn
from torch.nn.utils import skip_init
m = skip_init(nn.Linear, 10, 5)
# Example: Do custom, non-default parameter initialization.
nn.init.orthogonal_(m.weight)
这可以应用于任何满足以下“更新模块以支持跳过初始化”部分所述条件的模块。请注意,torch.nn 提供的所有模块都满足这些条件,因此支持跳过初始化。
更新模块以支持跳过初始化 ¶
由于 torch.nn.utils.skip_init()
的实现方式(见实现细节),一个模块要与其功能兼容,必须满足两个要求。您可以通过遵守这些要求来为您的自定义模块启用参数初始化跳过功能:
1. 模块必须在构造函数中接受一个设备 kwargs 参数,并将其传递给在构造过程中创建的任何参数或缓冲区。
2. 模块在构造函数中不得对参数或缓冲区进行任何计算,除了初始化(即 torch.nn.init 中的函数)。
以下示例演示了一个更新后的模块,该模块支持设备 kwargs 参数,通过将其传递给任何创建的参数、缓冲区或子模块:
import torch
from torch import nn
class MyModule(torch.nn.Module):
def __init__(self, foo, bar, device=None):
super().__init__()
# ==== Case 1: Module creates parameters directly. ====
# Pass device along to any created parameters.
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
# To ensure support for the meta device, avoid using ops except those in
# torch.nn.init on parameters in your module's constructor.
with torch.no_grad():
nn.init.kaiming_uniform_(self.param1)
nn.init.uniform_(self.param2)
# ==== Case 2: Module creates submodules. ====
# Pass device along recursively. All submodules will need to support
# them as well; this is the case for all torch.nn provided modules.
self.fc = nn.Linear(bar, 5, device=device)
# This also works with containers.
self.linears = nn.Sequential(
nn.Linear(5, 5, device=device),
nn.Linear(5, 1, device=device)
)
# ==== Case 3: Module creates buffers. ====
# Pass device along during buffer tensor creation.
self.register_buffer('some_buffer', torch.ones(7, device=device))
...
实现细节
在幕后, torch.nn.utils.skip_init()
函数通过两步模式实现:
# 1. Initialize module on the meta device; all torch.nn.init ops have
# no-op behavior on the meta device.
m = nn.Linear(10, 5, device='meta')
# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
# The result of this is a module instance with uninitialized parameters.
m.to_empty(device='cpu')
它通过在“meta”设备上实例化模块来工作,该设备具有张量形状信息但不分配任何存储。torch.nn.init 操作为此元设备特别实现,使其具有无操作行为。这导致参数初始化逻辑实际上被跳过。
注意,此模式仅适用于在构建时正确支持设备 kwarg 的模块,如更新模块以支持跳过初始化中所述。