• 文档 >
  • 模块代码 >
  • torch >
  • torch.distributions.exp_family
快捷键

torch.distributions.exp_family 的源代码

# mypy: 允许未类型化定义
导入 火炬
来自 火炬 导入 张量
来自 torch.distributions.distribution 导入 分布


全部 = [指数族]


[文档] 指数族(分发): r""" 指数族是指数族概率分布的抽象基类, 其概率质量/密度函数的形式如下定义 .. math:: p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) 其中 :math:`\theta` 表示自然参数,:math:`t(x)` 表示充分统计量, math:`F(\theta)` 是给定族的对数归一化函数,:math:`k(x)` 是载波测度。 量。 注意: 这个类是`Distribution`类和属于指数分布族的分布之间的中介,主要用于检查`.entropy()`和解析 KL 散度方法的正确性。 我们使用这个类来计算熵和 KL 散度,使用 AD 框架和 Bregman 散度(感谢:Frank Nielsen 和 Richard Nock,熵和散度)。 我们使用这个类来计算熵和 KL 散度,使用 AD 框架和 Bregman 散度(感谢:Frank Nielsen 和 Richard Nock,熵和散度)。 我们使用这个类来计算熵和 KL 散度,使用 AD 框架和 Bregman 散度(感谢:Frank Nielsen 和 Richard Nock,熵和散度)。 交叉熵指数族 """ @property def 自然参数(self) -> 元组[张量, ...] """ 自然参数的抽象方法。返回一个包含张量的元组 在分布 """ 抛出 未实现异常 def 对数归一化器(self, *自然参数): """ 日志归一化函数的抽象方法。返回基于分布和输入的日志归一化器。 的期望载波度量,这是计算所需的。 """ 抛出 未实现异常 @property def _mean_carrier_measure(self) -> float: """ 期望载波度量的抽象方法,这是计算所需的。 熵。 """ 抛出 未实现异常
[文档] def 熵(self): """ 使用对数归一化器的 Bregman 散度计算熵的方法。 """ result = -self._mean_carrier_measure nparams = [p.detach().requires_grad_() for p in self._natural_params] lg_normal = self._log_normalizer(*nparams) gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) result += lg_normal for np, g in zip(nparams, gradients): result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) 返回结果

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源