Transformer 编码器 ¶
- class torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True)[source][source]¶
Transformer 编码器是一个 N 个编码层的堆叠。
注意
请参阅此教程,深入了解 PyTorch 提供的性能构建块,用于构建自己的 transformer 层。
用户可以使用相应的参数构建 BERT(https://arxiv.org/abs/1810.04805)模型。
- 参数:
编码器层(TransformerEncoderLayer)- TransformerEncoderLayer()类的实例(必需)。
num_layers(int)- 编码器中子编码器层的数量(必需)。
norm(Optional[Module])- 层归一化组件(可选)。
enable_nested_tensor (bool) – 如果为 True,输入将自动转换为嵌套张量(并在输出时转换回)。这将在填充率较高时提高 TransformerEncoder 的整体性能。默认:
True
(启用)。
- 示例::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src)
- forward(src, mask=None, src_key_padding_mask=None, is_causal=None)[source][source]¶
依次通过编码器层传递输入。
- 参数:
src (Tensor) – 传递给编码器的序列(必需)。
mask(可选[Tensor])- 源序列的掩码(可选)。
src_key_padding_mask(可选[Tensor])- 每批次的源键掩码(可选)。
is_causal(可选[bool])- 如果指定,则应用因果掩码。默认:
None
;尝试检测因果掩码。警告:is_causal
提供提示,mask
是因果掩码。提供错误的提示可能导致执行错误,包括前向和后向兼容性。
- 返回类型:
- 形状:
请参阅Transformer
中的文档。