• 文档 >
  • torch.nn >
  • TransformerDecoder
快捷键

TransformerDecoder

class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[source][source]

TransformerDecoder 是一个由 N 个解码器层堆叠而成的结构。

注意

请参阅此教程,深入了解 PyTorch 提供的性能构建块,用于构建自己的 transformer 层。

参数:
  • decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 类的一个实例(必需)。

  • num_layers (int) – 解码器中子解码器层的数量(必需)。

  • norm (Optional[Module]) – 层归一化组件(可选)。

示例::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[source][source]

依次将输入(和掩码)通过解码器层。

参数:
  • tgt(张量)- 解码器的序列(必需)。

  • memory(张量)- 编码器最后一层的序列(必需)。

  • tgt_mask(可选[Tensor])- tgt 序列的掩码(可选)。

  • memory_mask(可选[Tensor])- 内存序列的掩码(可选)。

  • tgt_key_padding_mask(可选[Tensor])- 每批次的 tgt 键掩码(可选)。

  • memory_key_padding_mask(可选[Tensor])- 每批次的内存键掩码(可选)。

  • tgt_is_causal(可选[bool])- 如果指定,则应用因果掩码。默认: None ;尝试检测因果掩码。警告: tgt_is_causal 提供提示, tgt_mask 是因果掩码。提供错误的提示可能导致执行错误,包括前向和后向兼容性。

  • memory_is_causal (bool) – 如果指定,则应用因果掩码作为 memory mask 。默认: False 。警告: memory_is_causal 提供了 memory_mask 是因果掩码的提示。提供错误的提示可能导致执行错误,包括向前和向后兼容性。

返回类型:

张量

形状:

查看文档在 Transformer


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源