• 文档 >
  • torch >
  • torch.sort
快捷键

torch.sort

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

按给定维度对 input 张量的元素进行升序排序。

如果未指定 dim ,则选择输入的最后维度。

如果 descendingTrue ,则元素按值降序排序。

如果 stableTrue ,则排序算法变为稳定,保留等价元素的顺序。

返回一个包含 (values, indices) 的 namedtuple,其中 values 是排序后的值,indices 是原始输入张量中元素的索引。

参数:
  • input (Tensor) – 输入张量。

  • dim(int,可选)- 要排序的维度

  • descending(bool,可选)- 控制排序顺序(升序或降序)

  • stable (布尔值,可选) – 使排序算法稳定,这保证了等价元素的顺序被保留。

关键字参数:

out (元组,可选) – 可选地提供输出元组 (Tensor, LongTensor),用作输出缓冲区。

示例:

>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
tensor([[-0.2162,  0.0608,  0.6719,  2.3332],
        [-0.5793,  0.0061,  0.6058,  0.9497],
        [-0.5071,  0.3343,  0.9553,  1.0960]])
>>> indices
tensor([[ 1,  0,  2,  3],
        [ 3,  1,  0,  2],
        [ 0,  3,  1,  2]])

>>> sorted, indices = torch.sort(x, 0)
>>> sorted
tensor([[-0.5071, -0.2162,  0.6719, -0.5793],
        [ 0.0608,  0.0061,  0.9497,  0.3343],
        [ 0.6058,  0.9553,  1.0960,  2.3332]])
>>> indices
tensor([[ 2,  0,  0,  1],
        [ 0,  1,  1,  2],
        [ 1,  2,  2,  0]])
>>> x = torch.tensor([0, 1] * 9)
>>> x.sort()
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 2, 16,  4,  6, 14,  8,  0, 10, 12,  9, 17, 15, 13, 11,  7,  5,  3,  1]))
>>> x.sort(stable=True)
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17]))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源