快捷键

三元组损失 ¶

class torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')[source][source]

创建一个测量三元组损失的准则,给定输入张量 x1x1x2x2x3x3 和一个大于 00 的间隔。这用于测量样本之间的相对相似度。三元组由 a、p 和 n(即锚、正例和负例)组成。所有输入张量的形状应为 (N,D)(N, D)

论文中详细描述了距离交换,该论文为 V. Balntas, E. Riba 等人撰写的《使用三元损失的浅层卷积特征描述符学习》。

每个样本在迷你批次中的损失函数为:

L(a,p,n)=max{d(ai,pi)d(ai,ni)+margin,0}L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}

其中

d(xi,yi)=xiyipd(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p

使用指定的 p 值计算范数,并添加一个小常数 ε\varepsilon 以增加数值稳定性。

参见 TripletMarginWithDistanceLoss ,它使用自定义距离函数计算输入张量的三元组损失。

参数:
  • 边距(float,可选)- 默认: 11

  • p(int,可选)- 对距离的范数度。默认: 22

  • eps(float,可选)- 用于数值稳定性的小常数。默认: 1e61e-6

  • 交换(bool,可选)- 距离交换在 V. Balntas,E. Riba 等人的论文《使用三元组损失的浅层卷积特征描述符学习》中详细描述。默认: False

  • size_average(布尔值,可选)- 已弃用(参见 reduction )。默认情况下,损失会在批次的每个损失元素上平均。注意,对于某些损失,每个样本可能有多个元素。如果字段 size_average 设置为 False ,则损失将改为对每个 minibatch 求和。当 reduceFalse 时,将被忽略。默认: True

  • reduce(布尔值,可选)- 已弃用(参见 reduction )。默认情况下,损失会在每个 minibatch 的观测上平均或求和,具体取决于 size_average 。当 reduceFalse 时,将返回每个批次的损失,并忽略 size_average 。默认: True

  • reduction(字符串,可选)- 指定应用于输出的缩减方式: 'none' | 'mean' | 'sum''none' :不应用缩减, 'mean' :输出总和将除以输出中的元素数量, 'sum' :输出将被求和。注意: size_averagereduce 正在被弃用,在此期间,指定这两个参数之一将覆盖 reduction 。默认: 'mean'

形状:
  • 输入: (N,D)(N, D)(D)(D) ,其中 DD 是向量维度。

  • 输出:如果 reduction'none' 且输入形状为 (N,D)(N, D) ,则形状为 (N)(N) 的张量;否则为标量。

示例:

>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
>>> anchor = torch.randn(100, 128, requires_grad=True)
>>> positive = torch.randn(100, 128, requires_grad=True)
>>> negative = torch.randn(100, 128, requires_grad=True)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源