快捷键

torch.nn.attention.bias.CausalBias

类 torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[source][source] ¶

表示因果注意力模式的偏差。有关偏差结构的概述,请参阅 CausalVariant 枚举。

此类用于定义因果(三角形)注意力偏差。构建偏差时,存在两个工厂函数: causal_upper_left()causal_lower_right()

示例:

from torch.nn.attention.bias import causal_lower_right

bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8

# Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)

q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(q, k, v, attn_bias)

警告

此类是一个原型,可能随时更改。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源