torch.sort¶
- torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)¶
按给定维度对
input
张量的元素进行升序排序。如果未指定
dim
,则选择输入的最后维度。如果
descending
是True
,则元素按值降序排序。如果
stable
是True
,则排序算法变为稳定,保留等价元素的顺序。返回一个包含 (values, indices) 的 namedtuple,其中 values 是排序后的值,indices 是原始输入张量中元素的索引。
- 参数:
input (Tensor) – 输入张量。
dim(int,可选)- 要排序的维度
descending(bool,可选)- 控制排序顺序(升序或降序)
stable (布尔值,可选) – 使排序算法稳定,这保证了等价元素的顺序被保留。
- 关键字参数:
out (元组,可选) – 可选地提供输出元组 (Tensor, LongTensor),用作输出缓冲区。
示例:
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]]) >>> x = torch.tensor([0, 1] * 9) >>> x.sort() torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) >>> x.sort(stable=True) torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))