• 文档 >
  • torch.nn >
  • Transformer 编码器
快捷键

Transformer 编码器 ¶

class torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True)[source][source]

Transformer 编码器是一个 N 个编码层的堆叠。

注意

请参阅此教程,深入了解 PyTorch 提供的性能构建块,用于构建自己的 transformer 层。

用户可以使用相应的参数构建 BERT(https://arxiv.org/abs/1810.04805)模型。

参数:
  • 编码器层(TransformerEncoderLayer)- TransformerEncoderLayer()类的实例(必需)。

  • num_layers(int)- 编码器中子编码器层的数量(必需)。

  • norm(Optional[Module])- 层归一化组件(可选)。

  • enable_nested_tensor (bool) – 如果为 True,输入将自动转换为嵌套张量(并在输出时转换回)。这将在填充率较高时提高 TransformerEncoder 的整体性能。默认: True (启用)。

示例::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
forward(src, mask=None, src_key_padding_mask=None, is_causal=None)[source][source]

依次通过编码器层传递输入。

参数:
  • src (Tensor) – 传递给编码器的序列(必需)。

  • mask(可选[Tensor])- 源序列的掩码(可选)。

  • src_key_padding_mask(可选[Tensor])- 每批次的源键掩码(可选)。

  • is_causal(可选[bool])- 如果指定,则应用因果掩码。默认: None ;尝试检测因果掩码。警告: is_causal 提供提示, mask 是因果掩码。提供错误的提示可能导致执行错误,包括前向和后向兼容性。

返回类型:

张量

形状:

请参阅 Transformer 中的文档。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源