备注
点击此处下载完整示例代码
torch.compile
简介
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:威廉·文
torch.compile
是加速 PyTorch 代码的最新方法! torch.compile
通过 JIT 编译将 PyTorch 代码转换为优化的内核,同时只需进行最少的代码更改,从而加快 PyTorch 代码的运行速度。
在本教程中,我们介绍了 torch.compile
的基本用法,并展示了 torch.compile
相较于之前的 PyTorch 编译器解决方案(如 TorchScript 和 FX Tracing)的优势。
目录
必需的 pip 依赖
torch >= 2.0
torchvision
numpy
scipy
tabulate
系统要求 - C++ 编译器,例如 g++
- Python 开发包( python-devel
/ python-dev
)
注意:为了重现以下和别处记录的加速数值,建议使用现代的 NVIDIA GPU(H100、A100 或 V100)
import torch
import warnings
gpu_ok = False
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()
if device_cap in ((7, 0), (8, 0), (9, 0)):
gpu_ok = True
if not gpu_ok:
warnings.warn(
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
"than expected."
)
基本用法
torch.compile
已包含在最新的 PyTorch 中。在 GPU 上运行 TorchInductor 需要 Triton,它包含在 PyTorch 2.0 夜间构建版本中。如果 Triton 仍然缺失,请尝试通过 pip 安装 torchtriton
( pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"
用于 CUDA 11.7)。
可以通过传递可调用对象到 torch.compile
来优化任意 Python 函数。然后我们可以用返回的优化函数替换原始函数。
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
或者,我们也可以使用装饰器。
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(t1, t2))
我们还可以优化 torch.nn.Module
实例。
t = torch.randn(10, 100)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(t))
torch.compile 和嵌套调用
装饰器函数内的嵌套函数调用也将被编译。
以同样的方式,在编译一个模块时,所有不在跳过列表中的子模块和方法也将被编译。
class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner_module = MyModule()
self.outer_lin = torch.nn.Linear(10, 2)
def forward(self, x):
x = self.inner_module(x)
return torch.nn.functional.relu(self.outer_lin(x))
outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))
我们还可以通过使用 torch.compiler.disable
来禁用某些函数的编译。假设你只想禁用 complex_function
函数的跟踪,但想在 complex_conjugate
中继续跟踪。在这种情况下,你可以使用 torch.compiler.disable(recursive=False)
选项。否则,默认是 recursive=True
。
def complex_conjugate(z):
return torch.conj(z)
@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
# Assuming this function cause problems in the compilation
z = torch.complex(real, imag)
return complex_conjugate(z)
def outer_function():
real = torch.tensor([2, 3], dtype=torch.float32)
imag = torch.tensor([4, 5], dtype=torch.float32)
z = complex_function(real, imag)
return torch.abs(z)
# Try to compile the outer_function
try:
opt_outer_function = torch.compile(outer_function)
print(opt_outer_function())
except Exception as e:
print("Compilation of outer_function failed:", e)
最佳实践与建议
torch.compile
在嵌套模块和函数调用中的行为
当您使用 torch.compile
时,编译器将尝试递归编译目标函数或模块内部的所有函数调用(不包括在跳过列表中的,如内置函数、torch.*命名空间中的某些函数)。
最佳实践:
1. 最高级编译:一种方法是尽可能在最高级别进行编译(即顶层模块初始化/调用时)并在遇到过多的图断裂或错误时选择性禁用编译。如果仍然存在许多编译问题,则编译单个子组件。
2. 模块化测试:在将它们集成到更大的模型之前,使用 torch.compile
测试单个函数和模块以隔离潜在问题。
3. 选择性禁用编译:如果某些函数或子模块无法由 torch.compile 处理,请使用 torch.compiler.disable 上下文管理器递归地排除它们。
4. 首先编译叶函数:在具有多个嵌套函数和模块的复杂模型中,首先编译叶函数或模块。更多信息请参阅 TorchDynamo API 进行细粒度跟踪。
展示加速效果
现在我们来演示使用 torch.compile
可以加速真实模型。我们将通过在随机数据上评估和训练一个 torchvision
模型来比较标准 eager 模式和 torch.compile
。
在开始之前,我们需要定义一些实用函数。
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
return (
torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
torch.randint(1000, (b,)).cuda(),
)
N_ITERS = 10
from torchvision.models import densenet121
def init_model():
return densenet121().to(torch.float32).cuda()
首先,让我们比较推理。
注意在调用 torch.compile
时,我们还有一个额外的 mode
参数,我们将在下面进行讨论。
model = init_model()
# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()
model_opt = torch.compile(model, mode="reduce-overhead")
inp = generate_data(16)[0]
with torch.no_grad():
print("eager:", timed(lambda: model(inp))[1])
print("compile:", timed(lambda: model_opt(inp))[1])
注意到 torch.compile
的完成时间比急切模式长得多。这是因为 torch.compile
在执行过程中将模型编译成优化的内核。在我们的例子中,模型的架构没有变化,因此不需要重新编译。所以如果我们多次运行我们的优化模型,我们应该看到与急切模式相比有显著的改进。
eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
with torch.no_grad():
_, eager_time = timed(lambda: model(inp))
eager_times.append(eager_time)
print(f"eager eval time {i}: {eager_time}")
print("~" * 10)
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
with torch.no_grad():
_, compile_time = timed(lambda: model_opt(inp))
compile_times.append(compile_time)
print(f"compile eval time {i}: {compile_time}")
print("~" * 10)
import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
事实上,我们可以看到使用 torch.compile
运行我们的模型会导致显著的加速。加速主要来自于减少 Python 开销和 GPU 读写,因此观察到的加速可能会因模型架构和批量大小等因素而有所不同。例如,如果模型的架构简单且数据量很大,那么瓶颈将是 GPU 计算,观察到的加速可能不太显著。
你也可能看到根据选择的 mode
参数不同的加速结果。 "reduce-overhead"
模式使用 CUDA 图进一步减少 Python 的开销。对于你自己的模型,你可能需要尝试不同的模式以最大化加速。你可以在这里了解更多关于模式的信息。
你可能会注意到,第二次运行我们的模型时,使用 torch.compile
的速度比其他运行慢得多,尽管它比第一次运行快得多。这是因为 "reduce-overhead"
模式运行了一些 CUDA 图的预热迭代。
对于一般的 PyTorch 基准测试,你可以尝试使用 torch.utils.benchmark
来代替我们上面定义的 timed
函数。在这个教程中,我们编写了自己的计时函数来展示 torch.compile
的编译延迟。
现在,让我们来比较一下训练。
model = init_model()
opt = torch.optim.Adam(model.parameters())
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)
_, eager_time = timed(lambda: train(model, inp))
eager_times.append(eager_time)
print(f"eager train time {i}: {eager_time}")
print("~" * 10)
model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)
_, compile_time = timed(lambda: train_opt(model, inp))
compile_times.append(compile_time)
print(f"compile train time {i}: {compile_time}")
print("~" * 10)
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
同样,我们可以看到 torch.compile
在第一次迭代中花费的时间更长,因为它必须编译模型,但在后续迭代中,与即时模式相比,我们看到了显著的加速。
我们注意到本教程中展示的加速数值仅用于演示目的。官方的加速值可以在 TorchInductor 性能仪表板上查看。
与 TorchScript 和 FX Tracing 的比较 §
我们已经看到 torch.compile
可以加速 PyTorch 代码。那么,为什么我们还要使用 torch.compile
而不是现有的 PyTorch 编译器解决方案,如 TorchScript 或 FX Tracing 呢? torch.compile
的主要优势在于其能够以最小的现有代码改动处理任意 Python 代码。
torch.compile
可以处理的一个案例是其他编译器解决方案难以处理的数据相关控制流(如下面的 if x.sum() < 0:
行)。
def f1(x, y):
if x.sum() < 0:
return -y
return y
# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
out1 = fn1(*args)
out2 = fn2(*args)
return torch.allclose(out1, out2)
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
TorchScript 跟踪 f1
导致静默错误的结果,因为只有实际的流程控制路径被跟踪。
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
FX 跟踪 f1
由于存在数据相关的流程控制,导致错误。
import traceback as tb
try:
torch.fx.symbolic_trace(f1)
except:
tb.print_exc()
如果我们在尝试 FX 跟踪 f1
时提供一个值给 x
,那么我们会遇到与 TorchScript 跟踪相同的问题,因为在跟踪函数中移除了数据相关的流程控制。
fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))
现在,我们可以看到 torch.compile
正确处理了数据相关的流程控制。
# Reset since we are using a different mode.
torch._dynamo.reset()
compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
TorchScript 脚本可以处理数据相关的控制流,但这种解决方案也带来了一系列问题。具体来说,TorchScript 脚本可能需要大量代码更改,并且当使用不支持的 Python 版本时会产生错误。
在下面的示例中,我们忘记了 TorchScript 类型注解,因此收到了一个 TorchScript 错误,因为参数 y
的输入类型,一个 int
,与默认参数类型 torch.Tensor
不匹配。
def f2(x, y):
return x + y
inp1 = torch.randn(5, 5)
inp2 = 3
script_f2 = torch.jit.script(f2)
try:
script_f2(inp1, inp2)
except:
tb.print_exc()
然而, torch.compile
可以轻松处理 f2
。
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
与之前的编译器解决方案相比, torch.compile
在处理非 PyTorch 函数方面表现得很好。
import scipy
def f3(x):
x = x * 2
x = scipy.fft.dct(x.numpy())
x = torch.from_numpy(x)
x = x * 2
return x
TorchScript 跟踪将非 PyTorch 函数调用的结果视为常量,因此我们的结果可能被静默错误。
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
TorchScript 脚本和 FX 跟踪禁止非 PyTorch 函数调用。
try:
torch.jit.script(f3)
except:
tb.print_exc()
try:
torch.fx.symbolic_trace(f3)
except:
tb.print_exc()
相比之下, torch.compile
容易处理非 PyTorch 函数调用。
compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
TorchDynamo 和 FX 图 §
torch.compile
的重要组件之一是 TorchDynamo。TorchDynamo 负责将任意 Python 代码即时编译成 FX 图,然后可以进一步优化。TorchDynamo 通过在运行时分析 Python 字节码并检测对 PyTorch 操作的调用来提取 FX 图。
通常,TorchInductor 是 torch.compile
的另一个组件,它进一步将 FX 图编译成优化的内核,但 TorchDynamo 允许使用不同的后端。为了检查 TorchDynamo 输出的 FX 图,让我们创建一个自定义后端,该后端输出 FX 图并简单地返回图的无优化前向方法。
from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("custom backend called with FX graph:")
gm.graph.print_tabular()
return gm.forward
# Reset since we are using a different backend.
torch._dynamo.reset()
opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])
使用我们的自定义后端,我们现在可以查看 TorchDynamo 如何处理数据相关的控制流。考虑以下函数,其中 if b.sum() < 0
是数据相关控制流的来源。
def bar(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
输出显示,TorchDynamo 提取了 3 个不同的 FX 图,对应以下代码(顺序可能与上面的输出不同):
x = a / (torch.abs(a) + 1)
b = b * -1; return x * b
return x * b
当 TorchDynamo 遇到不支持 Python 特性,如数据依赖的控制流时,它会中断计算图,让默认 Python 解释器处理不支持代码,然后继续捕获图。
让我们通过示例来调查 TorchDynamo 如何逐步执行 bar
。如果 b.sum() < 0
,那么 TorchDynamo 将运行图 1,让 Python 确定条件的结果,然后运行图 2。另一方面,如果 not b.sum() < 0
,那么 TorchDynamo 将运行图 1,让 Python 确定条件的结果,然后运行图 3。
这突出了 TorchDynamo 与之前 PyTorch 编译器解决方案之间的一个主要区别。当遇到不支持 Python 特性时,之前的解决方案要么引发错误,要么静默失败。而 TorchDynamo 则会中断计算图。
我们可以通过使用 torch._dynamo.explain
来看到 TorchDynamo 在哪里中断图:
# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)
为了最大化加速,应限制图断点。我们可以通过使用 fullgraph=True
来强制 TorchDynamo 在遇到第一个图断点时抛出错误:
opt_bar = torch.compile(bar, fullgraph=True)
try:
opt_bar(torch.randn(10), torch.randn(10))
except:
tb.print_exc()
下面,我们演示了 TorchDynamo 不会在上述用于演示加速的模型上断开图。
opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))
我们可以使用 torch.export
(从 PyTorch 2.1+开始)从输入的 PyTorch 程序中提取一个单独的可导出 FX 图。导出的图旨在在不同的(即无 Python 环境)环境中运行。一个重要的限制是 torch.export
不支持图断点。请查看此教程以获取有关 torch.export
的更多详细信息。
结论 §
在本教程中,我们通过介绍 torch.compile
的基本用法、演示相对于 eager 模式的加速、与之前的 PyTorch 编译器解决方案的比较以及简要调查 TorchDynamo 及其与 FX 图的交互来介绍了 torch.compile
。我们希望您能尝试使用 torch.compile
!
脚本总运行时间:(0 分钟 0.000 秒)