torch.Tensor.masked_scatter_¶
- Tensor.masked_scatter_(mask, source)¶
从
source
复制元素到self
张量中,位置由mask
为 True 的位置决定。从source
复制的元素将按顺序逐个复制到source
的起始位置 0,直到mask
为 True 的所有出现。mask
的形状必须可以广播到基础张量的形状。source
应至少包含与mask
中 1 的数量相同数量的元素。- 参数:
布尔掩码(BoolTensor)-布尔掩码
源(Tensor)-要复制的张量
注意
mask
操作于self
张量,而不是给定的source
张量。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])