备注
点击此处下载完整示例代码
使用用户定义的 Triton 内核 torch.compile
¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:Oğuz Ulgen
用户定义的 Triton 内核可用于优化模型计算中的特定部分。这些内核是用 Triton 语言编写的,该语言旨在使实现硬件性能峰值更加容易。通过使用用户定义的 Triton 内核与 torch.compile
,您可以将这些优化的计算集成到您的 PyTorch 模型中,从而可能实现显著的性能提升。
本示例演示了如何使用用户定义的 Triton 内核与 torch.compile
。
前提条件 _
在开始此菜谱之前,请确保您有以下条件:
对
torch.compile
和 Triton 的基本理解。参见:PyTorch 2.3 或更高版本
支持 Triton 的 GPU
import torch
from torch.utils._triton import has_triton
基本用法 ¶
在本例中,我们将使用 Triton 文档中的一个简单向量加法内核。仅供参考,请参阅 Triton 文档。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
高级用法 ¶
Triton 的自动调优功能是一个强大的工具,它可以自动优化您的 Triton 内核的配置参数。它会探索一系列可能的配置,并选择最适合您特定用例的最佳性能配置。
当与 torch.compile
结合使用时, triton.autotune
可以帮助确保您的 PyTorch 模型运行得尽可能高效。以下是一个使用 torch.compile
和 triton.autotune
的示例。
备注
torch.compile
仅支持 triton.autotune
的配置和关键参数。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
可组合性 ¶
用户定义的 Triton 内核默认不支持所有 PyTorch 子系统。这可以在以下用例中看到:
添加 CPU 回退
添加
FlopCounter
公式使用张量子类进行创作
要与额外的 PyTorch 子系统进行组合,请使用 torch.library.triton_op
。
triton_op is
是一种以结构化方式定义由一个或多个 Triton 内核支持的定制操作符的方法:与常规定制操作符( torch.library.custom_op
)一样,您可以通过 torch.library
指定与 PyTorch 子系统的交互。然而,与 torch.library.custom_op
不同,它创建了对 torch.compile
透明的可调用对象,而 torch.compile
则追踪到 triton_op
以应用优化。
这里有一张图表,说明了在将 Triton 内核与 PyTorch 集成时应使用哪个 API。
特里顿内核(无显式 |
|
|
|
---|---|---|---|
支持推理 |
是的 |
是的 |
是的 |
支持训练 |
在大多数情况下 |
是的 |
是的 |
支持 |
是的 |
是的 |
是的 |
支持 |
在大多数情况下 |
在大多数情况下 |
在所有情况下 |
torch.compile 的 trace 是否会被编译到实现中? |
是的 |
是的 |
不 |
支持 AOT 电感 |
是的 |
是的 |
不 |
支持 PyTorch 子系统,如 FlopCounterMode、CPU 回退、张量子类 |
不 |
是的 |
是的 |
使用 triton_op
包裹 Triton 内核
使用 torch.library.triton_op
包装可能调用一个或多个 Triton 内核的函数。使用 torch.library.wrap_triton
包装对 Triton 内核的调用。
from torch.library import triton_op, wrap_triton
@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
def sin_triton(x):
out = torch.empty_like(x)
n_elements = x.numel()
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
您可以通过以下两种方式之一调用 triton_op
。
x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)
assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())
结果的 triton_op
与 torch.compile
和 AOTInductor
兼容。
y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())
添加训练支持
使用 register_autograd
为 triton_op
添加自动求导公式。优先使用此方法,而不是使用 torch.autograd.Function
(它与 torch.compile
有各种可组合性陷阱)。
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad * x.cos()
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)
mysin.register_autograd(backward, setup_context=setup_context)
注意反向传播必须是 PyTorch 理解的运算符的组合。如果您想反向传播调用 Triton 内核,则这些内核也必须包装在 triton_op
中:
@triton.jit
def cos_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.cos(x)
tl.store(out_ptr + offsets, output, mask=mask)
@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad * mycos(x)
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)
mysin.register_autograd(backward, setup_context=setup_context)
添加 CPU 回退
Triton 内核不支持在 CPU 上运行。使用 register_kernel
添加 CPU(或任何其他设备)降级以用于 triton_op
:
@mysin.register_kernel("cpu")
def _(x):
return torch.sin(x)
x = torch.randn(3)
y = mysin(x)
assert torch.allclose(y, x.sin())
降级必须由 PyTorch 运算符组成。
添加 FlopCounter 公式
要指定在 PyTorch 的 flop 计数器下 triton 内核报告的 flops 数量,请使用 register_flop_formula
。
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
numel = 1
for s in x_shape:
numel *= s
return numel
x = torch.randn(3, device="cuda")
FlopCounterMode
需要 tabulate。在运行以下代码之前,请确保您已安装 tabulate
或通过运行 pip install tabulate
进行安装。
>>> with FlopCounterMode() as flop_counter:
>>> y = mysin(x)
局限性
截至 PyTorch 2.3, torch.compile
对用户定义的 Triton 内核的支持包括动态形状、 torch.autograd.Function
、JIT 诱导器和 AOT 诱导器。您可以使用这些功能一起构建复杂、高性能的模型。
PyTorch 2.6 新增了 torch.library.triton_op
,该功能为 tensor 子类和其他高级特性添加了对用户定义的 Triton 内核的支持。
然而,需要注意一些限制:
Triton 特性:
triton.heuristics
可以单独使用或在triton.autotune
之前使用,但不能在triton.autotune
之后使用。这意味着如果需要同时使用triton.heuristics
和triton.autotune
,则必须首先使用triton.heuristics
。
结论 ¶
在本菜谱中,我们探讨了如何使用 torch.compile
利用用户定义的 Triton 内核。我们深入研究了简单向量加法内核的基本用法以及涉及 Triton 自动调优功能的高级用法。我们还讨论了用户定义的 Triton 内核与其他 PyTorch 特性的可组合性,并强调了当前的一些限制。