torch.nn.attention.flex_attention¶
- torch.nn.attention.flex_attention.flex_attention(查询,键,值,score_mod=None,block_mask=None,scale=None,enable_gqa=False,return_lse=False,kernel_options=None)[source][source] ¶
此函数实现了具有任意注意力分数修改函数的缩放点积注意力。
此函数计算查询、键和值张量之间的缩放点积注意力,并使用用户定义的注意力分数修改函数。注意力分数修改函数将在计算查询和键张量之间的注意力分数之后应用。注意力分数的计算方法如下:
score_mod
函数应具有以下签名:def score_mod( score: Tensor, batch: Tensor, head: Tensor, q_idx: Tensor, k_idx: Tensor ) -> Tensor:
- 位置:
score
:表示注意力分数的标量张量,其数据类型和设备与查询、键和值张量相同。batch
,head
,q_idx
,k_idx
:分别表示批处理索引、查询头索引、查询索引和键/值索引的标量张量。这些张量应具有torch.int
数据类型,并且位于与分数张量相同的设备上。
- 参数:
查询(张量)- 查询张量;形状 。
密钥(张量)- 密钥张量;形状 。
值(张量)- 值张量;形状 。
score_mod(可选[可调用])- 修改注意力分数的函数。默认不应用 score_mod。
block_mask (Optional[BlockMask]) – 控制注意力稀疏模式的 BlockMask 对象。
scale (Optional[float]) – 在 softmax 之前应用的缩放因子。如果没有指定,默认值设置为 。enable_gqa (bool) – 如果设置为 True,则启用分组查询注意力(GQA)并将键/值头广播到查询头。
return_lse (bool) – 是否返回注意力分数的 logsumexp。默认为 False。
kernel_options(可选[Dict[str, Any]])- 将传递给 Triton 内核的选项。
- 返回:
注意力输出;形状 。
- 返回类型:
输出(张量)
- 形状说明:
警告
torch.nn.attention.flex_attention 是 PyTorch 中的一个原型功能。请期待 PyTorch 未来版本中更稳定的实现。更多关于功能分类的信息请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
BlockMask 工具
- torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[source][source]¶
此函数从 mask_mod 函数创建一个块掩码元组。
- 参数:
mask_mod (Callable) – mask_mod 函数。这是一个可调用的函数,用于定义注意力机制的掩码模式。它接受四个参数:b(批大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。它应该返回一个布尔张量,指示哪些注意力连接被允许(True)或被掩码(False)。
B (int) – 批大小。
H (int) – 查询头数。
Q_LEN (int) – 查询序列长度。
KV_LEN(int)- 键/值序列长度。
device(str)- 在其上运行掩码创建的设备。
BLOCK_SIZE(int 或 tuple[int, int])- 块掩码的块大小。如果提供一个单独的 int,则用于查询和键/值。
- 返回:
包含块掩码信息的 BlockMask 对象。
- 返回类型:
- 示例用法:
def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[source][source]¶
此函数从 mod_fn 函数创建一个掩码张量。
- 参数:
mod_fn (Union[_score_mod_signature, _mask_mod_signature]) – 函数用于修改注意力分数。
B (int) – 批量大小。
H (int) – 查询头数量。
Q_LEN (int) – 查询序列长度。
KV_LEN(int)- 键/值序列长度。
device(str)- 在其上运行掩码创建的设备。
- 返回:
形状为(B,H,M,N)的掩码张量。
- 返回类型:
mask(张量)
- torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[source][source]¶
此函数从 mask_mod 函数创建一个与嵌套张量兼容的块掩码元组。返回的 BlockMask 将位于输入嵌套张量指定的设备上。
- 参数:
mask_mod (Callable) – mask_mod 函数。这是一个可调用的函数,用于定义注意力机制的掩码模式。它接受四个参数:b(批大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。它应返回一个布尔张量,指示哪些注意力连接被允许(True)或屏蔽(False)。
B (int) – 批大小。
H (int) – 查询头数量。
q_nt (torch.Tensor) – 杂乱布局嵌套张量(NJT),用于定义查询的序列长度结构。块掩码将构建为在 NJT 的序列长度
S
上操作“堆叠序列”的长度sum(S)
。默认:Nonekv_nt (torch.Tensor) – 杂乱布局嵌套张量(NJT),用于定义键/值的序列长度结构,允许交叉注意力。块掩码将构建为在 NJT 的序列长度
S
上操作“堆叠序列”的长度sum(S)
。如果此值为 None,则使用q_nt
定义键/值的结构。默认:NoneBLOCK_SIZE (int 或 tuple[int, int]) – 块掩码的块大小。如果提供一个单个 int,则用于查询和键/值。默认:None
- 返回:
包含块掩码信息的 BlockMask 对象。
- 返回类型:
- 示例用法:
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx # cross attention case: pass both query and key/value NJTs block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.and_masks(*mask_mods)[source][source]¶
返回一个与提供的 mask_mods 交集的 mask_mod
- 返回类型:
可调用函数 [Tensor, Tensor, Tensor, Tensor] -> Tensor
- torch.nn.attention.flex_attention.or_masks(*mask_mods)[source][source]
返回一个与提供的 mask_mods 并集的 mask_mod
- 返回类型:
可调用[[Tensor, Tensor, Tensor, Tensor], Tensor]
块掩码
- class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[source][source]¶
BlockMask 是我们表示块稀疏注意力掩码的格式。它在 BCSR 和非稀疏格式之间有所交叉。
块稀疏掩码意味着,我们不是表示掩码中单个元素的稀疏性,而是将 KV_BLOCK_SIZE x Q_BLOCK_SIZE 的块视为稀疏的,只有当该块中的所有元素都是稀疏的时候。这与硬件相吻合,因为硬件通常期望执行连续的加载和计算。
此格式主要针对 1. 简单性和 2. 内核效率进行优化。值得注意的是,它没有针对大小进行优化,因为这个掩码总是减少了 KV_BLOCK_SIZE * Q_BLOCK_SIZE 的因子。如果大小是一个问题,可以通过增加块大小来减小张量的大小。
我们格式的要点是:
num_blocks_in_row: Tensor[ROWS]: 描述每行中存在的块的数量。
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: col_indices[i] 是行 i 中块位置的序列。此行在 col_indices[i][num_blocks_in_row[i]] 之后的所有值都是未定义的。
例如,要从这种格式中重建原始张量:
dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1
显然,这种格式使得对掩码行进行降维实现起来更加容易。
我们格式的最基本要求只需要 kv_num_blocks 和 kv_indices。但是,这个对象最多可以有 8 个张量。这代表了 4 对:
1. (kv_num_blocks, kv_indices):用于注意力的正向传递,因为我们是在 KV 维度上进行降维。
2. [可选] (full_kv_num_blocks, full_kv_indices):这是可选的,纯粹是优化。实际上,对每个块应用掩码是非常昂贵的!如果我们具体知道哪些块是“完整”的,并且根本不需要掩码,那么我们可以跳过对这些块应用 mask_mod。这要求用户从 score_mod 中分离出一个单独的 mask_mod。对于因果掩码,这大约可以提升 15% 的速度。
3. [自动生成] (q_num_blocks, q_indices):反向传播所必需,因为计算 dKV 需要沿着 Q 维度遍历掩码。这些将自动从 1 生成。
4. [自动生成] (full_q_num_blocks, full_q_indices):与上面相同,但用于反向传播。这些将自动从 2 生成。
- BLOCK_SIZE 元组[int, int] ¶
- as_tuple(flatten=True)[源代码][源代码] ¶
返回 BlockMask 属性的元组。
- 参数:
flatten (布尔值) – 如果为 True,则将(KV_BLOCK_SIZE, Q_BLOCK_SIZE)元组进行展平。
- @classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None)[source][source] ¶
从键值块信息创建 BlockMask 实例。
- 参数:
kv_num_blocks (Tensor) – 每个 Q_BLOCK_SIZE 行瓦片中的 kv_blocks 数量。
kv_indices (Tensor) – 每个 Q_BLOCK_SIZE 行瓦片中的键值块索引。
full_kv_num_blocks (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行瓦片中的完整 kv_blocks 数量。
full_kv_indices (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行瓦片中的完整键值块索引。
BLOCK_SIZE(联合整数,整数元组)- KV_BLOCK_SIZE x Q_BLOCK_SIZE 瓦片的大小。
mask_mod(可选[可调用])- 修改掩码的函数。
- 返回:
通过 _transposed_ordered 生成的包含完整 Q 信息的实例。
- 返回类型:
- 引发:
运行时错误 - 如果 kv_indices 维度小于 2。
断言错误 - 如果只提供了 full_kv_* 参数中的一个。
- full_kv_indices 可选 [Tensor]
- full_kv_num_blocks 可选 [Tensor]
- full_q_indicesOptional[张量] ¶
- full_q_num_blocksOptional[张量] ¶
- kv_indices 张量 ¶
- kv_num_blocks 张量 ¶
- property shape¶
- 转移到指定的设备上[源][源] ¶
将 BlockMask 移动到指定的设备
- 参数:
device (torch.device 或 str) – 将 BlockMask 移动到的目标设备。可以是 torch.device 对象或字符串(例如,‘cpu’,‘cuda:0’)。
- 返回:
创建一个新的 BlockMask 实例,其中所有张量组件都已移动到指定的设备。
- 返回类型:
注意
此方法不会就地修改原始 BlockMask。相反,它返回一个新的 BlockMask 实例,其中各个张量属性可能会或可能不会移动到指定的设备,具体取决于它们当前的设备位置。
- to_dense()[来源][来源] ¶
返回一个与块掩码等效的密集块。
- 返回类型:
- to_string(grid_size=(20, 20), limit=4)[来源][来源] ¶
返回块掩码的字符串表示。非常巧妙。
如果 grid_size 为 None,则输出未压缩版本。警告,这可能相当大!