torch.take_along_dim¶
- torch.take_along_dim(input, indices, dim=None, *, out=None) Tensor ¶
从
input
中选取 1 维索引indices
沿给定dim
的值。如果
dim
为 None,则将输入数组视为已展平为 1d。返回沿维度索引的函数,如
torch.argmax()
和torch.argsort()
,设计用于与该函数一起使用。下面是示例。注意
此函数类似于 NumPy 的 take_along_axis。另请参阅
torch.gather()
。- 参数:
input (Tensor) – 输入张量。
索引(LongTensor)-
input
中的索引。必须具有 long 数据类型。dim(int,可选)- 要选择的维度。默认:0
- 关键字参数:
输出(张量,可选)- 输出张量。
示例:
>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) >>> max_idx = torch.argmax(t) >>> torch.take_along_dim(t, max_idx) tensor([60]) >>> sorted_idx = torch.argsort(t, dim=1) >>> torch.take_along_dim(t, sorted_idx, dim=1) tensor([[10, 20, 30], [40, 50, 60]])