快捷键

torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=- 1)[source][source]

从 Gumbel-Softmax 分布中采样(链接 1 链接 2)并可选择进行离散化。

参数:
  • logits (Tensor) – […, num_features] 未归一化的对数概率

  • tau (浮点数) – 非负标量温度

  • hard (布尔值) – 如果 True ,则返回的样本将被离散化为独热向量,但在 autograd 中将被区分对待,就像它是软样本一样

  • dim (整数) – softmax 将要计算的维度。默认值:-1。

返回值:

与 logits 形状相同的 Gumbel-Softmax 分布的采样张量。如果 hard=True ,则返回的样本将是一维热向量,否则它们将是概率分布,沿 dim 维度求和为 1。

返回类型:

张量

注意

此功能出于历史原因存在,未来可能会从 nn.Functional 中移除。

注意

硬化操作的主要技巧是执行 y_hard - y_soft.detach() + y_soft

实现了两件事:- 使输出值正好是 one-hot(因为我们添加然后减去 y_soft 值)- 使梯度等于 y_soft 梯度(因为我们去除了所有其他梯度)

示例::
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源