BCEWithLogitsLoss¶
- class torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)[source][source]¶
此损失函数结合了一个 Sigmoid 层和二进制交叉熵损失函数(BCELoss)在一个单独的类中。这个版本比使用普通的 Sigmoid 后跟 BCELoss 更数值稳定,因为通过将操作合并到一个层中,我们利用了 log-sum-exp 技巧来提高数值稳定性。
未归一化的(即
reduction
设置为'none'
)损失可以描述为:其中 是批量大小。如果
reduction
不等于'none'
(默认'mean'
),则这用于测量重建误差,例如在自动编码器中。注意,目标 t[i]应该是介于 0 和 1 之间的数字。
通过为正例添加权重,可以权衡召回率和精确度。在多标签分类的情况下,损失可以描述为:
其中 是类别编号( 用于多标签二分类, 用于单标签二分类), 是批次中的样本编号, 是类别 正回答的权重。
增加召回率, 增加精确率。
例如,如果一个数据集包含一个类别的 100 个正例和 300 个负例,那么该类别的
pos_weight
应该等于 。损失将像数据集包含 个正例一样起作用。示例:
>>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10 >>> output = torch.full([10, 64], 1.5) # A prediction (logit) >>> pos_weight = torch.ones([64]) # All weights are equal to 1 >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) >>> criterion(output, target) # -log(sigmoid(1.5)) tensor(0.20...)
在上述例子中,
pos_weight
张量中的元素对应于多标签二分类场景中的 64 个不同类别。pos_weight
中的每个元素都旨在根据相应类别的正负样本不平衡来调整损失函数。这种方法在类别不平衡程度不同的数据集中很有用,确保损失计算准确考虑每个类的分布。- 参数:
weight(张量,可选)- 分配给每个批次元素的损失的手动缩放权重。如果提供,则必须是一个大小为 nbatch 的张量。
size_average(布尔值,可选)- 已弃用(参见
reduction
)。默认情况下,损失会在批次的每个损失元素上平均。注意,对于某些损失,每个样本可能有多个元素。如果字段size_average
设置为False
,则损失将改为对每个 minibatch 求和。当reduce
为False
时,将被忽略。默认:True
reduce(布尔值,可选)- 已弃用(参见
reduction
)。默认情况下,损失会在每个 minibatch 的观测上平均或求和,具体取决于size_average
。当reduce
为False
时,将返回每个批次的损失,并忽略size_average
。默认:True
reduction(字符串,可选)- 指定应用于输出的缩减方式:
'none'
|'mean'
|'sum'
。'none'
:不应用缩减,'mean'
:输出总和将除以输出中的元素数量,'sum'
:输出将被求和。注意:size_average
和reduce
正在被弃用,在此期间,指定这两个参数之一将覆盖reduction
。默认:'mean'
pos_weight(张量,可选)- 要与目标广播的正面示例的权重。必须是一个与类维度的类别数量相等的张量。请密切关注 PyTorch 的广播语义,以实现所需的操作。对于大小为[B, C, H, W](其中 B 是批次大小)的目标,pos_weight 的大小为[B, C, H, W]将应用于批次的每个元素或[C, H, W]在整个批次中使用相同的 pos_weights。对于 2D 多类目标[C, H, W],要在所有空间维度上应用相同的正权重,请使用:[C, 1, 1]。默认:
None
- 形状:
输入: ,其中 表示任意数量的维度。
目标: ,与输入形状相同。
输出:标量。如果
reduction
是'none'
,则 ,与输入形状相同。
示例:
>>> loss = nn.BCEWithLogitsLoss() >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> output = loss(input, target) >>> output.backward()