torch.multinomial¶
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor ¶
返回一个张量,其中每行包含从多项式(更严格的定义是多元的,请参阅
torch.distributions.multinomial.Multinomial
获取更多详情)概率分布中采样的num_samples
索引。注意
input
的行不需要求和为 1(在这种情况下,我们使用这些值作为权重),但必须是非负的、有限的,并且总和不为零。索引按照每个索引被采样的时间顺序排列(首先采样的样本放在第一列)。
如果
input
是一个向量,则out
是一个大小为num_samples
的向量。如果
input
是一个有 m 行的矩阵,则out
是一个形状为 的矩阵。如果替换为
True
,则进行有放回的抽样。如果不是,则进行不放回抽取,这意味着当抽取一个样本索引后,该索引不能再次用于该行。
注意
不放回抽取时,
num_samples
必须小于input
中的非零元素数量(或input
每行的最小非零元素数量,如果它是一个矩阵)。- 参数:
输入(张量)- 包含概率的输入张量
num_samples(整数)- 要抽取的样本数量
替换(布尔值,可选)- 是否进行有放回抽取
- 关键字参数:
生成器(
torch.Generator
,可选)- 用于采样的伪随机数生成器输出(张量,可选)- 输出张量。
示例:
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])