• 文档 >
  • torch >
  • torch.multinomial
快捷键

torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor

返回一个张量,其中每行包含从多项式(更严格的定义是多元的,请参阅 torch.distributions.multinomial.Multinomial 获取更多详情)概率分布中采样的 num_samples 索引。

注意

input 的行不需要求和为 1(在这种情况下,我们使用这些值作为权重),但必须是非负的、有限的,并且总和不为零。

索引按照每个索引被采样的时间顺序排列(首先采样的样本放在第一列)。

如果 input 是一个向量,则 out 是一个大小为 num_samples 的向量。

如果 input 是一个有 m 行的矩阵,则 out 是一个形状为 (m×num_samples)(m \times \text{num\_samples}) 的矩阵。

如果替换为 True ,则进行有放回的抽样。

如果不是,则进行不放回抽取,这意味着当抽取一个样本索引后,该索引不能再次用于该行。

注意

不放回抽取时, num_samples 必须小于 input 中的非零元素数量(或 input 每行的最小非零元素数量,如果它是一个矩阵)。

参数:
  • 输入(张量)- 包含概率的输入张量

  • num_samples(整数)- 要抽取的样本数量

  • 替换(布尔值,可选)- 是否进行有放回抽取

关键字参数:
  • 生成器( torch.Generator ,可选)- 用于采样的伪随机数生成器

  • 输出(张量,可选)- 输出张量。

示例:

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源