• 文档 >
  • 量化 >
  • 量化 API 参考 >
  • 多头注意力
快捷键

多头注意力

class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]
dequantize()[source][source]

将量化后的 MHA 转换回浮点数的实用工具。

将量化版本中使用的权重格式转换回浮点格式并非易事,这就是动机所在。

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]
注意:::

请参阅 forward() 获取更多信息

参数:
  • query (Tensor) – 将查询和一组键值对映射到输出。详见“Attention Is All You Need”。

  • key (Tensor) – 将查询和一组键值对映射到输出。详见“Attention Is All You Need”获取更多详情。

  • value (Tensor) – 将查询和一组键值对映射到输出。详见“Attention Is All You Need”获取更多详情。

  • key_padding_mask (Optional[Tensor]) – 如果提供,指定在键中的填充元素将由注意力忽略。当给定二进制掩码且值为 True 时,对应于注意力层的值将被忽略。

  • need_weights (bool) – 输出 attn_output_weights。

  • attn_mask(可选[Tensor])- 2D 或 3D 掩码,用于阻止对某些位置的注意力。2D 掩码将广播到所有批次,而 3D 掩码允许为每个批次的条目指定不同的掩码。

返回类型:

tuple[torch.Tensor, Optional[torch.Tensor]]

形状:
  • 输入:

  • query: (L,N,E)(L, N, E) 其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 (N,L,E)(N, L, E) 如果 batch_firstTrue

  • key: (S,N,E)(S, N, E) ,其中 S 为源序列长度,N 为批大小,E 为嵌入维度。 (N,S,E)(N, S, E) 如果 batch_firstTrue

  • value: (S,N,E)(S, N, E) 其中 S 为源序列长度,N 为批大小,E 为嵌入维度。 (N,S,E)(N, S, E) 如果 batch_firstTrue

  • key_padding_mask: (N,S)(N, S) 其中 N 为批大小,S 为源序列长度。如果提供一个布尔张量,则值为 True 的位置将被忽略,而值为 False 的位置将保持不变。

  • attn_mask: 2D mask (L,S)(L, S) 其中 L 为目标序列长度,S 为源序列长度。3D mask (Nnumheads,L,S)(N*num_heads, L, S) 其中 N 为批大小,L 为目标序列长度,S 为源序列长度。attn_mask 确保位置 i 可以关注未掩码的位置。如果提供一个布尔张量,则值为 True 的位置不允许关注,而 False 值将保持不变。如果提供一个浮点张量,它将被添加到注意力权重中。

  • 如果指定,则应用因果掩码作为注意力掩码。与提供 attn_mask 互斥。默认: False

  • average_attn_weights:如果为 true,表示返回的 attn_weights 将在头部之间平均。否则,提供每个头部的 attn_weights 。注意,此标志仅在 need_weights=True. 时才有效。默认:True(即在头部之间平均权重)

  • 输出:

  • attn_output: (L,N,E)(L, N, E) ,其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 (N,L,E)(N, L, E) 如果 batch_firstTrue

  • attn_output_weights:如果 average_attn_weights=True ,则返回平均头部注意力的形状为 (N,L,S)(N, L, S) ,其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。如果 average_attn_weights=False ,则返回每个头部的注意力权重形状为 (N,numheads,L,S)(N, num_heads, L, S)


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源