torch.quantile¶
- torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) Tensor ¶
计算沿
dim
维度的input
张量每行的 q-分位数。为了计算分位数,我们将 q 在 [0, 1] 的范围内映射到索引范围 [0, n],以找到排序输入中分位数的位置。如果分位数位于两个数据点
a < b
(索引为i
和j
)之间,则根据给定的interpolation
方法计算结果,如下所示:linear
:a + (b - a) * fraction
,其中fraction
是计算出的分位数索引的分数部分。lower
:a
.higher
:b
.nearest
:a
或b
,取其索引更接近计算出的分位数索引者(对于 .5 的分数进行向下取整)。midpoint
:(a + b) / 2
.
如果
q
是一个一维张量,输出张量的第一个维度表示分位数,其大小等于q
的大小,其余维度是缩减后的剩余维度。注意
默认情况下,
dim
是None
,导致在计算之前将input
张量展平。- 参数:
input (Tensor) – 输入张量。
q (浮点数或张量) – 范围 [0, 1] 内的标量或 1D 张量。
dim(整数)- 要减少的维度。
keepdim(布尔值)- 输出张量是否保留
dim
。
- 关键字参数:
interpolation (字符串) – 当所需的量分值位于两个数据点之间时使用的插值方法。可以是
linear
,lower
,higher
,midpoint
和nearest
。默认为linear
。输出(张量,可选)- 输出张量。
示例:
>>> a = torch.randn(2, 3) >>> a tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]]) >>> q = torch.tensor([0.25, 0.5, 0.75]) >>> torch.quantile(a, q, dim=1, keepdim=True) tensor([[[-0.5661], [ 0.5795]], [[ 0.0795], [ 0.6706]], [[ 0.5280], [ 0.9206]]]) >>> torch.quantile(a, q, dim=1, keepdim=True).shape torch.Size([3, 2, 1]) >>> a = torch.arange(4.) >>> a tensor([0., 1., 2., 3.]) >>> torch.quantile(a, 0.6, interpolation='linear') tensor(1.8000) >>> torch.quantile(a, 0.6, interpolation='lower') tensor(1.) >>> torch.quantile(a, 0.6, interpolation='higher') tensor(2.) >>> torch.quantile(a, 0.6, interpolation='midpoint') tensor(1.5000) >>> torch.quantile(a, 0.6, interpolation='nearest') tensor(2.) >>> torch.quantile(a, 0.4, interpolation='nearest') tensor(1.)