torch.argmax¶
- torch.argmax(input) LongTensor ¶
返回
input
张量中所有元素的最大值的索引。这是该函数返回的第二个值。请参阅其文档以了解此方法的精确语义。
注意
如果存在多个最大值,则返回第一个最大值的索引。
- 参数:
input (Tensor) – 输入张量。
示例:
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a) tensor(0)
- torch.argmax(input, dim, keepdim=False) LongTensor
返回张量在指定维度上的最大值的索引。
这是
torch.max()
返回的第二个值。请参阅其文档以了解此方法的精确语义。- 参数:
input (Tensor) – 输入张量。
dim(int)- 要降低的维度。如果为
None
,则返回展平输入的 argmax。keepdim(布尔值)- 输出张量是否保留
dim
。
示例:
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1])