• 文档 >
  • 模块代码 >
  • torch >
  • torch.nn.utils.parametrize
快捷键

torch.nn.utils.parametrize 的源代码

# mypy: 允许未类型化装饰器
# mypy: 允许未类型化定义
导入 集合
导入 copyreg
来自 collections.abc 导入 序列
来自 contextlib 导入 contextmanager
来自 复制 导入 深拷贝
来自 打字 导入 可选, 联合

导入 火炬
来自 火炬 导入 张量
来自 torch.__future__ 导入 转换时获取交换模块参数
来自 torch.nn.modules.container 导入 模块, ModuleDict, 模块列表
来自 torch.nn 参数 导入 参数
来自 torch.utils._python_dispatch 导入 可追踪包装子类


__all__ = [
    缓存,
    参数化列表,
    "注册参数化",
    "是否参数化",
    "移除参数化",
    "参数化之前的类型",
    "参数化及参数传递",
]

缓存已启用 = 0
缓存: 字典[元组[int, 字符串] 可选[张量]] = {}


[文档]上下文管理器 def 缓存(): r"""上下文管理器,用于启用与 :func:`register_parametrization` 注册的参数化注册的缓存系统。""" 参数化对象的值在第一次需要时进行计算和缓存。当上下文管理器处于活动状态时。 缓存的值在离开上下文管理器时将被丢弃。 当上下文管理器处于活动状态时,它们是必需的。当上下文管理器处于活动状态时。 这在多次前向传递中使用参数化参数时很有用。 例如,当对 RNN 的循环核进行参数化或共享权重时。 激活缓存的简单方法是将神经网络的正向传递包装起来。 激活缓存的简单方法是将神经网络的正向传递包装起来。 .. 代码块 :: python 导入 torch.nn.utils.parametrize 作为 P ... 使用 P.cached() output = 模型(inputs) 在训练和评估中。也可以将模块中使用到的参数化张量的部分进行包装。 例如,具有参数化循环核的 RNN 的循环可以多次执行。 参数化循环核: .. 代码块 :: python 使用 P.cached() 对于 x in xs: out_rnn = self.rnn_cell(x, out_rnn) """ 全局变量 _cache 全局变量 _cache_enabled _cache_enabled += 1 try: yield finally: _cache_enabled -= 1 如果没有 _cache_enabled: _cache = {}
def _register_parameter_or_buffer(模块
, 名称, X): 如果 isinstance(X, 参数): 模块.注册参数(名称, X) 否则: 模块.注册缓冲区(名称, X) def _maybe_set(目标: 张量, : 张量) -> : 应该交换 = ( 在转换时获取交换模块参数() 可追踪包装子类(目标) ) 如果 应该交换: 如果 isinstance(目标, 参数) not isinstance(, 参数): = 参数(, 需要梯度=目标.需要梯度) 火炬.工具.交换张量(目标, ) 否则: 目标.集合() # type: ignore[call-overload]
[文档] 参数化列表(模块列表): r"""一个按顺序存储和管理参数化 :class:`torch.nn.Module` 的原始参数或缓冲区的容器。 这是 `module.parametrizations[tensor_name]` 的类型,当 `module[tensor_name]` 使用 `:func:`register_parametrization` 进行参数化时。 如果第一个注册的参数化具有返回一个张量的 `right_inverse` 或 没有具有 `right_inverse`(在这种情况下,我们假设 `right_inverse` 是恒等函数), 它将使用 ``original`` 名称存储张量。 如果它有一个返回多个张量的 ``right_inverse``,这些张量将被注册为 ``original0``、``original1``、... .. 警告:: 此类在 :func:`register_parametrization` 内部使用。它有文档说明 为了完整性,它不应被用户实例化。 Args: 模块(序列):表示参数化的模块序列 原始(参数或张量):被参数化的参数或缓冲区 不安全(布尔值):表示参数化是否安全的布尔标志 可能会改变张量的数据类型和形状。默认:`False` 警告:注册时不会检查参数化的一致性。 请自行承担启用此标志的风险。 "源代码" 原始: 张量 不安全: 布尔类型 def 初始化( self, 模块: 序列[模块] 原始: 联盟[张量, 参数] 不安全: 布尔类型 = 错误, ) -> : 我们需要这样要求,因为我们需要区分不同的第一次参数化。 # 此类不应抛出异常,除非从外部使用此类 如果 长度(模块) == 0: 提升 ValueError("参数化列表需要一个或多个模块。") 超级().初始化(模块) self.不安全 = 不安全 简单的话: # module.weight 必须保持其 dtype 和 shape。 # 此外,如果没有 right_inverse 或者 right_inverse 返回一个 tensor, # 这应该与原始 tensor 具有相同的 dtype # 我们检查以下不变量是否成立: X = 模块权重 Y = param.right_inverse(X) 断言 isinstance(Y, Tensor) 或 # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) # # 一致性检查 # # 断言 X.dtype == Z.dtype and X.shape == Z.shape # # 允许使用 set_ 来能够 # # 将数据移动到/从原始张量而不改变其 id(优化器使用它来跟踪参数) # # 如果 Y 是 Tensor 类型 # # 如果 Y 是张量 # 断言 X 的数据类型等于 Y 的数据类型 # 以下我们使用 original = X, new = Y # 原始形状 = 原始.形状 原始数据类型 = 原始.dtype # 计算新的 火炬.不梯度(): = 原始 模块 反转(self): # type: ignore[call-overload] 如果 有属性(模块, "右逆"): try: = 模块.右逆() # type: ignore[operator] 除了 不支持的操作异常: 通过 # 否则,或者如果它抛出异常,我们假设右逆是恒等变换 如果 not isinstance(, 张量) not isinstance(, 序列): 提升 ValueError( "'right_inverse'必须返回一个 Tensor 或一个 Tensor 序列(列表、元组等)。" f"已获取"{类型().__name__}" ) # 设置原始张量的数量 self.is_tensor = isinstance(, 张量) self.ntensors = 1 如果 self.is_tensor 否则 长度() # 注册张量(们) 如果 self.is_tensor: 如果 原始.dtype != .数据类型: 提升 ValueError( "当 `right_inverse` 输出一个张量时,它可能不会改变数据类型。输入文本翻译为简体中文为:\n" f"original.dtype: {原始.数据类型}输入文本翻译为简体中文为:\n" fright_inverse(original).dtype:{.数据类型}" ) 将 original 设置为 original,以便用户无需在优化器中重新注册参数 手动设置 火炬.不梯度(): _maybe_set(原始, ) 注册参数或缓冲区(self, 原始, 原始) 否则: i, 原始 i 列举(): 如果 not isinstance(原始 i, 张量): 提升 ValueError( "'right_inverse'必须返回一个张量或张量序列 " "(列表、元组...)。" f"获取序列的元素 "{i}的类型 {类型(原始的).__name__} ) 如果原始张量是一个需要梯度的参数,我们期望用户 在注册参数化之后将新参数添加到优化器中 (这已在文档中说明) 如果 isinstance(原始, 参数): 原始的 = 参数(原始的, 原始.需要梯度) 原始的.需要梯度_(原始.需要梯度) 注册参数或缓冲区(self, f原始{i}", 原始 i) 如果 not self.不安全: # 一致性检查: # 由于 f : A -> B, 右逆 : B -> A, Z 和原始应位于 B 中 # Z = 原始矩阵的右逆矩阵 Z = self() 如果 not isinstance(Z, 张量): 提升 ValueError( f"参数化必须返回一个张量。得到"{类型(Z).__name__} ) 如果 Z.dtype != 原始数据类型: 提升 ValueError( "注册参数化可能不会改变张量的数据类型,除非启用 `unsafe` 标志。"输入文本翻译为简体中文为:\n" f"未参数化的数据类型:"{原始数据类型}输入文本翻译为简体中文为:\n" f"参数化的数据类型:"{Z.数据类型}" ) 如果 Z.形状 != 原始形状: 提升 ValueError( "注册参数化可能不会改变张量的形状,除非启用 `unsafe` 标志。"输入文本翻译为简体中文为:\n" f"未参数化的形状:"{original_shape}输入文本翻译为简体中文为:\n" f"参数化形状:"{Z.shape}" )
[文档] def 右逆(self, : 张量) -> : r调用参数化中的 `right_inverse` 方法,按照逆注册顺序。 然后,如果 `right_inverse` 输出一个张量,则将其存储在 `self.original` 中; 如果它输出多个,则存储在 `self.original0`、`self.original1`、... 中。 Args: 值(Tensor):初始化模块的值。 "源代码" 函数中所有的异常几乎永远不会抛出。 例如,如果 right_inverse 函数在给定不同输入时返回不同的数据类型,可能会抛出异常。 这通常可能是由于用户代码中的 bug 引起的。 # 这应该是由用户代码中的 bug 引起的。 火炬.不梯度(): # 请见 https://github.com/pytorch/pytorch/issues/53103 模块 反转(self): # type: ignore[call-overload] 如果 有属性(模块, "right_inverse"): value = 模块.右逆() # type: ignore[operator] 否则: 提升 运行时错误( f"参数化"{类型(模块).__name__}未实现 "右逆." ) 如果 self.is_tensor: # 这些异常只有在右逆函数不 # 返回相同的数据类型时才会抛出,这通常是由错误引起的 如果 not isinstance(, 张量): 提升 ValueError( f"`右逆` 应返回一个张量。得到{类型().__name__}" ) 如果 .dtype != self.原始.数据类型: 提升 ValueError( f`right_inverse` 返回的张量数据类型为{.数据类型} " f而 `original` 的数据类型为{self.原始.数据类型}" ) 我们知道结果将具有相同的数据类型 _maybe_set(self.原始, ) 否则: 如果 not isinstance(, 集合.abc.序列): 提升 ValueError( " 'right_inverse' 必须返回一个张量序列。 " f"已获取"{类型().__name__} ) 如果 长度() != self.累计张量: 提升 ValueError( "'right_inverse' 必须返回一个长度为的张量序列" f"{self.累计张量}. 获取长度为的序列{长度()} ) i, 张量 列举(): 原始_i = getattr(self, f"原始"{i}") 如果 not isinstance(张量, 张量): 提升 ValueError( f"`右逆`必须返回一个张量序列。" f"获取元素"{i}的类型{类型(张量).__name__}" ) 如果 原始_i.dtype != 张量.数据类型: 提升 ValueError( f张量{i}由 `right_inverse` 返回的具有数据类型{张量.数据类型} " f而 `original{i}`具有数据类型`{原始_i.数据类型}" ) _可能设置(原始_i, 张量)
def 前向(self) -> 张量: 如果 火炬.算子.是否正在脚本化(): 提升 运行时错误("参数化在脚本中不起作用。") # 解包第一次参数化的原始数据 如果 self.is_tensor: x = self[0]self.原始) 否则: 原始 = (getattr(self, f"原始"{i}") i 范围(self.张量)) x = self[0]*原始) 在这里无法调用 self[1:],所以我们需要稍微隐晦一些 我们想要跳过所有非整数键 当前索引 = 1 while 有属性(self, 字符串(curr_idx)): x = self[curr_idx]x) 当前索引 += 1 返回 x
def 注入新类(模块: 模块) -> : r设置一个可参数化的模块。 这通过用扩展它的类替换模块的类来实现 以便能够注入一个属性 Args: 模块(nn.Module):要注入属性的模块 "源代码" = 模块. def 默认深拷贝(self, 描述): # 当当前类中不存在 __deepcopy__ 方法时,仅模拟标准深拷贝过程。 对象 = 描述.获取(id(self), ) 如果 对象 not : 返回 对象 复制品 = self.__new__(self.) 描述[id(self] = 复制品 复制品.字典 = 深拷贝(self.字典, 描述) # 也保存所有存在的槽位。 slots_to_save = 复制注册._槽位名称(self.) # 类型: 忽略[attr-defined] 要保存的槽位: 如果 有属性(self, 插槽): setattr(复制品, 插槽, 深拷贝(getattr(self, 插槽), 描述)) 返回 复制品 def 获取状态(self): 提升 运行时错误( 模块参数化序列化仅 通过 state_dict()支持。参见:输入文本翻译为简体中文为:\n" https://maskerprc.github.io/tutorials/beginner/saving_loading_models.html "#保存或加载通用检查点以进行推理或恢复训练" ) dct = {"__getstate__": getstate} 我们不允许序列化参数化模块,但仍然应该允许深拷贝。 默认的 'deepcopy' 函数在存在的情况下调用 __deepcopy__ 方法而不是 __getstate__。 如果 not 有属性(, __deepcopy__): dct["__deepcopy__"] = 默认深拷贝 # 类型:忽略[赋值] 参数类 = 类型( f"参数化{.__name__}", (,), dct, ) 模块. = param_cls def _inject_property(模块: 模块, tensor_name: 字符串) -> : r将属性注入到 module[tensor_name]中。 它假设模块中的类已经通过_inject_new_class 从原始类进行了修改, 并且在:attr:`tensor_name`下的张量已经被移出。 已经被移出。 Args: 模块(nn.Module):注入属性的模块 tensor_name (str):要创建的属性的名称 "源代码" # 检查前置条件。 # 如果 register_parametrization 实现正确,则此情况永远不会触发。 断言 not 有属性(模块, 张量名称) @torch.算子.未使用 def 获取缓存的参数化(参数化) -> 张量: 全局 _缓存 key = (id(模块), 张量名称) 张量 = _缓存.获取() 如果 张量 : 张量 = 参数化() _缓存[] = 张量 返回 张量 def 获取参数化(self) -> 张量: 如果 火炬.算子.是否正在脚本化(): 提升 运行时错误("参数化在脚本中不起作用。") 参数化 = self.参数化(复数)[tensor_name] 如果 _缓存启用: 如果 火炬.算子.是否正在脚本化(): # 脚本 提升 运行时错误( "缓存未实现于脚本。" "请禁用缓存或避免使用脚本。" ) elif 火炬._C._获取追踪状态() not : 跟踪 提升 运行时错误( "无法在缓存参数化时跟踪模型。" ) 否则: 返回 获取缓存的参数化(参数化) 否则: 如果缓存未激活,此函数仅评估参数化 返回 参数化() def 设置原始值(self, : 张量) -> : 如果 火炬.算子.是否正在脚本化(): 提升 运行时错误("参数化在脚本中不起作用。") self.参数化[张量名称].右逆() setattr(模块., 张量名称, 属性(获取参数化, 设置原始))
[文档]def 注册参数化( 模块: 模块, tensor 名称: 字符串, 参数化: 模块, *, 不安全: 布尔类型 = 错误, ) -> 模块: r将参数化注册到模块中的张量。 假设为了简单起见,``tensor_name="weight"``。当访问``module.weight``, 模块将返回参数化版本 `parametrization(module.weight)`。 如果原始张量需要梯度,反向传播将进行微分 通过:attr:`参数化`,优化器将相应地更新张量。 第一次模块注册参数化时,此函数将添加一个属性 模块类型 :class:`~ParametrizationList` 的 ``parametrizations`` 参数化。 张量 ``weight`` 上的参数化列表将在以下位置访问: ``module.parametrizations.weight``。 原始张量将在以下位置访问: ``module.parametrizations.weight.original``. 参数化可以通过在相同属性上注册多个参数化进行连接。 已注册的参数化的训练模式在注册时更新。 The training mode of a registered parametrization is updated on registration 与主机模块的训练模式相匹配 参数化和缓冲区具有内置的缓存系统,可以激活 使用上下文管理器:func:`cached`。 attr:`参数化` 可选实现一个具有签名的 .. 代码块 :: python def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] 此方法在未参数化的张量上调用,当第一个参数化注册以计算原始张量的初始值时。 如果此方法未实现,则原始张量将只是未参数化的张量。 如果此方法未实现,则原始张量将保持为未参数化的张量。 如果一个张量上注册的所有参数化都实现了 `right_inverse`,则可以初始化一个参数化张量,如下例所示。 可以通过将以下内容分配给它来初始化一个参数化张量,如下例所示。 第一个参数化可能依赖于多个输入。 这可以通过从 `right_inverse` 返回一个张量元组来实现。 (以下为 ``RankOne`` 参数化的示例实现)。 在这种情况下,未受约束的张量也位于 ``module.parametrizations.weight`` 下。 命名为 ``original0``、``original1``、...。 .. 注意:: 如果 unsafe=False(默认值)则将调用 forward 和 right_inverse 方法。 一次执行多个一致性检查。 如果 unsafe=True,则当张量未参数化时将调用 right_inverse,否则不会调用。 在大多数情况下,right_inverse 将是一个函数,使得 .. 注意:: 在大多数情况下,`right_inverse` 将是一个函数,使得 ``forward(right_inverse(X)) == X``(见 `右逆函数 `_)。 有时,当参数化不是满射时,放宽这一点可能是合理的。 .. 警告:: 如果参数化依赖于多个输入,:func:`~register_parametrization` 将注册多个新参数。如果此类参数化已注册 在优化器创建后,这些新参数需要手动添加 将优化器传递进去。参见::meth:`torch.Optimizer.add_param_group`。 Args: nn.Module 模块:注册参数化的模块 tensor_name (str):注册参数或缓冲区的名称 参数化 nn.Module 参数化:要注册的参数化 关键字参数: 不安全(布尔值):表示参数化是否不安全的布尔标志 可能改变张量的数据类型和形状。默认:`False` 警告:注册时未检查参数化的一致性。 启用此标志风险自负。 抛出异常: ValueError:如果模块没有名为 :attr:`tensor_name` 的参数或缓冲区 示例: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> 导入 torch >>> 从 torch.nn 导入 nn >>> 导入 torch.nn.utils.parametrize 作为 P ... >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # 返回一个对称矩阵 ... >>> def right_inverse(self, A): >>> return A.triu() ... >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight 现在是对称的 True >>> A = torch.rand(5, 5) >>> A = A + A.T # A 现在是对称的 >>> m.weight = A # 将权重初始化为对称矩阵 A >>> print(torch.allclose(m.weight, A)) True >>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # 通过两个向量的乘积形成一个秩为 1 的矩阵 >>> 返回 x.unsqueeze(-1) @ y.unsqueeze(-2) ... >>> def right_inverse(self, Z): >>> # 将 Z 投影到秩为 1 的矩阵上 >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # 返回缩放的单个向量 >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt ... >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1 "源代码" 参数化.训练(模块.训练) 如果 是否参数化(模块, 张量名称): 正确性检查。 如果 A 是形状和 dtype 等于 module.weight 的张量空间 我们检查 parametrization.forward 和 parametrization.right_inverse 是否是 从 A 到 A 的函数 如果 not 不安全: Y = getattr(模块, 张量名称) X = 参数化(Y) 如果 not isinstance(X, 张量): 提升 ValueError( f"参数化必须返回一个张量。得到了"{类型(X).__name__} ) 如果 X.dtype != Y.数据类型: 提升 ValueError( "注册参数化可能不会改变张量的数据类型,除非启用了`unsafe`标志。输入文本翻译为简体中文为:\n" f"模块。{张量名称}.dtype: {Y.数据类型}输入文本翻译为简体中文为:\n" f"参数化(模块。{tensor_name}).dtype: {X.数据类型}" ) 如果 X.形状 != Y.shape: 提升 ValueError( 注册参数化可能不会改变张量的形状,除非启用`unsafe`标志。输入文本翻译为简体中文为:\n" f模块。{tensor_name}.shape: {Y.shape}输入文本翻译为简体中文为:\n" f"模块参数化(module."{张量名称}).shape: "{X.shape}" ) 如果 有属性(参数化, "右逆"): try: Z = 参数化.右逆(X) # type: ignore[operator] 除了 不支持的操作异常: 通过 否则: 如果 not isinstance(Z, 张量): 提升 ValueError( f"parametrization.right_inverse 必须返回一个张量。实际得到:"{类型(Z).__name__}" ) 如果 Z.dtype != Y.数据类型: 提升 ValueError( "parametrization.right_inverse 返回的张量必须具有相同的 dtype" f作为模块。{tensor_name},除非启用`unsafe`标志。输入文本翻译为简体中文为:\n" f"模块。{tensor 名称}.数据类型:{Y.数据类型}输入文本翻译为简体中文为:\n" f"返回的数据类型:{Z.数据类型}" ) 如果 Z.形状 != Y.shape: 提升 ValueError( "参数化.right_inverse 返回的张量必须具有相同的形状 " f"作为模块。{tensor_name},除非启用 `unsafe` 标志。输入文本翻译为简体中文为:\n" f模块。{张量名称}.形状:{Y.shape}输入文本翻译为简体中文为:\n" f"返回形状:{Z.shape}" ) # else 右逆元假定为恒等元 # 将新的参数化添加到参数化列表中 断言 isinstance(模块.参数化, ModuleDict) # 让 mypy 开心 模块.参数化[张量名称].追加(参数化) # 如果前一个参数化中的 unsafe 为 True,则保持启用 模块.参数化[张量名称].不安全 |= 不安全 忽略[索引,联合属性] elif 张量名称 模块._缓冲区 张量名称 模块.参数: 设置参数化机制 获取原始缓冲区或参数 原始 = getattr(模块, 张量名称) 我们提前创建这个以检查可能的错误 参数化 = 参数化列表( [参数化] 原始, 不安全=不安全 ) 删除之前的参数或缓冲区 delattr(模块, 张量名称) 如果这是模块上注册的第一个参数化 我们准备注入属性模块 如果 not 已参数化(模块): 修改类 注入新类(模块) 将 ``ModuleDict`` 注入到 module.parametrizations 实例下 模块.parametrizations = ModuleDict() 向类中添加一个属性 _inject_property(模块, tensor_name) # 添加参数化列表 断言 isinstance(模块.参数化, ModuleDict) # 让 mypy 开心 模块.参数化[张量名称] = 参数化 否则: 提升 ValueError( f模块{模块}'没有参数,缓冲区或“ f"具有名称的参数化元素"{tensor_name} ) 返回 模块
[docs]def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: r"""判断一个模块是否有参数化。 Args: 模块(nn.Module):查询模块 tensor_name(str,可选):模块中的参数名称 默认值:`None` 返回值: “如果:attr:`module` 对名为:attr:`tensor_name` 的参数有参数化,则返回 ``True``, 或者当:attr:`tensor_name` 为 ``None`` 时,它有任何参数化; 否则返回 ``False``, """ parametrizations = getattr(module, "parametrizations", None) if parametrizations is None or not isinstance(parametrizations, ModuleDict): return False if tensor_name is None: 检查是否存在至少一个参数化缓冲区或参数 返回 parametrizations 的长度大于 0 else: 返回 tensor_name 是否在 parametrizations 中
[文档]def 移除参数化( 模块: 模块, 张量名称: 字符串, 保留参数化: 布尔类型 = True, ) -> 模块: r在模块中移除张量上的参数化。 - 如果 `leave_parametrized=True`,则 `module[tensor_name]` 将被设置为 其当前输出。在这种情况下,参数化不应更改 `dtype` 张量。 - 如果 `leave_parametrized=False`,则 `module[tensor_name]` 将被设置为 ``module.parametrizations[tensor_name].original``中的未参数化张量。 只有当参数化仅依赖于一个张量时,这才能实现。 Args: module (nn.Module):从中移除参数化的模块。 tensor_name (str):要移除的参数化名称。 leave_parametrized (bool, 可选): 保持 :attr:`tensor_name` 属性参数化。 默认:``True`` 返回: 模块: 模块 抛出异常: ValueError: 如果 ``module[tensor_name]`` 没有参数化 ValueError: 如果 ``leave_parametrized=False`` 并且参数化依赖于多个张量 "源代码" 如果 not 已参数化(模块, 张量名称): 提升 ValueError( f模块{模块}没有参数化在{张量名称}" ) 获取原始张量 断言 isinstance(模块.参数化, ModuleDict) 让 mypy 开心 参数化 = 模块.参数化[张量名称] 如果 参数化.is_tensor: 原始 = 参数化.原始 如果 留作参数化: 火炬.不梯度(): t = getattr(模块, tensor_name) 我们知道它们具有相同的 dtype,因为我们注册时已经检查过这一点 因此,我们可以使用 set_ 我们这样做是为了确保参数不会改变 id() 这种方式用户不需要更新优化器 火炬.不梯度(): 如果 类型(原始) 火炬.张量: 可能设置(原始, t) 否则: try: 可能设置(原始, t) 除了 运行时错误 作为 e: # TODO: 修复此问题,以使 tensor 子类成为参数: # 运行时错误:不允许在从 .data 或 .detach() 创建的 Tensor 上使用 set_storage()。 提升 运行时错误( "调用 remove_parametrizations() 并设置 leave_parametrized=True " "对于是一个张量子类的实例的参数,需要正确实现 set_() 方法。" "或者,可以选择进入 swap_tensors 路径。" "或者,可以选择进入 swap_tensors 路径。" "要么设置 leave_parametrized=False,要么提供一个有效的实现" "在张量子类中的 set_() 或设置 " "torch.__future__.set_swap_module_params_on_conversion(True)." ) 来自 e 否则: 如果 leave_parametrized: 我们不能使用 no_grad,因为我们需要知道是否有一个或多个原始张量需要梯度 我们将不得不相信用户将其添加到优化器中 t = getattr(模块, 张量名称) 我们将不得不相信用户将其添加到优化器中 原始 = 参数(t) 如果 t.requires_grad 否则 t 否则: 提升 ValueError( "无法不参数化(`leave_parametrized=False`)一个以张量序列为参数的张量" "该张量是以张量序列为参数进行参数化的。" ) # 删除管理参数化的属性 delattr(模块., tensor_name) # 删除参数化列表 删除 模块.参数化[tensor_name] 将参数/缓冲区恢复到主类中 注册参数或缓冲区(模块, 张量名称, 原始) # 将参数化类回滚,如果没有其他缓冲区或参数 # 当前此类中已参数化 如果 not is_parametrized(模块): delattr(模块, 参数化) 恢复类 原类 = 模块..__基类__[0] 模块. = 原类 返回 模块
def 参数化前的类型(模块: 模块) -> 类型: r返回在应用参数化之前模块的类型,如果没有,则返回模块类型。 Args: 模块 (nn.Module):获取类型用的模块 "源代码" 如果 已参数化(模块): 返回 模块..__基类__[0] 否则: 返回 类型(模块) def 转移参数化和参数( from_module: 模块, 到模块: 模块, 张量名称: 可选[字符串] = , ) -> 模块: r将参数化及其参数从 :attr:`from_module` 转移到 :attr:`to_module`。 如果指定了 :attr:`tensor_name`,则只转移指定的参数,否则 转移所有参数化参数。如果这些参数在 to_module 中不存在,则会创建它们。 如果 from_module 没有参数化,则不执行任何操作。 Args: 从模块(nn.Module):要转移的模块 到模块(nn.Module):要转移到的模块 tensor_name (str, 可选):要转移的参数 返回: 模块:到模块 "源代码" 如果 已参数化(from_module): 断言 isinstance(from_module.参数化, ModuleDict) # 用于 mypy # 获取所有参数的列表或单个要传递的参数 需要传递的参数: 联盟[列表, ModuleDict] = ( from_module.参数化 如果 张量名称 否则 [张量名称] ) 断言 有属性(需要传递的参数, "__iter__") # for mypy 参数名称 需要传递的参数: 初始化 to_module 中尚未存在的待传输参数 如果 not 有属性(到模块, 参数名称): setattr( 到模块, 参数名称, 参数(getattr(from_module, 参数名称)), ) 将 params 的参数化应用到 to_module param_func from_module.参数化[参数名]: 注册参数化(到模块, 参数名, 参数函数) 断言 isinstance(到模块.参数化, ModuleDict) # for mypy # 使值匹配,原始值可以存储在原始或 # original0, original1...,需要检查这两种情况 如果 有属性(from_module.参数化[参数名] "原始"): 到模块.参数化[ 参数名称 ].原始 = from_module.参数化[参数名称].原始 否则: 数字 = 0 原始编号 = "原始" + 字符串(数字) 遍历每个原始#直到所有值都已设置 while 有属性(from_module.参数化[参数名] 原数): setattr( 到模块.参数化[参数名称] 原始数, getattr(from_module.参数化[参数名称] 原始数字), ) 数字 = 数字 + 1 原始数字 = "原始" + 字符串(数字) 返回 到模块

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源