torch.argwhere¶
- torch.argwhere(input) → Tensor
返回一个包含所有非零元素索引的张量。结果中的每一行包含一个非零元素的索引。结果按字典序排序,最后一个索引变化最快(C 风格)。
如果
input
具有 维度,则结果索引张量out
的大小为 ,其中 是input
张量中非零元素的总数。注意
此函数类似于 NumPy 的 argwhere 函数。
当
input
在 CUDA 上时,此函数会导致主机-设备同步。- 参数:
{输入} –
示例:
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])