torch.nn.utils.rnn.pad_packed_sequence¶
- torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[source][source]¶
填充可变长度序列的打包批次。
这是
pack_padded_sequence()的逆操作。返回的 Tensor 数据大小将为
T x B x *(如果batch_first是False)或B x T x *(如果batch_first是True),其中T是最长序列的长度,B是批大小。示例
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens = [2, 1, 3] >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) >>> packed PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) >>> seq_unpacked tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens_unpacked tensor([2, 1, 3])
注意
total_length在Module中实现pack sequence -> recurrent network -> unpack sequence模式很有用,其中DataParallel被DataParallel包裹。请参阅此 FAQ 部分以获取详细信息。- 参数:
序列(PackedSequence)- 批填充
batch_first(布尔值,可选)- 如果
True,输出将采用B x T x *格式,否则为T x B x *。填充值(浮点数,可选)- 填充元素的值。
total_length(整数,可选)- 如果不
None,输出将被填充以达到长度total_length。如果total_length小于sequence中的最大序列长度,此方法将抛出ValueError。
- 返回值:
包含填充序列的 Tensor 元组,以及包含每个序列长度的列表的 Tensor。批处理元素将按它们在将批处理传递给
pack_padded_sequence或pack_sequence时原始顺序重新排序。- 返回类型: