快捷键

CTCLoss

class torch.nn.CTCLoss(空白=0, reduction='mean', 无穷为零=False)[source][source] ¶

连接主义时序分类损失。

计算连续(未分割)时间序列与目标序列之间的损失。CTCLoss 对输入与目标可能的对齐概率进行求和,产生一个相对于每个输入节点可微分的损失值。输入与目标的对齐假设为“多对一”,这限制了目标序列的长度,使其必须等于输入长度。

参数:
  • 空白(整型,可选)- 空白标签。默认值: 00

  • reduction(字符串,可选)- 指定应用于输出的归约方式: 'none' | 'mean' | 'sum''none' :不应用归约, 'mean' :输出损失将除以目标长度,然后取批次的平均值, 'sum' :输出损失将求和。默认值: 'mean'

  • zero_infinity(布尔型,可选)- 是否将无限损失及其相关梯度置零。默认值: False 无限损失主要发生在输入太短无法与目标对齐时。

形状:
  • Log_probs:大小为 (T,N,C)(T, N, C)(T,C)(T, C) 的张量,其中 T=input lengthT = \text{input length}N=batch sizeN = \text{batch size}C=number of classes (including blank)C = \text{number of classes (including blank)} 。表示输出(例如,通过 torch.nn.functional.log_softmax() 获得的)的对数概率。

  • Targets:大小为 (N,S)(N, S)(sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 的张量,其中 N=batch sizeN = \text{batch size}S=max target length, if shape is (N,S)S = \text{max target length, if shape is } (N, S) 。它表示目标序列。目标序列中的每个元素是一个类别索引。目标索引不能为空(默认=0)。在 (N,S)(N, S) 形式中,目标被填充到最长序列的长度,并堆叠。在 (sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 形式中,假设目标未填充,并在 1 个维度内连接。

  • Input_lengths:元组或大小为 (N)(N)()() 的张量,其中 N=batch sizeN = \text{batch size} 。它表示输入的长度(必须每个都是 T\leq T )。长度指定为每个序列,以在假设序列被填充到相等长度的情况下实现掩码。

  • 目标长度:大小为 (N)(N)()() 的元组或张量,其中 N=batch sizeN = \text{batch size} 。它表示目标的长度。长度为每个序列指定,以在假设序列被填充到相等长度的情况下实现掩码。如果目标形状为 (N,S)(N,S) ,则目标长度实际上是每个目标序列的停止索引 sns_n ,使得每个批次的每个目标 target_n = targets[n,0:s_n] 。长度必须每个都是 S\leq S 。如果目标作为 1d 张量给出,它是各个目标拼接的结果,则目标长度必须加起来等于张量的总长度。

  • 输出:如果 reduction'mean' (默认)或 'sum' ,则为标量。如果 reduction'none' ,则如果输入是批处理,则 (N)(N) ,如果输入是非批处理,则 ()() ,其中 N=batch sizeN = \text{batch size}

示例:

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded and unbatched (effectively N=1)
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>>
>>> # Initialize random batch of input vectors, for *size = (T,C)
>>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
>>> input_lengths = torch.tensor(T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
参考文献:

A. Graves 等人:连接时序分类:使用循环神经网络对未分割序列数据进行标记:https://www.cs.toronto.edu/~graves/icml_2006.pdf

注意

为了使用 CuDNN,必须满足以下条件: targets 必须为连接格式,所有 input_lengths 必须为 T。 blank=0blank=0target_lengths256\leq 256 ,整数参数的数据类型必须为 torch.int32

正规实现使用(在 PyTorch 中更常见)的 torch.long 数据类型。

注意

In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting torch.backends.cudnn.deterministic = True . Please see the notes on Reproducibility for background.


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源