• 文档 >
  • torch >
  • torch.argwhere
快捷键

torch.argwhere

torch.argwhere(input) → Tensor

返回一个包含所有非零元素索引的张量。结果中的每一行包含一个非零元素的索引。结果按字典序排序,最后一个索引变化最快(C 风格)。

如果 input 具有 nn 维度,则结果索引张量 out 的大小为 (z×n)(z \times n) ,其中 zzinput 张量中非零元素的总数。

注意

此函数类似于 NumPy 的 argwhere 函数。

input 在 CUDA 上时,此函数会导致主机-设备同步。

参数:

{输入} –

示例:

>>> t = torch.tensor([1, 0, 1])
>>> torch.argwhere(t)
tensor([[0],
        [2]])
>>> t = torch.tensor([[1, 0, 1], [0, 1, 1]])
>>> torch.argwhere(t)
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [1, 2]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源