快捷键

torch.Tensor.scatter_add_

Tensor.scatter_add_(dim, index, src) Tensor

将张量 src 中的所有值按 index 张量中指定的索引添加到 self 中,类似于 scatter_() 。对于 src 中的每个值,它被添加到由 src 中的索引指定的 self 中的索引,以及由 index 中的对应值指定的 dimension != dim 中的索引。

对于一个三维张量, self 的更新如下:

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

selfindexsrc 应具有相同的维度。还要求对于所有维度 dindex.size(d) <= src.size(d) ,以及对于所有维度 d != dimindex.size(d) <= self.size(d) 。注意, indexsrc 不能进行广播。

注意

当在 CUDA 设备上给定张量时,此操作可能表现出非确定性。有关更多信息,请参阅可重现性。

注意

仅实现了 src.shape == index.shape 的反向传播。

参数:
  • dim(整数)- 指索引的轴

  • 索引(LongTensor)- 要分散和添加的元素的索引,可以是空的,也可以与 src 的维度相同。当为空时,该操作返回 self 不变。

  • 源(张量)- 要散列并添加的源元素

示例:

>>> src = torch.ones((2, 5))
>>> index = torch.tensor([[0, 1, 2, 0, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[1., 0., 0., 1., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
>>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源