CTCLoss¶
- class torch.nn.CTCLoss(空白=0, reduction='mean', 无穷为零=False)[source][source] ¶
连接主义时序分类损失。
计算连续(未分割)时间序列与目标序列之间的损失。CTCLoss 对输入与目标可能的对齐概率进行求和,产生一个相对于每个输入节点可微分的损失值。输入与目标的对齐假设为“多对一”,这限制了目标序列的长度,使其必须等于输入长度。
- 参数:
空白(整型,可选)- 空白标签。默认值: 。
reduction(字符串,可选)- 指定应用于输出的归约方式:
'none'
|'mean'
|'sum'
。'none'
:不应用归约,'mean'
:输出损失将除以目标长度,然后取批次的平均值,'sum'
:输出损失将求和。默认值:'mean'
。zero_infinity(布尔型,可选)- 是否将无限损失及其相关梯度置零。默认值:
False
无限损失主要发生在输入太短无法与目标对齐时。
- 形状:
Log_probs:大小为 或 的张量,其中 、 和 。表示输出(例如,通过
torch.nn.functional.log_softmax()
获得的)的对数概率。Targets:大小为 或 的张量,其中 和 。它表示目标序列。目标序列中的每个元素是一个类别索引。目标索引不能为空(默认=0)。在 形式中,目标被填充到最长序列的长度,并堆叠。在 形式中,假设目标未填充,并在 1 个维度内连接。
Input_lengths:元组或大小为 或 的张量,其中 。它表示输入的长度(必须每个都是 )。长度指定为每个序列,以在假设序列被填充到相等长度的情况下实现掩码。
目标长度:大小为 或 的元组或张量,其中 。它表示目标的长度。长度为每个序列指定,以在假设序列被填充到相等长度的情况下实现掩码。如果目标形状为 ,则目标长度实际上是每个目标序列的停止索引 ,使得每个批次的每个目标
target_n = targets[n,0:s_n]
。长度必须每个都是 。如果目标作为 1d 张量给出,它是各个目标拼接的结果,则目标长度必须加起来等于张量的总长度。输出:如果
reduction
是'mean'
(默认)或'sum'
,则为标量。如果reduction
是'none'
,则如果输入是批处理,则 ,如果输入是非批处理,则 ,其中 。
示例:
>>> # Target are to be padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded and unbatched (effectively N=1) >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> >>> # Initialize random batch of input vectors, for *size = (T,C) >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() >>> input_lengths = torch.tensor(T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
- 参考文献:
A. Graves 等人:连接时序分类:使用循环神经网络对未分割序列数据进行标记:https://www.cs.toronto.edu/~graves/icml_2006.pdf
注意
为了使用 CuDNN,必须满足以下条件:
targets
必须为连接格式,所有input_lengths
必须为 T。 ,target_lengths
, ,整数参数的数据类型必须为torch.int32
。正规实现使用(在 PyTorch 中更常见)的 torch.long 数据类型。
注意
In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by settingtorch.backends.cudnn.deterministic = True
. Please see the notes on Reproducibility for background.