• 文档 >
  • torch.nn >
  • 高斯负对数似然损失
快捷键

高斯负对数似然损失 ¶

class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[source][source]

高斯负对数似然损失。

目标被视为来自高斯分布的样本,其期望和方差由神经网络预测。对于一个以高斯分布建模的 target 张量模型,其期望为 input ,正方差为 var 的张量,损失为:

loss=12(log(max(var, eps))+(inputtarget)2max(var, eps))+const.\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

用于稳定性的 eps 。默认情况下,损失函数的常数项被省略,除非 fullTrue 。如果 var 的大小与 input 不同(由于同方差假设),它必须具有 1 个最终维度或少一个维度(其他尺寸相同)才能正确广播。

参数:
  • full(bool,可选)- 在损失计算中包含常数项。默认: False

  • eps(浮点数,可选)- 用于夹紧 var 的值(见下文说明),以保持稳定性。默认:1e-6。

  • reduction(字符串,可选)- 指定应用于输出的降维方式: 'none' | 'mean' | 'sum' . 'none' : 不应用降维, 'mean' : 输出为所有批次成员损失的均值, 'sum' : 输出为所有批次成员损失的总和。默认: 'mean'

形状:
  • 输入: (N,)(N, *)()(*) ,其中 * 表示任意数量的附加维度

  • 目标: (N,)(N, *)()(*) ,与输入形状相同,或与输入形状相同但其中一个维度等于 1(以允许广播)

  • 输入变量: (N,)(N, *)()(*) ,与输入形状相同,或与输入形状相同但其中一个维度等于 1,或与输入形状相同但少一个维度(以允许广播),或一个标量值

  • 输出:如果 reduction'mean' (默认)或 'sum' ,则为标量。如果 reduction'none' ,则 (N,)(N, *) ,形状与输入相同

示例::
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

注意

关于自动微分, var 的钳位被忽略,因此梯度不受其影响。

参考文献:

Nix, D. A. 和 Weigend, A. S.,“估计目标概率分布的均值和方差”,1994 年 IEEE 国际神经网络会议(ICNN’94)论文集,佛罗里达州奥兰多,1994 年,第 1 卷,第 55-60 页,doi: 10.1109/ICNN.1994.374138.


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源