快捷键

torch.nn.functional.scaled_dot_product_attention

torch.nn.functional.scaled_dot_product_attention()
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,

is_causal=False, scale=None, enable_gqa=False) -> Tensor:

计算查询、键和值张量的缩放点积注意力,如果传递了可选的注意力掩码,则应用,如果指定的概率大于 0.0,则应用 dropout。可选的缩放参数只能作为关键字参数指定。

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

警告

此函数处于测试阶段,可能会发生变化。

警告

此函数始终根据指定的 dropout_p 参数应用 dropout。为了在评估期间禁用 dropout,请确保在调用此函数的模块不在训练模式时传递 0.0 的值。

例如:

class MyModel(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, ...):
        return F.scaled_dot_product_attention(...,
            dropout_p=(self.p if self.training else 0.0))

注意

目前支持三种缩放点积注意力的实现:

该函数在使用 CUDA 后端时可能会调用优化内核以提高性能。对于所有其他后端,将使用 PyTorch 实现。

所有实现默认启用。缩放点积注意力尝试根据输入自动选择最优化实现。为了提供更细粒度的控制,以下函数用于启用和禁用实现。上下文管理器是首选机制:

  • torch.nn.attention.sdpa_kernel() :用于启用或禁用任何实现的上下文管理器。

  • 全局启用或禁用 FlashAttention。

  • 全局启用或禁用内存高效注意力。

  • 全局启用或禁用 PyTorch C++实现。

每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现,请使用 torch.nn.attention.sdpa_kernel() 禁用 PyTorch C++实现。如果融合实现不可用,将发出警告,说明为什么融合实现无法运行。

由于融合浮点运算的特性,此函数的输出可能因所选后端内核而异。C++ 实现支持 torch.float64,当需要更高精度时可以使用。对于数学后端,如果输入为 torch.half 或 torch.bfloat16,所有中间结果都保持在 torch.float。

更多信息请参阅数值精度

分组查询注意力(GQA)是一个实验性功能。目前它仅适用于 CUDA 张量上的 Flash_attention 和数学内核,不支持嵌套张量。GQA 的限制条件:

  • 查询头数 % 键值头数 == 0 且

  • 头数键 == 头数值

注意

在某些情况下,当在 CUDA 设备上给定张量并使用 CuDNN 时,此运算符可能会选择非确定性算法以提高性能。如果这不可取,您可以尝试通过设置 torch.backends.cudnn.deterministic = True 来使操作确定性(可能以性能成本为代价)。有关更多信息,请参阅可重现性。

参数:
  • 查询(张量)- 查询张量;形状 (N,...,Hq,L,E)(N, ..., Hq, L, E)

  • 密钥(张量)- 密钥张量;形状 (N,...,H,S,E)(N, ..., H, S, E)

  • 值(张量)- 值张量;形状 (N,...,H,S,Ev)(N, ..., H, S, Ev)

  • attn_mask(可选张量)- 注意力掩码;形状必须可广播到注意力权重形状,即 (N,...,L,S)(N,..., L, S) 。支持两种掩码类型。一种为布尔掩码,其中 True 值表示元素应参与注意力计算。另一种为与查询、键、值相同类型的浮点掩码,将其添加到注意力分数中。

  • dropout_p(浮点数)- Dropout 概率;如果大于 0.0,则应用 dropout

  • is_causal(布尔值)- 如果设置为 true,则当掩码为正方形矩阵时,注意力掩码为下三角矩阵。当掩码为非正方形矩阵时,由于对齐产生的注意力掩码形式为上左因果偏差(见 torch.nn.attention.bias.CausalBias )。如果同时设置了 attn_mask 和 is_causal,则抛出错误。

  • scale(可选 python:float,关键字参数)- 在 softmax 之前应用的缩放因子。如果为 None,则默认设置为 1E\frac{1}{\sqrt{E}}

  • enable_gqa(布尔值)- 如果设置为 True,则启用分组查询注意力(GQA),默认设置为 False。

返回值:

注意力输出;形状 (N,...,Hq,L,Ev)(N, ..., Hq, L, Ev)

返回类型:

输出(张量)

形状说明:
  • N:Batch size...:Any number of other batch dimensions (optional)N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}

  • S:Source sequence lengthS: \text{Source sequence length}

  • L:Target sequence lengthL: \text{Target sequence length}

  • E:Embedding dimension of the query and keyE: \text{Embedding dimension of the query and key}

  • Ev:Embedding dimension of the valueEv: \text{Embedding dimension of the value}

  • Hq:Number of heads of queryHq: \text{Number of heads of query}

  • H:Number of heads of key and valueH: \text{Number of heads of key and value}

示例

>>> # Optionally use the context manager to ensure one of the fused kernels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
>>>     F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3
>>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.MATH]):
>>>     F.scaled_dot_product_attention(query,key,value,enable_gqa=True)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源