多头注意力
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
允许模型联合关注来自不同表示子空间的信息。
注意
请参阅此教程,深入了解 PyTorch 提供的性能构建块,用于构建自己的 transformer 层。
方法见论文:Attention Is All You Need。
多头注意力被定义为:
其中 。
当可能时,将使用
scaled_dot_product_attention()
的优化实现nn.MultiheadAttention
。除了支持新的
scaled_dot_product_attention()
函数外,为了加速推理,MHA 将使用支持嵌套张量的 fastpath 推理,如果:自注意力正在被计算(即
query
、key
和value
是相同的张量)。输入以批处理(3D)方式与
batch_first==True
结合要么自动微分被禁用(使用
torch.inference_mode
或torch.no_grad
),要么没有张量参数requires_grad
训练被禁用(使用
.eval()
)add_bias_kv
是False
add_zero_attn
是False
kdim
和vdim
等于embed_dim
如果传递了 NestedTensor,则既不传递
key_padding_mask
也不传递attn_mask
自动广播被禁用
如果正在使用优化的推理快速路径实现,则可以通过
query
/key
/value
传递 NestedTensor 来更有效地表示填充,比使用填充掩码更高效。在这种情况下,将返回 NestedTensor,并且可以期望额外的加速,其比例与输入中填充的部分成比例。- 参数:
embed_dim – 模型的总维度。
num_heads – 并行注意力头数。注意
embed_dim
将被分割到num_heads
(即每个头将有维度embed_dim // num_heads
)。dropout – 在
attn_output_weights
上的 Dropout 概率。默认:0.0
(无 dropout)。bias – 如果指定,则向输入/输出投影层添加偏置。默认:
True
。add_bias_kv – 如果指定,则在 dim=0 的键和值序列上添加偏置。默认:
False
。add_zero_attn – 如果指定,则在 dim=1 的键和值序列中添加一个新的零批次。默认:
False
。kdim – 键的总特征数。默认:
None
(使用kdim=embed_dim
)。vdim – 值的总特征数。默认:
None
(使用vdim=embed_dim
)。batch_first – 如果
True
,则输入和输出张量提供为(batch,seq,feature)。默认:False
(seq,batch,feature)。
示例:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
使用查询、键和值嵌入计算注意力输出。
支持可选的填充、掩码和注意力权重参数。
- 参数:
查询(张量)- 对于未批处理的输入,查询嵌入的形状为 ,当
batch_first=False
时为 ,当 时为batch_first=True
,其中 是目标序列长度, 是批大小, 是查询嵌入维度embed_dim
。查询与键值对进行比较以生成输出。参见“Attention Is All You Need”获取更多详细信息。键(张量)- 对于未批处理的输入,键嵌入的形状为 ,当
batch_first=False
时为 ,当 时为batch_first=True
,其中 是源序列长度, 是批大小, 是键嵌入维度kdim
。参见“Attention Is All You Need”获取更多详细信息。值(Tensor)- 值嵌入的形状为 ,对于未批处理输入, ,当
batch_first=False
或 时为batch_first=True
,其中 是源序列长度, 是批处理大小, 是值嵌入维度vdim
。参见“Attention Is All You Need”获取更多详细信息。key_padding_mask(可选[Tensor])- 如果指定,则表示要忽略的元素掩码,形状为 ,用于注意力机制(即视为“填充”)。对于未批处理的查询,形状应为 。支持二进制和浮点掩码。对于二进制掩码,
True
值表示将忽略对应的key
值。对于浮点掩码,它将直接添加到相应的key
值。need_weights(布尔值)- 如果指定,则除了
attn_outputs
外还返回attn_output_weights
。设置need_weights=False
以使用优化的scaled_dot_product_attention
并实现 MHA 的最佳性能。默认:True
。attn_mask(可选[Tensor])- 如果指定,则是一个 2D 或 3D 掩码,用于阻止对某些位置的注意力。其形状必须为 或 ,其中 是批量大小, 是目标序列长度, 是源序列长度。2D 掩码将在批量中广播,而 3D 掩码允许每个批量的条目有不同的掩码。支持二进制和浮点掩码。对于二进制掩码,
True
值表示相应的位置不允许进行注意力。对于浮点掩码,掩码值将添加到注意力权重中。如果同时提供了 attn_mask 和 key_padding_mask,则它们的数据类型应匹配。average_attn_weights(布尔值)- 如果为 true,表示返回的
attn_weights
将在头部之间平均。否则,attn_weights
将按每个头部分别提供。请注意,此标志仅在need_weights=True
时才有作用。默认:True
(即平均头部权重)is_causal(布尔值)- 如果指定,则将因果掩码作为注意力掩码应用。默认:
False
。警告:is_causal
提供提示,attn_mask
是因果掩码。提供错误的提示可能导致执行错误,包括前向和反向兼容性。
- 返回类型:
tuple[torch.Tensor, Optional[torch.Tensor]]
- 输出:
attn_output - 注意力输出形状为 ,当输入未批处理时, ,当
batch_first=False
或 时,batch_first=True
,其中 是目标序列长度, 是批大小, 是嵌入维度embed_dim
。attn_output_weights - 只有在
need_weights=True
时返回。如果average_attn_weights=True
,则返回跨头的平均注意力权重,形状为 当输入未批处理时,或 ,其中 是批大小, 是目标序列长度, 是源序列长度。如果average_attn_weights=False
,则返回每个头的注意力权重,形状为 当输入未批处理时,或 。
注意
batch_first 参数对于未批处理的输入被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[source][source]¶
确定掩码类型并在必要时合并掩码。
如果只提供一个掩码,则返回该掩码及其对应的掩码类型。如果提供两个掩码,它们将被扩展到形状
(batch_size, num_heads, seq_len, seq_len)
,通过逻辑or
合并,并返回掩码类型 2:@param attn_mask: 注意力掩码的形状(seq_len, seq_len)
,掩码类型 0:@param key_padding_mask: 填充掩码的形状(batch_size, seq_len)
,掩码类型 1:@param query: 查询嵌入的形状(batch_size, seq_len, embed_dim)
- 返回值:
合并掩码掩码类型:合并掩码类型(0、1 或 2)
- 返回类型:
merged_mask