高斯负对数似然损失 ¶
- class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[source][source]¶
高斯负对数似然损失。
目标被视为来自高斯分布的样本,其期望和方差由神经网络预测。对于一个以高斯分布建模的
target
张量模型,其期望为input
,正方差为var
的张量,损失为:用于稳定性的
eps
。默认情况下,损失函数的常数项被省略,除非full
是True
。如果var
的大小与input
不同(由于同方差假设),它必须具有 1 个最终维度或少一个维度(其他尺寸相同)才能正确广播。- 参数:
full(bool,可选)- 在损失计算中包含常数项。默认:
False
。eps(浮点数,可选)- 用于夹紧
var
的值(见下文说明),以保持稳定性。默认:1e-6。reduction(字符串,可选)- 指定应用于输出的降维方式:
'none'
|'mean'
|'sum'
.'none'
: 不应用降维,'mean'
: 输出为所有批次成员损失的均值,'sum'
: 输出为所有批次成员损失的总和。默认:'mean'
。
- 形状:
输入: 或 ,其中 表示任意数量的附加维度
目标: 或 ,与输入形状相同,或与输入形状相同但其中一个维度等于 1(以允许广播)
输入变量: 或 ,与输入形状相同,或与输入形状相同但其中一个维度等于 1,或与输入形状相同但少一个维度(以允许广播),或一个标量值
输出:如果
reduction
是'mean'
(默认)或'sum'
,则为标量。如果reduction
是'none'
,则 ,形状与输入相同
- 示例::
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic >>> output = loss(input, target, var) >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic >>> output = loss(input, target, var) >>> output.backward()
注意
关于自动微分,
var
的钳位被忽略,因此梯度不受其影响。- 参考文献:
Nix, D. A. 和 Weigend, A. S.,“估计目标概率分布的均值和方差”,1994 年 IEEE 国际神经网络会议(ICNN’94)论文集,佛罗里达州奥兰多,1994 年,第 1 卷,第 55-60 页,doi: 10.1109/ICNN.1994.374138.