快捷键

torch.nn.functional.ctc_loss

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[source][source]

应用连接主义时序分类损失。

详细内容请见 CTCLoss

注意

在某些情况下,当在 CUDA 设备上给定张量并使用 CuDNN 时,此运算符可能会选择非确定性算法以提高性能。如果这不可取,您可以尝试通过设置 torch.backends.cudnn.deterministic = True 来使操作确定性(可能以性能成本为代价)。有关更多信息,请参阅可重现性。

注意

当在 CUDA 设备上给定张量时,此操作可能会产生非确定性的梯度。有关更多信息,请参阅可重现性。

参数:
  • log_probs (Tensor) – (T,N,C)(T, N, C)(T,C)(T, C) 其中 C = 字母表中的字符数(包括空白),T = 输入长度,N = 批处理大小。输出(例如,通过 torch.nn.functional.log_softmax() 获得的)的对数概率。

  • targets (Tensor) – (N,S)(N, S) 或 (sum(target_lengths))。目标不能为空白。在第二种形式中,假设目标已连接。

  • 输入长度(Tensor)- (N)(N)()() 。输入的长度(必须每个都是 T\leq T

  • 目标长度(Tensor)- (N)(N)()() 。目标的长度

  • 空白(int,可选)- 空白标签。默认 00

  • reduction(str,可选)- 指定应用于输出的缩减方式: 'none' | 'mean' | 'sum''none' :不应用缩减, 'mean' :输出损失将除以目标长度,然后取批次的平均值, 'sum' :输出将求和。默认: 'mean'

  • zero_infinity(布尔型,可选)- 是否将无限损失及其相关梯度置零。默认值: False 无限损失主要发生在输入太短无法与目标对齐时。

返回类型:

张量

示例:

>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源