• 文档 >
  • torch.nn >
  • torch.nn.utils.rnn.pack_padded_sequence
快捷键

torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[source][source]

包含可变长度填充序列的 Tensor。

input 的大小可以是 T x B x * (如果 batch_firstFalse )或 B x T x * (如果 batch_firstTrue ),其中 T 是最长序列的长度, B 是批大小, * 是任意数量的维度(包括 0)。

对于未排序的序列,请使用 enforce_sorted = False。如果 enforce_sortedTrue ,则序列应按长度降序排序,即 input[:,0] 应为最长序列, input[:,B-1] 为最短序列。enforce_sorted = True 仅在 ONNX 导出时必要。

这是 pad_packed_sequence() 的逆操作,因此 pad_packed_sequence() 可以用来恢复 PackedSequence 中打包的底层张量。

注意

此函数接受至少具有两个维度的任何输入。您可以将其应用于打包标签,并使用 RNN 的输出与它们一起计算损失。可以通过访问 PackedSequence 对象的 .data 属性从 Tensor 中检索张量。

参数:
  • 输入(Tensor)- 可变长度的填充批次序列。

  • lengths(Tensor 或 list(int))- 每个批次元素的序列长度列表(如果提供为张量,则必须在 CPU 上)。

  • batch_first(bool,可选)- 如果 True ,则期望输入为 B x T x * 格式,否则为 T x B x * 。默认: False

  • enforce_sorted (bool, 可选) – 如果 True ,则期望输入按长度降序排序的序列。如果 False ,则无条件排序。默认: True

返回类型:

PackedSequence

警告

input 张量的维度将被截断,如果其长度大于 length 中的对应值。

返回值:

一个 PackedSequence 对象

返回类型:

PackedSequence


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源