备注
点击此处下载完整示例代码
(Beta) 使用缩放点积注意力(SDPA)实现高性能 Transformer ¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:Driss Guessous
摘要
在本教程中,我们想强调一个新功能 torch.nn.functional
,该功能对于实现 Transformer 架构可能非常有用。该函数名为 torch.nn.functional.scaled_dot_product_attention
。有关该函数的详细描述,请参阅 PyTorch 文档。此功能已被集成到 torch.nn.MultiheadAttention
和 torch.nn.TransformerEncoderLayer
中。
概述 ¶
从高层次来看,此 PyTorch 函数根据论文《Attention is all you need》中找到的定义计算查询、键和值之间的缩放点积注意力(SDPA)。虽然可以使用现有函数在 PyTorch 中编写此函数,但融合实现可以提供比原始实现更大的性能优势。
融合实现
对于 CUDA 张量输入,该函数将调度以下实现之一:
使用 C++定义的 PyTorch 实现
备注
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
显式调度器控制
虽然函数会隐式地调度到三种实现之一,但用户也可以通过使用上下文管理器来显式控制调度。这个上下文管理器允许用户显式禁用某些实现。如果用户想确保该函数确实使用了针对其特定输入的最快实现,可以使用上下文管理器来遍历测量性能。
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.MATH):
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The math implementation runs in {math_time:.3f} microseconds")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
try:
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
硬件依赖
根据您在上面单元格中运行的机器以及可用的硬件,您得到的结果可能会有所不同。 - 如果您没有 GPU 并且正在 CPU 上运行,那么使用 FP32 的上下文管理器将没有任何效果,三次运行应该返回相似的计时。 - 根据您的显卡支持的计算能力,快速注意力或内存高效可能已失败。
因果自注意力
以下是一个受 Andrej Karpathy NanoGPT 仓库启发的多头因果自注意力块的示例实现。
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
NestedTensor
和稠密张量支持 ¶
SDPA 支持 NestedTensor
和稠密张量输入。 NestedTensors
处理输入为变量长度序列批次的场景,无需将每个序列填充到批次中的最大长度。有关 NestedTensors
的更多信息,请参阅 torch.nested 和嵌套张量教程。
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
使用 SDPA 与 torch.compile
¶
随着 PyTorch 2.0 的发布,引入了一个名为 torch.compile()
的新功能,它可以在急切模式下提供显著的性能提升。缩放点积注意力机制可以完全与 torch.compile()
组合使用。为了展示这一点,让我们使用 torch.compile()
编译 CausalSelfAttention
模块,并观察性能提升。
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
精确的执行时间取决于机器,然而我的结果如下:未编译的模块运行时间为 166.616 微秒,编译后的模块运行时间为 166.726 微秒。这并不是我们预期的结果。让我们深入挖掘一下。PyTorch 自带一个惊人的内置分析器,您可以使用它来检查代码的性能特征。
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
之前的代码片段生成了消耗最多 GPU 执行时间的 10 个 PyTorch 函数的报告,包括编译和非编译模块。分析显示,在 GPU 上花费的大部分时间集中在相同的一组函数上。这里的原因在于 torch.compile
非常擅长移除与 PyTorch 相关的框架开销。如果你的模型正在启动大型、高效的 CUDA 内核,在这种情况下 CausalSelfAttention
就是,那么 PyTorch 的开销就可以被隐藏。
实际上,你的模块通常不只是一个 CausalSelfAttention
块。在实验 Andrej Karpathy NanoGPT 仓库时,编译模块将每步训练的时间从 6090.49ms
降低到 3273.17ms
!这是在 NanoGPT 训练莎士比亚数据集的 commit: ae3a8d5
时完成的。
使用 SDPA 与 attn_bias 子类
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
print(type(upper_left_bias))
print(type(lower_right_bias))
assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``
# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.
# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)
# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
结论 ¶
在本教程中,我们展示了 torch.nn.functional.scaled_dot_product_attention
的基本用法。我们展示了如何使用 sdpa_kernel
上下文管理器来断言在 GPU 上使用了特定的实现。此外,我们还构建了一个简单的 CausalSelfAttention
模块,该模块与 NestedTensor
一起工作且可 torch 编译。在这个过程中,我们展示了如何使用性能分析工具来探索用户定义模块的性能特征。
脚本总运行时间:(0 分钟 0.000 秒)