• 文档 >
  • 模块代码 >
  • torch >
  • torch.distributions.constraint_registry
快捷键

torch.distributions.constraint_registry 的源代码

# mypy: 允许未类型化定义
r"""
PyTorch 提供了两个全局 :class:`ConstraintRegistry` 对象,它们链接
`torch.distributions.constraints.Constraint` 对象
`torch.distributions.transforms.Transform` 对象。这两个对象都
输入约束并返回变换,但它们在双射性方面有不同的保证。
双射性。

1. `biject_to(constraint)` 查找从 `constraints.real` 到给定 `constraint` 的双射映射
class:`~torch.distributions.transforms.Transform`
到给定的 `constraint`。返回的转换保证具有 `.bijective = True` 并且应该实现 `.log_abs_det_jacobian()`。
`.log_abs_det_jacobian()`。
`transform_to(constraint)` 查找从 `constraints.real` 到给定 `constraint` 的非必要双射
`:class:`~torch.distributions.transforms.Transform`
。返回的转换不保证实现 `.log_abs_det_jacobian()`。


transform_to() 注册表对于执行无约束操作很有用
优化概率分布的约束参数
指示每个分布的 `.arg_constraints` 字典。这些转换通常
过度参数化空间以避免旋转;因此它们更
适用于坐标优化算法如 Adam

    loc = torch.zeros(100, requires_grad=True)
    unconstrained = torch.zeros(100, requires_grad=True)
    scale = transform_to(Normal.arg_constraints["scale"])(unconstrained)
    loss = -Normal(loc, scale).log_prob(data).sum()

The ``biject_to()`` 注册表对于哈密顿蒙特卡洛方法很有用,
从具有约束 ``.support`` 的概率分布中抽取样本,
在非约束空间中传播,算法通常是旋转
不变项.::

距离 = 指数速率
无约束 = torch.zeros(100, requires_grad=True)
    sample = biject_to(dist.support)(unconstrained)
    potential_energy = -dist.log_prob(sample).sum()

.. 注意:

一个示例,其中 `transform_to` 和 `biject_to` 不同
`constraints.simplex`:`transform_to(constraints.simplex)` 返回一个
class:`~torch.distributions.transforms.SoftmaxTransform`,它只是
指数化并归一化其输入;这是一个便宜且主要
坐标-wise 操作适用于 SVI 等算法。
对比,`biject_to(constraints.simplex)` 返回
`torch.distributions.transforms.StickBreakingTransform` 类
将输入映射到一个维度更低的子空间;这是一个更昂贵的、数值稳定性较差的变换,但对于像 HMC 这样的算法是必需的。
这是一个昂贵的、数值稳定性较差的变换,但对于像 HMC 这样的算法是必需的。
对于像 HMC 这样的算法是必需的。

``biject_to`` 和 ``transform_to`` 对象可以被用户自定义扩展。
使用它们的 ``.register()`` 方法注册约束和转换,既可以作为单例约束的函数:
或者作为参数化约束的装饰器:

    transform_to.register(my_constraint, my_transform)

或作为参数化约束的装饰器:

    @transform_to.register(MyConstraintClass)
    def my_factory(constraint):
        assert isinstance(constraint, MyConstraintClass)
        return MyTransform(constraint.param1, constraint.param2)

您可以通过创建一个新的 :class:`ConstraintRegistry` 来创建自己的注册表
对象.
"""

来自 torch.distributions 导入 约束, 转换
来自 torch 的类型 导入 数值


全部 = [
    "约束注册表",
    "双射到",
    "转换成",
]


[文档] 约束注册表: """ 将约束链接到转换的注册表。 "沉浸式翻译" 定义 __init__(): ._注册表 = {} 超级().__init__()
[文档] def register(self, constraint, factory=None): """ 在此注册表中注册一个 :class:`~torch.distributions.constraints.Constraint` 子类。用法示例:: 子类。用法示例:: @my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints) 参数: 约束(:class:`~torch.distributions.constraints.Constraint`子类): 一个 :class:`~torch.distributions.constraints.Constraint` 的子类,或 一个所需类的单例对象。 工厂(Callable):输入约束对象并返回的可调用对象。 一个 `:class:`~torch.distributions.transforms.Transform` 对象。 """ # 支持用作装饰器。 如果工厂为空: 返回 lambda 工厂: self.register(constraint, factory) # 支持在单例实例上调用。 如果 isinstance(constraint, constraints.Constraint): constraint = type(constraint) 如果 constraint 不是 type 类型或者不是 constraints.Constraint 的子类 constraint, constraints.Constraint ): raise TypeError( 期望约束是约束子类或实例,但实际上得到了 {constraint} ) self._registry[constraint] = factory 返回工厂
定义
__调用__(, 约束): """ 根据约束对象查找约束空间中的变换。 用法: constraint = Normal.arg_constraints["scale"] scale = transform_to(constraint)(torch.zeros(1)) # constrained u = transform_to(constraint).inv(scale) # unconstrained 参数: constraint (:class:`~torch.distributions.constraints.Constraint`): 约束对象。 返回: 一个 `torch.distributions.transforms.Transform` 对象。 抛出异常: 如果未注册任何转换,则抛出 `NotImplementedError`。 "沉浸式翻译" 通过约束子类查找。 尝试: 工厂 = ._注册表[类型(约束)] 除了 键错误: 抛出 不支持的操作异常( f无法转换{类型(约束).__name__}约束" ) 来自 返回 工厂(约束)
映射到 = 约束注册表() 转换为 = 约束注册表() ################################################################################ # 注册表 ################################################################################ @biject_to.注册(约束.真实) @transform_to.注册(约束.真实) 定义 _转换为真实(约束): 返回 转换.身份转换 @biject_to.注册(约束.独立) 定义 双射到独立(约束): 基础转换 = 双射到(约束.基础约束) 返回 转换.独立转换( 基础转换, 约束.重新解释的批量维度数 ) @transform_to.注册(约束.独立) def _转换为独立(约束): 基础转换 = 转换为(约束.基础约束) 返回 转换.独立转换( 基础转换, 约束.重新解释的批量维度数 ) @biject_to.注册(约束.正的) @biject_to.注册(约束.非负的) @transform_to.注册(约束.正面) @transform_to.注册(约束.非负的) def 转换为正数(约束): 返回 转换.指数转换() @biject_to.注册(约束.大于) @biject_to.注册(约束.大于等于) @transform_to.注册(约束.大于) @transform_to.注册(约束.大于等于) 定义 _转换为大于(约束): 返回 转换.组合转换( [ 转换.深度转换(), 转换.变换(约束.下界, 1), ] ) @biject_to.注册(约束.小于) @transform_to.注册(约束.小于) def 转换为小于(约束): 返回 转换.组合转换( [ 转换.指数转换(), 转换.平移变换(约束.上限, -1), ] ) @biject_to.注册(约束.区间) @biject_to.注册(约束.半开区间) @transform_to.注册(约束.区间) @transform_to.注册(约束.半开区间) def _转换为区间(约束): # 处理单位区间的特殊情况。 下限为 0 = ( isinstance(约束.下界, ) 约束.下限 == 0 ) 上限为 1 = ( isinstance(约束.上限, ) 约束.上界 == 1 ) 如果 lower_is_0 upper_is_1: 返回 转换.Sigmoid 变换() 定位 = 约束.下限 缩放 = 约束.上界 - 约束.下限 返回 转换.组合变换( [变换.Sigmoid 变换(), 变换.变换矩阵(位置, 比例)] ) @biject_to.注册(约束.单形) 定义 双射到单纯形(约束): 返回 变换.StickBreakingTransform() @transform_to.注册(约束.单形) def 转换为单形(约束): 返回 转换.Softmax 转换() 定义 LowerCholeskyTransform 的映射 @transform_to.注册(约束.下三角 Cholesky) def _转换为下三角 Cholesky_(约束): 返回 转换.下三角 Cholesky 变换() @transform_to.注册(约束.正定) @transform_to.注册(约束.正半定) def 转换为正定(约束): 返回 转换.正定变换() @biject_to.注册(约束.相关乔列斯基) @transform_to.注册(约束.相关乔列斯基) def _转换到相关乔列斯基_(约束): 返回 转换.CorrCholeskyTransform() @biject_to.注册(约束.) def _biject_to_cat(约束): 返回 转换.猫转换( [双射到(c) for c 约束.cseq] 约束.暗淡, 约束.长度 ) @transform_to.注册(约束.) def _转换成猫(约束): 返回 转换.猫转换( [转换为(c) for c 约束.cseq] 约束.暗淡, 约束.长度 ) @biject_to.注册(约束.) def 双射到栈(约束): 返回 转换.栈转换( [双射到(c) for c 约束.c 序列] 约束.维度 ) @transform_to.注册(约束.) def _转换成堆栈(约束): 返回 转换.栈转换( [转换为(c) for c 约束.连续序列] 约束.维度 )

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源