快捷键

torch.Tensor.masked_scatter_

Tensor.masked_scatter_(mask, source)

source 复制元素到 self 张量中,位置由 mask 为 True 的位置决定。从 source 复制的元素将按顺序逐个复制到 source 的起始位置 0,直到 mask 为 True 的所有出现。 mask 的形状必须可以广播到基础张量的形状。 source 应至少包含与 mask 中 1 的数量相同数量的元素。

参数:
  • 布尔掩码(BoolTensor)-布尔掩码

  • 源(Tensor)-要复制的张量

注意

mask 操作于 self 张量,而不是给定的 source 张量。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源