torch.gather¶
- torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor ¶
沿由 dim 指定的轴收集值。
对于 3-D 张量,输出由以下指定:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
input
和index
必须具有相同的维度数。还要求所有维度d != dim
的index.size(d) <= input.size(d)
。out
的形状将与index
相同。请注意,input
和index
之间不能进行广播。- 参数:
输入(张量)- 源张量
dim(整数)- 指索引的轴
index(长整型张量)- 要收集的元素的索引
- 关键字参数:
sparse_grad (bool, 可选) – 如果
True
,则input
的梯度将是一个稀疏张量。out (Tensor, 可选) – 目标张量
示例:
>>> t = torch.tensor([[1, 2], [3, 4]]) >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]])