• 文档 >
  • torch >
  • torch.quantile
快捷键

torch.quantile

torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) Tensor

计算沿 dim 维度的 input 张量每行的 q-分位数。

为了计算分位数,我们将 q 在 [0, 1] 的范围内映射到索引范围 [0, n],以找到排序输入中分位数的位置。如果分位数位于两个数据点 a < b (索引为 ij )之间,则根据给定的 interpolation 方法计算结果,如下所示:

  • linear : a + (b - a) * fraction ,其中 fraction 是计算出的分位数索引的分数部分。

  • lower: a.

  • higher: b.

  • nearest : ab ,取其索引更接近计算出的分位数索引者(对于 .5 的分数进行向下取整)。

  • midpoint: (a + b) / 2.

如果 q 是一个一维张量,输出张量的第一个维度表示分位数,其大小等于 q 的大小,其余维度是缩减后的剩余维度。

注意

默认情况下, dimNone ,导致在计算之前将 input 张量展平。

参数:
  • input (Tensor) – 输入张量。

  • q (浮点数或张量) – 范围 [0, 1] 内的标量或 1D 张量。

  • dim(整数)- 要减少的维度。

  • keepdim(布尔值)- 输出张量是否保留 dim

关键字参数:
  • interpolation (字符串) – 当所需的量分值位于两个数据点之间时使用的插值方法。可以是 linearlowerhighermidpointnearest 。默认为 linear

  • 输出(张量,可选)- 输出张量。

示例:

>>> a = torch.randn(2, 3)
>>> a
tensor([[ 0.0795, -1.2117,  0.9765],
        [ 1.1707,  0.6706,  0.4884]])
>>> q = torch.tensor([0.25, 0.5, 0.75])
>>> torch.quantile(a, q, dim=1, keepdim=True)
tensor([[[-0.5661],
        [ 0.5795]],

        [[ 0.0795],
        [ 0.6706]],

        [[ 0.5280],
        [ 0.9206]]])
>>> torch.quantile(a, q, dim=1, keepdim=True).shape
torch.Size([3, 2, 1])
>>> a = torch.arange(4.)
>>> a
tensor([0., 1., 2., 3.])
>>> torch.quantile(a, 0.6, interpolation='linear')
tensor(1.8000)
>>> torch.quantile(a, 0.6, interpolation='lower')
tensor(1.)
>>> torch.quantile(a, 0.6, interpolation='higher')
tensor(2.)
>>> torch.quantile(a, 0.6, interpolation='midpoint')
tensor(1.5000)
>>> torch.quantile(a, 0.6, interpolation='nearest')
tensor(2.)
>>> torch.quantile(a, 0.4, interpolation='nearest')
tensor(1.)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源