• 文档 >
  • torch.nn >
  • torch.nn.utils.rnn.pad_packed_sequence
快捷键

torch.nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[source][source]

填充可变长度序列的打包批次。

这是 pack_padded_sequence() 的逆操作。

返回的 Tensor 数据大小将为 T x B x * (如果 batch_firstFalse )或 B x T x * (如果 batch_firstTrue ),其中 T 是最长序列的长度, B 是批大小。

示例

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

注意

total_lengthModule 中实现 pack sequence -> recurrent network -> unpack sequence 模式很有用,其中 DataParallelDataParallel 包裹。请参阅此 FAQ 部分以获取详细信息。

参数:
  • 序列(PackedSequence)- 批填充

  • batch_first(布尔值,可选)- 如果 True ,输出将采用 B x T x * 格式,否则为 T x B x *

  • 填充值(浮点数,可选)- 填充元素的值。

  • total_length(整数,可选)- 如果不 None ,输出将被填充以达到长度 total_length 。如果 total_length 小于 sequence 中的最大序列长度,此方法将抛出 ValueError

返回值:

包含填充序列的 Tensor 元组,以及包含每个序列长度的列表的 Tensor。批处理元素将按它们在将批处理传递给 pack_padded_sequencepack_sequence 时原始顺序重新排序。

返回类型:

tuple[torch.Tensor, torch.Tensor]


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源