torch.nn.attention.bias.CausalBias¶
- 类 torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[source][source] ¶
表示因果注意力模式的偏差。有关偏差结构的概述,请参阅
CausalVariant
枚举。此类用于定义因果(三角形)注意力偏差。构建偏差时,存在两个工厂函数:
causal_upper_left()
和causal_lower_right()
。示例:
from torch.nn.attention.bias import causal_lower_right bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 # Create a lower-right causal bias attn_bias = causal_lower_right(seqlen_q, seqlen_kv) q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) out = F.scaled_dot_product_attention(q, k, v, attn_bias)
警告
此类是一个原型,可能随时更改。