torch.topk¶
- torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)¶
返回给定
input
张量沿给定维度上的k
最大元素。如果未指定
dim
,则选择输入的最后维度。如果
largest
等于False
,则返回 k 个最小的元素。返回一个包含(值,索引)的 namedtuple,其中包含输入张量每一行在给定维度 dim 上最大的 k 个元素的值和索引。
如果
sorted
为True
,则确保返回的 k 个元素本身是有序的。注意
当使用 torch.topk 时,保证相同值的元素的索引稳定性是不确定的,并且可能在不同调用之间变化。
- 参数:
input (Tensor) – 输入张量。
k(int)- “top-k”中的 k。
dim(int,可选)- 要排序的维度
最大(布尔值,可选)- 控制返回最大或最小元素
排序(布尔值,可选)- 控制是否按顺序返回元素
- 关键字参数:
out (元组,可选) – 可选地提供输出元组 (Tensor, LongTensor),用作输出缓冲区。
示例:
>>> x = torch.arange(1., 6.) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))