• 文档 >
  • torch.nn >
  • torch.nn.utils.parametrize.register_parametrization
快捷键

torch.nn.utils.parametrize.register_parametrization

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[source][source]

在模块中注册一个张量的参数化。

假设为简便起见使用 tensor_name="weight" 。当访问 module.weight 时,模块将返回参数化版本 parametrization(module.weight) 。如果原始张量需要梯度,反向传播将通过 parametrization 进行微分,并且优化器将相应地更新张量。

当模块首次注册参数化时,此函数将为模块添加一个类型为 ParametrizationList 的属性 parametrizations

张量 weight 上的参数化列表可通过 module.parametrizations.weight 访问。

原始张量可通过 module.parametrizations.weight.original 访问。

参数化可以通过在相同属性上注册多个参数化进行连接。

已注册的参数化的训练模式在注册时更新,以匹配宿主模块的训练模式。

参数化参数和缓冲区具有内置的缓存系统,可以使用上下文管理器 cached() 激活。

parametrization 可以可选地实现一个具有以下签名的函数。

def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

当第一次注册参数化时,该方法会在未参数化的张量上调用,以计算原始张量的初始值。如果此方法未实现,则原始张量将保持为未参数化的张量。

如果注册在张量上的所有参数化都实现了 right_inverse,则可以通过将其赋值给参数化张量来初始化它,如下例所示。

首次参数化可能依赖于多个输入。这可以通过从 right_inverse 返回张量元组来实现,如下面的 RankOne 参数化示例实现所示。

在这种情况下,无约束张量也位于 module.parametrizations.weight 下,名称为 original0original1 、……

注意

如果 unsafe=False(默认值),则将分别调用 forward 和 right_inverse 方法一次以执行一系列一致性检查。如果 unsafe=True,则当张量未参数化时将调用 right_inverse,否则不调用任何方法。

注意

在大多数情况下, right_inverse 将是一个函数,使得 forward(right_inverse(X)) == X (见右逆)。有时,当参数化不是满射时,放宽这个条件可能是合理的。

警告

如果一个参数化依赖于多个输入, register_parametrization() 将注册多个新参数。如果这种参数化是在创建优化器之后注册的,则需要手动将这些新参数添加到优化器中。参见 torch.Optimizer.add_param_group()

参数:
  • 模块(nn.Module)- 在其上注册参数化的模块

  • tensor_name (str) – 要注册参数化的参数或缓冲区的名称

  • parametrization (nn.Module) – 要注册的参数化

关键字参数:

unsafe (bool) – 一个布尔标志,表示参数化是否可以更改张量的数据类型和形状。默认:False 警告:注册时不会检查参数化的一致性。启用此标志需自行承担风险。

引发:

ValueError – 如果模块没有名为 tensor_name 的参数或缓冲区

返回类型:

模块

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T   # A is now symmetric
>>> m.weight = A  # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>>     def forward(self, x, y):
>>>         # Form a rank 1 matrix multiplying two vectors
>>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>>     def right_inverse(self, Z):
>>>         # Project Z onto the rank 1 matrices
>>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>>         # Return rescaled singular vectors
>>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源