快捷键

多头注意力

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]

允许模型联合关注来自不同表示子空间的信息。

注意

请参阅此教程,深入了解 PyTorch 提供的性能构建块,用于构建自己的 transformer 层。

方法见论文:Attention Is All You Need。

多头注意力被定义为:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

其中 headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

当可能时,将使用 scaled_dot_product_attention() 的优化实现 nn.MultiheadAttention

除了支持新的 scaled_dot_product_attention() 函数外,为了加速推理,MHA 将使用支持嵌套张量的 fastpath 推理,如果:

  • 自注意力正在被计算(即 querykeyvalue 是相同的张量)。

  • 输入以批处理(3D)方式与 batch_first==True 结合

  • 要么自动微分被禁用(使用 torch.inference_modetorch.no_grad ),要么没有张量参数 requires_grad

  • 训练被禁用(使用 .eval()

  • add_bias_kvFalse

  • add_zero_attnFalse

  • kdimvdim 等于 embed_dim

  • 如果传递了 NestedTensor,则既不传递 key_padding_mask 也不传递 attn_mask

  • 自动广播被禁用

如果正在使用优化的推理快速路径实现,则可以通过 query / key / value 传递 NestedTensor 来更有效地表示填充,比使用填充掩码更高效。在这种情况下,将返回 NestedTensor,并且可以期望额外的加速,其比例与输入中填充的部分成比例。

参数:
  • embed_dim – 模型的总维度。

  • num_heads – 并行注意力头数。注意 embed_dim 将被分割到 num_heads (即每个头将有维度 embed_dim // num_heads )。

  • dropout – 在 attn_output_weights 上的 Dropout 概率。默认: 0.0 (无 dropout)。

  • bias – 如果指定,则向输入/输出投影层添加偏置。默认: True

  • add_bias_kv – 如果指定,则在 dim=0 的键和值序列上添加偏置。默认: False

  • add_zero_attn – 如果指定,则在 dim=1 的键和值序列中添加一个新的零批次。默认: False

  • kdim – 键的总特征数。默认: None (使用 kdim=embed_dim )。

  • vdim – 值的总特征数。默认: None (使用 vdim=embed_dim )。

  • batch_first – 如果 True ,则输入和输出张量提供为(batch,seq,feature)。默认: False (seq,batch,feature)。

示例:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]

使用查询、键和值嵌入计算注意力输出。

支持可选的填充、掩码和注意力权重参数。

参数:
  • 查询(张量)- 对于未批处理的输入,查询嵌入的形状为 (L,Eq)(L, E_q) ,当 batch_first=False 时为 (L,N,Eq)(L, N, E_q) ,当 (N,L,Eq)(N, L, E_q) 时为 batch_first=True ,其中 LL 是目标序列长度, NN 是批大小, EqE_q 是查询嵌入维度 embed_dim 。查询与键值对进行比较以生成输出。参见“Attention Is All You Need”获取更多详细信息。

  • 键(张量)- 对于未批处理的输入,键嵌入的形状为 (S,Ek)(S, E_k) ,当 batch_first=False 时为 (S,N,Ek)(S, N, E_k) ,当 (N,S,Ek)(N, S, E_k) 时为 batch_first=True ,其中 SS 是源序列长度, NN 是批大小, EkE_k 是键嵌入维度 kdim 。参见“Attention Is All You Need”获取更多详细信息。

  • 值(Tensor)- 值嵌入的形状为 (S,Ev)(S, E_v) ,对于未批处理输入, (S,N,Ev)(S, N, E_v) ,当 batch_first=False(N,S,Ev)(N, S, E_v) 时为 batch_first=True ,其中 SS 是源序列长度, NN 是批处理大小, EvE_v 是值嵌入维度 vdim 。参见“Attention Is All You Need”获取更多详细信息。

  • key_padding_mask(可选[Tensor])- 如果指定,则表示要忽略的元素掩码,形状为 (N,S)(N, S) ,用于注意力机制(即视为“填充”)。对于未批处理的查询,形状应为 (S)(S) 。支持二进制和浮点掩码。对于二进制掩码, True 值表示将忽略对应的 key 值。对于浮点掩码,它将直接添加到相应的 key 值。

  • need_weights(布尔值)- 如果指定,则除了 attn_outputs 外还返回 attn_output_weights 。设置 need_weights=False 以使用优化的 scaled_dot_product_attention 并实现 MHA 的最佳性能。默认: True

  • attn_mask(可选[Tensor])- 如果指定,则是一个 2D 或 3D 掩码,用于阻止对某些位置的注意力。其形状必须为 (L,S)(L, S)(Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S) ,其中 NN 是批量大小, LL 是目标序列长度, SS 是源序列长度。2D 掩码将在批量中广播,而 3D 掩码允许每个批量的条目有不同的掩码。支持二进制和浮点掩码。对于二进制掩码, True 值表示相应的位置不允许进行注意力。对于浮点掩码,掩码值将添加到注意力权重中。如果同时提供了 attn_mask 和 key_padding_mask,则它们的数据类型应匹配。

  • average_attn_weights(布尔值)- 如果为 true,表示返回的 attn_weights 将在头部之间平均。否则, attn_weights 将按每个头部分别提供。请注意,此标志仅在 need_weights=True 时才有作用。默认: True (即平均头部权重)

  • is_causal(布尔值)- 如果指定,则将因果掩码作为注意力掩码应用。默认: False 。警告: is_causal 提供提示, attn_mask 是因果掩码。提供错误的提示可能导致执行错误,包括前向和反向兼容性。

返回类型:

tuple[torch.Tensor, Optional[torch.Tensor]]

输出:
  • attn_output - 注意力输出形状为 (L,E)(L, E) ,当输入未批处理时, (L,N,E)(L, N, E) ,当 batch_first=False(N,L,E)(N, L, E) 时, batch_first=True ,其中 LL 是目标序列长度, NN 是批大小, EE 是嵌入维度 embed_dim

  • attn_output_weights - 只有在 need_weights=True 时返回。如果 average_attn_weights=True ,则返回跨头的平均注意力权重,形状为 (L,S)(L, S) 当输入未批处理时,或 (N,L,S)(N, L, S) ,其中 NN 是批大小, LL 是目标序列长度, SS 是源序列长度。如果 average_attn_weights=False ,则返回每个头的注意力权重,形状为 (num_heads,L,S)(\text{num\_heads}, L, S) 当输入未批处理时,或 (N,num_heads,L,S)(N, \text{num\_heads}, L, S)

注意

batch_first 参数对于未批处理的输入被忽略。

merge_masks(attn_mask, key_padding_mask, query)[source][source]

确定掩码类型并在必要时合并掩码。

如果只提供一个掩码,则返回该掩码及其对应的掩码类型。如果提供两个掩码,它们将被扩展到形状 (batch_size, num_heads, seq_len, seq_len) ,通过逻辑 or 合并,并返回掩码类型 2:@param attn_mask: 注意力掩码的形状 (seq_len, seq_len) ,掩码类型 0:@param key_padding_mask: 填充掩码的形状 (batch_size, seq_len) ,掩码类型 1:@param query: 查询嵌入的形状 (batch_size, seq_len, embed_dim)

返回值:

合并掩码掩码类型:合并掩码类型(0、1 或 2)

返回类型:

merged_mask


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源