• 文档 >
  • torch >
  • torch.gather
快捷键

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor

沿由 dim 指定的轴收集值。

对于 3-D 张量,输出由以下指定:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

inputindex 必须具有相同的维度数。还要求所有维度 d != dimindex.size(d) <= input.size(d)out 的形状将与 index 相同。请注意, inputindex 之间不能进行广播。

参数:
  • 输入(张量)- 源张量

  • dim(整数)- 指索引的轴

  • index(长整型张量)- 要收集的元素的索引

关键字参数:
  • sparse_grad (bool, 可选) – 如果 True ,则 input 的梯度将是一个稀疏张量。

  • out (Tensor, 可选) – 目标张量

示例:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源