快捷键

CrossEntropyLoss

class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source][source] ¶

此准则计算输入 logits 与目标之间的交叉熵损失。

当训练具有 C 个类别的分类问题时很有用。如果提供,可选参数 weight 应该是一个分配给每个类别的权重的 1D Tensor。这在训练集不平衡时尤其有用。

输入应包含每个类别的未归一化 logits(通常不需要为正或总和为 1)。对于未批处理的输入,input 必须是一个大小为 (C)(C) 的 Tensor,对于 K 维情况,大小为 (minibatch,C)(minibatch, C)(minibatch,C,d1,d2,...,dK)(minibatch, C, d_1, d_2, ..., d_K) ,其中 K1K \geq 1 对于高维输入很有用,例如计算 2D 图像的每个像素的交叉熵损失。

该准则期望的目标应包含以下内容:

  • 该准则期望的目标应包含以下内容:在 [0,C)[0, C) 范围内的类别索引,其中 CC 是类别的数量;如果指定了 ignore_index,此损失也接受此类别索引(此索引不一定在类别范围内)。对于此情况,未减小的(即,将 reduction 设置为 'none' )损失可以描述为:

    (x,y)=L={l1,,lN},ln=wynlogexp(xn,yn)c=1Cexp(xn,c)1{ynignore_index}\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}

    其中 xx 是输入, yy 是目标, ww 是权重, CC 是类别数, NN 跨越 minibatch 维度,以及 d1,...,dkd_1, ..., d_k 在 K 维情况下的维度。如果 reduction 不等于 'none' (默认 'mean' ),则

    (x,y)={n=1N1n=1Nwyn1{ynignore_index}ln,if reduction=‘mean’;n=1Nln,if reduction=‘sum’.\ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases}

    注意,这种情况等价于先对输入应用 LogSoftmax ,然后应用 NLLLoss

  • 每个类别的概率;当需要每个 minibatch 项的标签超出单个类别时很有用,例如用于混合标签、标签平滑等。此情况未归一化的(即 reduction 设置为 'none' )损失可以描述为:

    (x,y)=L={l1,,lN},ln=c=1Cwclogexp(xn,c)i=1Cexp(xn,i)yn,c\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}

    其中 xx 是输入, yy 是目标, ww 是权重, CC 是类别数, NN 跨越 minibatch 维度,以及 d1,...,dkd_1, ..., d_k 在 K 维情况下的维度。如果 reduction 不等于 'none' (默认 'mean' ),则

    (x,y)={n=1NlnN,if reduction=‘mean’;n=1Nln,if reduction=‘sum’.\ell(x, y) = \begin{cases} \frac{\sum_{n=1}^N l_n}{N}, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases}

注意

当目标包含类别索引时,此标准的性能通常更好,因为这允许进行优化计算。只有在单个类标签对每个 minibatch 项过于限制时,才考虑将目标作为类别概率提供。

参数:
  • weight(张量,可选)- 每个类别给出的手动缩放权重。如果提供,必须是一个大小为 C 的浮点型张量。

  • size_average(布尔值,可选)- 已弃用(参见 reduction )。默认情况下,损失会在批次的每个损失元素上平均。注意,对于某些损失,每个样本可能有多个元素。如果字段 size_average 设置为 False ,则损失将改为对每个 minibatch 求和。当 reduceFalse 时,将被忽略。默认: True

  • ignore_index(整数,可选)- 指定一个被忽略的目标值,它不会对输入梯度做出贡献。当 size_averageTrue 时,损失将在非忽略的目标上平均。注意, ignore_index 仅适用于目标包含类别索引时。

  • reduce(布尔值,可选)- 已弃用(参见 reduction )。默认情况下,损失会在每个 minibatch 的观测上平均或求和,具体取决于 size_average 。当 reduceFalse 时,将返回每个批次的损失,并忽略 size_average 。默认: True

  • reduction(字符串,可选)- 指定应用于输出的缩减方式: 'none' | 'mean' | 'sum''none' :不应用缩减, 'mean' :取输出加权平均值, 'sum' :输出将被求和。注意: size_averagereduce 正在被弃用,在此期间,指定这两个参数之一将覆盖 reduction 。默认: 'mean'

  • label_smoothing(浮点数,可选)- 一个位于[0.0, 1.0]之间的浮点数。指定计算损失时的平滑量,其中 0.0 表示无平滑。目标值成为原始真实值和均匀分布的混合,如 Rethinking the Inception Architecture for Computer Vision 中所述。默认值: 0.00.0

形状:
  • 输入:形状 (C)(C)(N,C)(N, C)(N,C,d1,d2,...,dK)(N, C, d_1, d_2, ..., d_K) ,在 K 维损失的情况下, K1K \geq 1

  • 目标:如果包含类别索引,形状 ()()(N)(N)(N,d1,d2,...,dK)(N, d_1, d_2, ..., d_K) ,在 K 维损失的情况下,每个值应在 [0,C)[0, C) 之间。当使用类别索引时,目标数据类型必须为 long。如果包含类别概率,目标必须与输入形状相同,每个值应在 [0,1][0, 1] 之间。这意味着当使用类别概率时,目标数据类型必须为 float。

  • 输出:如果 reduction 为‘none’,则形状为 ()()(N)(N)(N,d1,d2,...,dK)(N, d_1, d_2, ..., d_K) ,在 K 维损失的情况下,取决于输入的形状。否则,为标量。

哪儿

C=number of classesN=batch size\begin{aligned} C ={} & \text{number of classes} \\ N ={} & \text{batch size} \\ \end{aligned}

示例:

>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源