常见问题
作者:马克·萨拉菲姆
torch.compile
是否支持训练?¶
torch.compile
支持训练,使用 AOTAutograd 捕获反向传播:
.forward()
图和optimizer.step()
被 TorchDynamo 的 Pythonevalframe
前端捕获。对于 torchdynamo 捕获的每个段,它使用 AOTAutograd 生成反向图段。
前向和反向图对(可选)进行最小割划分,以节省前向和反向之间的最小状态。
前向和反向对被封装在
autograd.function
模块中。用户代码调用
.backward()
仍然触发 eager 的自动微分引擎,该引擎将每个编译后的反向图作为单个操作运行,同时运行任何非编译的 eager 操作.backward()
的函数。
你支持分布式代码吗?¶
torch.compile
支持 DistributedDataParallel
(DDP)。正在考虑支持其他分布式训练库。
分布式代码在 dynamo 中具有挑战性的主要原因是 AOTAutograd 展开正向和反向传递,并为后端提供 2 个图进行优化。这对于分布式代码来说是一个问题,因为我们理想情况下希望重叠通信操作与计算。Eager PyTorch 通过不同的方式为 DDP/FSDP 实现这一点——使用自动微分钩子、模块钩子和模块状态的修改/变异。
使用 Dynamo 优化 DDP 的基本策略在 distributed.py 中概述,主要思想是在 DDP 桶边界上进行图拆分。
当 DDP 中的每个节点需要与其他节点同步其权重时,它会将其梯度和参数组织成桶,这减少了通信时间,并允许节点向其他等待节点广播其梯度的一部分。
分布式代码中的图拆分意味着您可以期望 Dynamo 及其后端优化分布式程序的计算开销,但不是通信开销。图拆分可能会干扰编译速度提升,如果减少的图大小剥夺了编译器的融合机会。然而,随着图大小的增加,收益递减,因为大多数当前的计算优化都是局部融合。因此,在实践中,这种方法可能是足够的。
我还需要导出整个图吗?
对于绝大多数模型,你可能不需要,你可以直接使用 torch.compile()
,但在某些情况下,需要完整的图,你可以通过简单地运行 torch.compile(..., fullgraph=True)
来确保完整的图。这些情况包括:
大规模训练运行,如$250K+,需要管道并行和其他高级分片策略。
依赖于比训练优化器更激进融合的推理优化器,如 TensorRT 或 AITemplate。
移动训练或推理。
未来工作将包括将通信操作追踪到图中,协调这些操作与计算优化,并优化通信操作。
为什么我的代码崩溃了? ¶
如果你的代码在没有 torch.compile
的情况下运行良好,而启用后开始崩溃,那么最重要的第一步是找出你的失败发生在堆栈的哪个部分。为了排查这个问题,请按照以下步骤进行,并且只有在前一步成功后才能尝试下一步。
torch.compile(..., backend="eager")
只运行 TorchDynamo 前向图捕获,然后使用 PyTorch 运行捕获到的图。如果这失败了,那么就有可能是 TorchDynamo 的问题。torch.compile(..., backend="aot_eager")
使用 TorchDynamo 捕获前向图,然后使用 AOTAutograd 进行反向图的跟踪,无需任何额外的后端编译步骤。PyTorch eager 将用于运行前向和反向图。如果失败,则说明 AOTAutograd 存在问题。torch.compile(..., backend="inductor")
使用 TorchDynamo 捕获前向图,然后使用 TorchInductor 编译器跟踪反向图。如果失败,则说明 TorchInductor 存在问题。
为什么编译速度慢?
Dynamo 编译 - TorchDynamo 内置了一个用于收集和显示每个编译阶段花费时间的统计函数。这些统计信息可以通过调用
torch._dynamo.utils.compile_times()
在执行torch._dynamo
后访问。默认情况下,这将返回每个 TorchDynamo 函数按名称花费的编译时间的字符串表示。电感器编译 - TorchInductor 内置统计和跟踪功能,用于显示每个编译阶段花费的时间、输出代码、输出图可视化以及 IR 转储。
env TORCH_COMPILE_DEBUG=1 python repro.py
。这是一个调试工具,旨在通过类似以下输出的方式使调试/理解 TorchInductor 的内部结构更加容易。该调试跟踪中的每个文件都可以通过torch._inductor.config.trace.*
启用/禁用。由于生成成本较高,默认情况下禁用了配置文件和图表。请参阅示例调试目录输出以获取更多示例。当 TorchDynamo 编译一个函数(或其一部分)时,它会针对局部变量和全局变量做出某些假设,以便允许编译器优化,并将这些假设表达为在运行时检查特定值的守卫。如果这些守卫中的任何一个失败,Dynamo 将重新编译该函数(或其部分)多达
torch._dynamo.config.recompile_limit
次。如果你的程序正在达到缓存限制,你首先需要确定哪个守卫失败以及你的程序中的哪个部分触发了它。重新编译分析器会自动将 TorchDynamo 的缓存限制设置为 1,并在仅记录任何守卫失败原因的观察-only '编译器'下运行你的程序。你应该确保你的程序运行的时间(或迭代次数)至少与你在遇到问题时运行的时间一样长,分析器将在此期间累积统计数据。
为什么在生产环境中你会重新编译?
在某些情况下,您可能不希望在程序预热后出现意外的编译。例如,如果您正在处理一个对延迟敏感的生产流量。为此,TorchDynamo 提供了一个替代模式,其中使用先前编译的图,但不会生成新的图:
frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))
你是如何加速我的代码的? ¶
加速 PyTorch 代码有 3 种主要方法:
通过垂直融合进行内核融合,将顺序操作融合在一起以避免过多的读写。例如,融合两个连续的余弦函数意味着您可以进行 1 次读取和 1 次写入,而不是 2 次读取和 2 次写入。水平融合:最简单的例子是批处理,其中单个矩阵与一批示例相乘,但更一般的情况是分组 GEMM,其中一组矩阵乘法被一起调度。
指令乱序执行:一种针对编译器的通用优化方法,通过在图中查看确切的数据依赖关系,我们可以决定执行节点的最佳时机以及哪些缓冲区可以重用
自动工作分配:与指令乱序执行类似,但通过将图中的节点与物理硬件或内存等资源匹配,我们可以设计合适的调度方案
上述内容是加速 PyTorch 代码的一般原则,但不同的后端将在优化内容上做出不同的权衡。例如,Inductor 首先负责融合尽可能多的内容,然后才生成 Triton 内核。
此外,Triton 还通过自动内存归约、内存管理和每个流多处理器内的调度来提供加速,并且已被设计用于处理分块计算。
然而,无论你使用什么后端,最好使用基准和查看方法,尝试使用 PyTorch 分析器,直观检查生成的内核,并尝试亲自看看发生了什么。
为什么我没有看到速度提升?¶
图断点
你不会看到使用 dynamo 所期望的速度提升的主要原因是图断开过多。那么什么是图断开呢?
给定一个程序,例如:
def some_fun(x):
...
torch.compile(some_fun)(x)
...
Torchdynamo 将尝试将 some_fun()
中的所有 torch/tensor 操作编译成一个单一的 FX 图,但它可能无法将所有内容都捕获到一个图中。
一些图断裂的原因对 TorchDynamo 来说是不可逾越的,例如调用 PyTorch 之外的 C 扩展,这对 TorchDynamo 来说是不可见的,并且可能执行任意操作,而 TorchDynamo 无法引入必要的保护措施来确保编译后的程序可以安全重用。
为了最大化性能,尽可能减少图断裂是很重要的。
识别图断裂的原因
要识别程序中的所有图断点及其原因,可以使用 torch._dynamo.explain
。该工具在提供的函数上运行 TorchDynamo 并汇总遇到的图断点。以下是一个示例用法:
import torch
import torch._dynamo as dynamo
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
print("woo")
if b.sum() < 0:
b = b * -1
return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
Break Reason 1:
Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User Stack:
<FrameSummary file foo.py, line 5 in toy_example>
Break Reason 2:
Reason: generic_jump TensorVariable()
User Stack:
<FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
...
Out Guards:
...
"""
要在遇到第一个图断点时抛出错误,可以通过使用 fullgraph=True
禁用 Python 回退,如果您使用过基于导出的编译器,这应该很熟悉。
def toy_example(a, b):
...
torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
为什么我修改了代码后代码没有重新编译?
如果您通过设置 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py
启用了动态形状,则在形状变化时您的代码不会重新编译。我们已经添加了对动态形状的支持,这避免了形状变化小于 2 倍因子时重新编译的情况。这在 CV 中图像大小变化或 NLP 中序列长度变化等场景中特别有用。在推理场景中,通常无法事先知道批大小,因为您需要从不同的客户端应用程序中获取尽可能多的信息。
通常情况下,TorchDynamo 会非常努力地避免不必要的重新编译,例如,如果 TorchDynamo 找到了 3 个图,而您的更改只修改了一个图,那么只有那个图会重新编译。因此,另一个避免潜在缓慢的编译时间的技巧是先编译模型一次进行预热,之后的编译将会快得多。冷启动编译时间仍然是我们跟踪的可见指标。
为什么我会得到错误的结果?
如果您设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4
,可以减少精度问题,它的工作方式与 git bisect 模型类似,完整的重现可能如下 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4
我们需要这样做的原因是下游编译器会生成代码,无论是 Triton 代码还是 C++后端,这些下游编译器的数值可能在细微之处有所不同,但会对您的训练稳定性产生重大影响。因此,精度调试器对我们检测代码生成或后端编译器中的错误非常有用。
如果您想确保 torch 和 triton 之间的随机数生成相同,则可以启用 torch._inductor.config.fallback_random = True
为什么我会遇到内存溢出(OOM)?
Dynamo 仍然是一个 alpha 产品,所以存在一些导致 OOM 的原因。如果您遇到了 OOM,请尝试以下顺序禁用以下配置,并在 GitHub 上创建一个 issue,以便我们解决根本问题:1. 如果您使用动态形状,请尝试禁用它们,我们已默认禁用: env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py
2. 在 inductor 中,CUDA graphs 与 Triton 默认启用,移除它们可能有助于缓解一些 OOM 问题: torch._inductor.config.triton.cudagraphs = False
。
torch.func
是否与 torch.compile
(用于 grad 和 vmap 转换)兼容?
将 torch.func
转换应用于使用 torch.compile
的函数是可行的:
import torch
@torch.compile
def f(x):
return torch.sin(x)
def g(x):
return torch.grad(f)(x)
x = torch.randn(2, 3)
g(x)
在函数内部处理 torch.func
调用 torch.compile
¶
使用 torch.compile
编译 torch.func.grad
¶
import torch
def wrapper_fn(x):
return torch.func.grad(lambda x: x.sin().sum())(x)
x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)
使用 torch.compile
编译 torch.vmap
¶
import torch
def my_fn(x):
return torch.vmap(lambda x: x.sum(1))(x)
x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)
编译除支持的函数之外的其他函数(逃生口) ¶
对于其他转换,作为解决方案,请使用 torch._dynamo.allow_in_graph
allow_in_graph
是一个逃生门。如果你的代码与 torch.compile
不兼容,其中 torch.compile
可以检查 Python 字节码,但你认为通过符号跟踪方法(如 jax.jit
)可以工作,那么请使用 allow_in_graph
。
使用 allow_in_graph
注解函数时,你必须确保你的代码满足以下要求:
函数中的所有输出仅依赖于输入,不依赖于任何捕获的张量。
您的功能是功能性的。也就是说,它不会改变任何状态。这可能被放宽;实际上,我们支持从外部看起来是功能性的函数:它们可能包含原地 PyTorch 操作,但可能不会改变全局状态或函数的输入。
您的函数不会引发数据相关的错误。
import torch
@torch.compile
def f(x):
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
x = torch.randn(2, 3)
f(x)
一个常见的陷阱是使用 allow_in_graph
来注释一个调用 nn.Module
的函数。这是因为输出现在依赖于 nn.Module
的参数。为了使其正常工作,请使用 torch.func.functional_call
来提取模块状态。
NumPy 是否与 torch.compile
兼容? ¶
从 2.1 版本开始, torch.compile
理解在 NumPy 数组上运行的本地 NumPy 程序,以及将 PyTorch 转换为 NumPy 并反向转换的混合 PyTorch-NumPy 程序,通过 x.numpy()
、 torch.from_numpy
和相关函数实现。
torch.compile
支持哪些 NumPy 功能? ¶
torch.compile
中的 NumPy 遵循 NumPy 2.0 预发布版。
通常, torch.compile
能够追踪大多数 NumPy 结构,当它无法追踪时,会回退到急切模式,并让 NumPy 执行那部分代码。即使如此,也有一些功能, torch.compile
的语义与 NumPy 略有不同:
NumPy 标量:我们将它们建模为 0 维数组。也就是说,
np.float32(3)
在torch.compile
下返回一个 0 维数组。为了避免图断裂,最好使用这个 0 维数组。如果这破坏了你的代码,你可以通过将 NumPy 标量强制转换为相关的 Python 标量类型bool/int/float
来解决这个问题。负步长:
np.flip
和使用负步长的切片会返回一个副本。类型提升:NumPy 的类型提升将在 NumPy 2.0 中发生变化。新规则在 NEP 50 中描述。
torch.compile
实现 NEP 50 而不是即将被弃用的当前规则。{tril,triu}_indices_from/{tril,triu}_indices
返回数组而不是数组的元组。
对于我们不支持追踪的其他功能,我们将优雅地回退到 NumPy 执行:
非数值数据类型,如日期时间、字符串、字符、void、结构化数据类型和记录数组。
长数据类型
np.float128/np.complex256
以及一些无符号数据类型np.uint16/np.uint32/np.uint64
。ndarray
子类。隐藏数组。
如
axes=[(n,k),(k,m)->(n,m)]
和 ufunc 方法(例如,np.add.reduce
)之类的神秘 ufunc 机制。排序/排序
complex64/complex128
数组。NumPy
np.poly1d
和np.polynomial
。函数中带有 2 个或更多返回值的参数位置(
out=tuple
也可以工作)。__array_function__
,__array_interface__
和__array_wrap__
。ndarray.ctypes
属性。
我可以使用 torch.compile
编译 NumPy 代码吗?¶
当然可以! torch.compile
可以原生理解 NumPy 代码,并将其视为 PyTorch 代码。为此,只需用 torch.compile
装饰器将 NumPy 代码包裹起来即可。
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)
在设置环境变量 TORCH_LOGS=output_code
后执行此示例,我们可以看到 torch.compile
能够将乘法和求和合并为一个 C++内核。它还能够使用 OpenMP 并行执行它们(原生 NumPy 是单线程的)。这可以使您的 NumPy 代码速度提高 n
倍,其中 n
是您处理器的核心数!
以这种方式跟踪 NumPy 代码也支持编译代码中的图断点。
我可以使用 torch.compile
在 CUDA 上执行 NumPy 代码并通过它计算梯度吗?¶
是的,你可以!为此,你只需在 torch.device("cuda")
上下文中执行你的代码即可。考虑以下示例
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)
在这个示例中, numpy_fn
将在 CUDA 中执行。为了实现这一点, torch.compile
会自动将 X
和 Y
从 CPU 移动到 CUDA,然后它将结果 Z
从 CUDA 移动到 CPU。如果我们想在同一个程序运行中多次执行此函数,我们可能希望避免这些相对昂贵的内存复制。为此,我们只需调整我们的 numpy_fn
,使其接受 cuda 张量并返回张量。我们可以通过使用 torch.compiler.wrap_numpy
来实现这一点:
@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
在这里,我们明确地在 CUDA 内存中创建张量,并将它们传递给函数,该函数在 CUDA 设备上执行所有计算。 wrap_numpy
负责在编译器中标记任何 torch.Tensor
输入为具有 np.ndarray
语义的 torch.compile
级别的输入。在编译器中标记张量是一个非常便宜的操作,因此在运行时不会发生数据复制或数据移动。
使用这个装饰器,我们还可以在 NumPy 代码中进行微分!
@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)))
X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)
我们在此上下文中使用 fullgraph=True
作为图断点存在问题。当发生图断点时,我们需要实例化 NumPy 数组。由于 NumPy 数组没有 device
或 requires_grad
的概念,因此在图断点期间,此信息会丢失。
我们无法通过图断点传播梯度,因为图断点代码可能执行任意代码,而这些代码不知道如何进行微分。另一方面,在 CUDA 执行的情况下,我们可以像第一个例子中那样解决这个问题,通过使用 torch.device("cuda")
上下文管理器:
@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
prod = X[:, :, None] * Y[:, None, :]
print("oops, a graph break!")
return np.sum(prod, axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
在图断点期间,中间张量仍然需要移动到 CPU,但当图断点后的跟踪恢复时,剩余的图仍然在 CUDA 上跟踪。鉴于这种 CUDA <> CPU 和 CPU <> CUDA 的移动,图断点在 NumPy 上下文中相当昂贵,应该避免,但至少它们允许跟踪复杂的代码片段。
我如何在 torch.compile
下调试 NumPy 代码?
调试即时编译代码具有挑战性,因为现代编译器的复杂性以及它们引发的令人畏惧的错误。torch.compile 故障排除文档包含一些技巧和窍门,介绍如何应对这项任务。
如果上述方法不足以定位问题的根源,我们还可以使用一些其他特定的 NumPy 工具。我们可以通过禁用 NumPy 函数的跟踪来区分错误是否完全在 PyTorch 代码中:
from torch._dynamo import config
config.trace_numpy = False
如果错误在于跟踪的 NumPy 代码,我们可以使用 PyTorch 作为后端,通过导入 import torch._numpy as np
来积极执行 NumPy 代码(无需 torch.compile
)。这仅应用于调试目的,绝对不能替代 PyTorch API,因为它性能较差,并且作为私有 API,可能会随时更改。无论如何, torch._numpy
是以 PyTorch 为基础的 NumPy 的 Python 实现,它被 torch.compile
内部使用,以将 NumPy 代码转换为 Pytorch 代码。它易于阅读和修改,所以如果您在其中发现任何错误,请随时提交 PR 修复它或简单地打开一个 issue。
如果程序在导入 torch._numpy as np
时仍然工作,那么很可能是 TorchDynamo 中存在 bug。如果是这种情况,请随时提交一个最小化复现问题的 issue。
我尝试了 NumPy 代码,但没有看到任何加速效果。¶
开始的最佳方式是查看教程,其中包含有关如何调试此类 torch.compile 问题的通用建议。
一些图断开可能是因为使用了不受支持的功能。请参阅 torch.compile 支持哪些 NumPy 功能?更普遍地说,记住一些广泛使用的 NumPy 功能与编译器不兼容是有用的。例如,就地修改在编译器内部推理困难,并且通常比它们的就地对应物性能更差。因此,最好避免使用它们。同样适用于使用 out=
参数。相反,优先使用就地操作,让 torch.compile
优化内存使用。同样适用于依赖于数据的操作,如通过布尔掩码的掩码索引,或依赖于数据的控制流,如 if
或 while
构造。
哪个 API 用于细粒度跟踪?
在某些情况下,您可能需要排除代码中的小部分内容,使其不参与 torch.compile 编译。本节提供了一些答案,您可以在 TorchDynamo API 中找到更多关于细粒度跟踪的信息。
如何在函数上执行图断点?
在函数上执行图断点不足以充分表达您希望 PyTorch 执行的操作。您需要更具体地说明您的用例。以下是一些您可能需要考虑的常见用例:
如果您想禁用此函数帧及其递归调用的帧上的编译,请使用
torch._dynamo.disable
。如果您想使特定操作符,如
fbgemm
,使用急切模式,请使用torch._dynamo.disallow_in_graph
。
一些不常见的用例包括:
如果您想在函数帧上禁用 TorchDynamo 但希望在递归调用的帧上重新启用它,请使用
torch._dynamo.disable(recursive=False)
。如果你想防止函数帧内联,请在要防止内联的函数开头使用
torch._dynamo.graph_break
。
torch._dynamo.disable
和 torch._dynamo.disallow_in_graph
有什么区别?
Disallow-in-graph 在操作符级别上工作,或者更具体地说,是你在 TorchDynamo 提取的图中看到的操作符。
Disable 在函数帧级别上工作,并决定 TorchDynamo 是否应该检查函数帧。
torch._dynamo.disable
和 torch._dynamo_skip
有什么区别?
注意
torch._dynamo_skip
已弃用。
你很可能会需要 torch._dynamo.disable
。但在一个不太可能的情况下,你可能需要更精细的控制。假设你只想禁用 a_fn
函数的跟踪,但希望在 aa_fn
和 ab_fn
中继续跟踪。下面的图片展示了这个用例:

在这种情况下,你可以使用 torch._dynamo.disable(recursive=False)
。在之前的版本中,此功能由 torch._dynamo.skip
提供。现在由 torch._dynamo.disable
中的 recursive
标志支持。