torch.where¶
- torch.where(condition, input, other, *, out=None) → Tensor
根据需要从
input
或other
中选择元素,返回一个张量,具体取决于condition
。该操作定义为:
注意
张量
condition
、input
、other
必须可广播。- 参数:
condition (BoolTensor) – 当为 True(非零)时,输出输入,否则输出其他
input (Tensor 或 Scalar) – 值(如果
input
是标量)或根据condition
在True
中选择的索引的值other (Tensor 或 Scalar) – 值(如果
other
是标量)或根据condition
在False
中选择的索引的值
- 关键字参数:
输出(张量,可选)- 输出张量。
- 返回值:
形状等于
condition
,input
,other
广播形状的张量- 返回类型:
示例:
>>> x = torch.randn(3, 2) >>> y = torch.ones(3, 2) >>> x tensor([[-0.4620, 0.3139], [ 0.3898, -0.7197], [ 0.0478, -0.1657]]) >>> torch.where(x > 0, 1.0, 0.0) tensor([[0., 1.], [1., 0.], [1., 0.]]) >>> torch.where(x > 0, x, y) tensor([[ 1.0000, 0.3139], [ 0.3898, 1.0000], [ 0.0478, 1.0000]]) >>> x = torch.randn(2, 2, dtype=torch.double) >>> x tensor([[ 1.0779, 0.0383], [-0.8785, -1.1089]], dtype=torch.float64) >>> torch.where(x > 0, x, 0.) tensor([[1.0779, 0.0383], [0.0000, 0.0000]], dtype=torch.float64)
- torch.where(condition) → LongTensor 元组
torch.where(condition)
与torch.nonzero(condition, as_tuple=True)
相同。注意
参见
torch.nonzero()
.