快捷键

常见问题

作者:马克·萨拉菲姆

torch.compile 是否支持训练?¶

torch.compile 支持训练,使用 AOTAutograd 捕获反向传播:

  1. .forward() 图和 optimizer.step() 被 TorchDynamo 的 Python evalframe 前端捕获。

  2. 对于 torchdynamo 捕获的每个段,它使用 AOTAutograd 生成反向图段。

  3. 前向和反向图对(可选)进行最小割划分,以节省前向和反向之间的最小状态。

  4. 前向和反向对被封装在 autograd.function 模块中。

  5. 用户代码调用 .backward() 仍然触发 eager 的自动微分引擎,该引擎将每个编译后的反向图作为单个操作运行,同时运行任何非编译的 eager 操作 .backward() 的函数。

你支持分布式代码吗?¶

torch.compile 支持 DistributedDataParallel (DDP)。正在考虑支持其他分布式训练库。

分布式代码在 dynamo 中具有挑战性的主要原因是 AOTAutograd 展开正向和反向传递,并为后端提供 2 个图进行优化。这对于分布式代码来说是一个问题,因为我们理想情况下希望重叠通信操作与计算。Eager PyTorch 通过不同的方式为 DDP/FSDP 实现这一点——使用自动微分钩子、模块钩子和模块状态的修改/变异。

使用 Dynamo 优化 DDP 的基本策略在 distributed.py 中概述,主要思想是在 DDP 桶边界上进行图拆分。

当 DDP 中的每个节点需要与其他节点同步其权重时,它会将其梯度和参数组织成桶,这减少了通信时间,并允许节点向其他等待节点广播其梯度的一部分。

分布式代码中的图拆分意味着您可以期望 Dynamo 及其后端优化分布式程序的计算开销,但不是通信开销。图拆分可能会干扰编译速度提升,如果减少的图大小剥夺了编译器的融合机会。然而,随着图大小的增加,收益递减,因为大多数当前的计算优化都是局部融合。因此,在实践中,这种方法可能是足够的。

我还需要导出整个图吗?

对于绝大多数模型,你可能不需要,你可以直接使用 torch.compile() ,但在某些情况下,需要完整的图,你可以通过简单地运行 torch.compile(..., fullgraph=True) 来确保完整的图。这些情况包括:

  • 大规模训练运行,如$250K+,需要管道并行和其他高级分片策略。

  • 依赖于比训练优化器更激进融合的推理优化器,如 TensorRT 或 AITemplate。

  • 移动训练或推理。

未来工作将包括将通信操作追踪到图中,协调这些操作与计算优化,并优化通信操作。

为什么我的代码崩溃了? ¶

如果你的代码在没有 torch.compile 的情况下运行良好,而启用后开始崩溃,那么最重要的第一步是找出你的失败发生在堆栈的哪个部分。为了排查这个问题,请按照以下步骤进行,并且只有在前一步成功后才能尝试下一步。

  1. torch.compile(..., backend="eager") 只运行 TorchDynamo 前向图捕获,然后使用 PyTorch 运行捕获到的图。如果这失败了,那么就有可能是 TorchDynamo 的问题。

  2. torch.compile(..., backend="aot_eager") 使用 TorchDynamo 捕获前向图,然后使用 AOTAutograd 进行反向图的跟踪,无需任何额外的后端编译步骤。PyTorch eager 将用于运行前向和反向图。如果失败,则说明 AOTAutograd 存在问题。

  3. torch.compile(..., backend="inductor") 使用 TorchDynamo 捕获前向图,然后使用 TorchInductor 编译器跟踪反向图。如果失败,则说明 TorchInductor 存在问题。

为什么编译速度慢?

  • Dynamo 编译 - TorchDynamo 内置了一个用于收集和显示每个编译阶段花费时间的统计函数。这些统计信息可以通过调用 torch._dynamo.utils.compile_times() 在执行 torch._dynamo 后访问。默认情况下,这将返回每个 TorchDynamo 函数按名称花费的编译时间的字符串表示。

  • 电感器编译 - TorchInductor 内置统计和跟踪功能,用于显示每个编译阶段花费的时间、输出代码、输出图可视化以及 IR 转储。 env TORCH_COMPILE_DEBUG=1 python repro.py 。这是一个调试工具,旨在通过类似以下输出的方式使调试/理解 TorchInductor 的内部结构更加容易。该调试跟踪中的每个文件都可以通过 torch._inductor.config.trace.* 启用/禁用。由于生成成本较高,默认情况下禁用了配置文件和图表。请参阅示例调试目录输出以获取更多示例。

  • 当 TorchDynamo 编译一个函数(或其一部分)时,它会针对局部变量和全局变量做出某些假设,以便允许编译器优化,并将这些假设表达为在运行时检查特定值的守卫。如果这些守卫中的任何一个失败,Dynamo 将重新编译该函数(或其部分)多达 torch._dynamo.config.recompile_limit 次。如果你的程序正在达到缓存限制,你首先需要确定哪个守卫失败以及你的程序中的哪个部分触发了它。重新编译分析器会自动将 TorchDynamo 的缓存限制设置为 1,并在仅记录任何守卫失败原因的观察-only '编译器'下运行你的程序。你应该确保你的程序运行的时间(或迭代次数)至少与你在遇到问题时运行的时间一样长,分析器将在此期间累积统计数据。

为什么在生产环境中你会重新编译?

在某些情况下,您可能不希望在程序预热后出现意外的编译。例如,如果您正在处理一个对延迟敏感的生产流量。为此,TorchDynamo 提供了一个替代模式,其中使用先前编译的图,但不会生成新的图:

frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))

你是如何加速我的代码的? ¶

加速 PyTorch 代码有 3 种主要方法:

  1. 通过垂直融合进行内核融合,将顺序操作融合在一起以避免过多的读写。例如,融合两个连续的余弦函数意味着您可以进行 1 次读取和 1 次写入,而不是 2 次读取和 2 次写入。水平融合:最简单的例子是批处理,其中单个矩阵与一批示例相乘,但更一般的情况是分组 GEMM,其中一组矩阵乘法被一起调度。

  2. 指令乱序执行:一种针对编译器的通用优化方法,通过在图中查看确切的数据依赖关系,我们可以决定执行节点的最佳时机以及哪些缓冲区可以重用

  3. 自动工作分配:与指令乱序执行类似,但通过将图中的节点与物理硬件或内存等资源匹配,我们可以设计合适的调度方案

上述内容是加速 PyTorch 代码的一般原则,但不同的后端将在优化内容上做出不同的权衡。例如,Inductor 首先负责融合尽可能多的内容,然后才生成 Triton 内核。

此外,Triton 还通过自动内存归约、内存管理和每个流多处理器内的调度来提供加速,并且已被设计用于处理分块计算。

然而,无论你使用什么后端,最好使用基准和查看方法,尝试使用 PyTorch 分析器,直观检查生成的内核,并尝试亲自看看发生了什么。

为什么我没有看到速度提升?¶

图断点

你不会看到使用 dynamo 所期望的速度提升的主要原因是图断开过多。那么什么是图断开呢?

给定一个程序,例如:

def some_fun(x):
    ...

torch.compile(some_fun)(x)
...

Torchdynamo 将尝试将 some_fun() 中的所有 torch/tensor 操作编译成一个单一的 FX 图,但它可能无法将所有内容都捕获到一个图中。

一些图断裂的原因对 TorchDynamo 来说是不可逾越的,例如调用 PyTorch 之外的 C 扩展,这对 TorchDynamo 来说是不可见的,并且可能执行任意操作,而 TorchDynamo 无法引入必要的保护措施来确保编译后的程序可以安全重用。

为了最大化性能,尽可能减少图断裂是很重要的。

识别图断裂的原因

要识别程序中的所有图断点及其原因,可以使用 torch._dynamo.explain 。该工具在提供的函数上运行 TorchDynamo 并汇总遇到的图断点。以下是一个示例用法:

import torch
import torch._dynamo as dynamo
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    print("woo")
    if b.sum() < 0:
        b = b * -1
    return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
    User Stack:
      <FrameSummary file foo.py, line 5 in toy_example>
  Break Reason 2:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
  ...
Out Guards:
  ...
"""

要在遇到第一个图断点时抛出错误,可以通过使用 fullgraph=True 禁用 Python 回退,如果您使用过基于导出的编译器,这应该很熟悉。

def toy_example(a, b):
   ...

torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)

为什么我修改了代码后代码没有重新编译?

如果您通过设置 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py 启用了动态形状,则在形状变化时您的代码不会重新编译。我们已经添加了对动态形状的支持,这避免了形状变化小于 2 倍因子时重新编译的情况。这在 CV 中图像大小变化或 NLP 中序列长度变化等场景中特别有用。在推理场景中,通常无法事先知道批大小,因为您需要从不同的客户端应用程序中获取尽可能多的信息。

通常情况下,TorchDynamo 会非常努力地避免不必要的重新编译,例如,如果 TorchDynamo 找到了 3 个图,而您的更改只修改了一个图,那么只有那个图会重新编译。因此,另一个避免潜在缓慢的编译时间的技巧是先编译模型一次进行预热,之后的编译将会快得多。冷启动编译时间仍然是我们跟踪的可见指标。

为什么我会得到错误的结果?

如果您设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4 ,可以减少精度问题,它的工作方式与 git bisect 模型类似,完整的重现可能如下 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4 我们需要这样做的原因是下游编译器会生成代码,无论是 Triton 代码还是 C++后端,这些下游编译器的数值可能在细微之处有所不同,但会对您的训练稳定性产生重大影响。因此,精度调试器对我们检测代码生成或后端编译器中的错误非常有用。

如果您想确保 torch 和 triton 之间的随机数生成相同,则可以启用 torch._inductor.config.fallback_random = True

为什么我会遇到内存溢出(OOM)?

Dynamo 仍然是一个 alpha 产品,所以存在一些导致 OOM 的原因。如果您遇到了 OOM,请尝试以下顺序禁用以下配置,并在 GitHub 上创建一个 issue,以便我们解决根本问题:1. 如果您使用动态形状,请尝试禁用它们,我们已默认禁用: env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py 2. 在 inductor 中,CUDA graphs 与 Triton 默认启用,移除它们可能有助于缓解一些 OOM 问题: torch._inductor.config.triton.cudagraphs = False

torch.func 是否与 torch.compile (用于 grad 和 vmap 转换)兼容?

torch.func 转换应用于使用 torch.compile 的函数是可行的:

import torch

@torch.compile
def f(x):
    return torch.sin(x)

def g(x):
    return torch.grad(f)(x)

x = torch.randn(2, 3)
g(x)

在函数内部处理 torch.func 调用 torch.compile

使用 torch.compile 编译 torch.func.grad

import torch

def wrapper_fn(x):
    return torch.func.grad(lambda x: x.sin().sum())(x)

x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)

使用 torch.compile 编译 torch.vmap

import torch

def my_fn(x):
    return torch.vmap(lambda x: x.sum(1))(x)

x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)

编译除支持的函数之外的其他函数(逃生口) ¶

对于其他转换,作为解决方案,请使用 torch._dynamo.allow_in_graph

allow_in_graph 是一个逃生门。如果你的代码与 torch.compile 不兼容,其中 torch.compile 可以检查 Python 字节码,但你认为通过符号跟踪方法(如 jax.jit )可以工作,那么请使用 allow_in_graph

使用 allow_in_graph 注解函数时,你必须确保你的代码满足以下要求:

  • 函数中的所有输出仅依赖于输入,不依赖于任何捕获的张量。

  • 您的功能是功能性的。也就是说,它不会改变任何状态。这可能被放宽;实际上,我们支持从外部看起来是功能性的函数:它们可能包含原地 PyTorch 操作,但可能不会改变全局状态或函数的输入。

  • 您的函数不会引发数据相关的错误。

import torch

@torch.compile
def f(x):
    return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)

x = torch.randn(2, 3)
f(x)

一个常见的陷阱是使用 allow_in_graph 来注释一个调用 nn.Module 的函数。这是因为输出现在依赖于 nn.Module 的参数。为了使其正常工作,请使用 torch.func.functional_call 来提取模块状态。

NumPy 是否与 torch.compile 兼容? ¶

从 2.1 版本开始, torch.compile 理解在 NumPy 数组上运行的本地 NumPy 程序,以及将 PyTorch 转换为 NumPy 并反向转换的混合 PyTorch-NumPy 程序,通过 x.numpy()torch.from_numpy 和相关函数实现。

torch.compile 支持哪些 NumPy 功能? ¶

torch.compile 中的 NumPy 遵循 NumPy 2.0 预发布版。

通常, torch.compile 能够追踪大多数 NumPy 结构,当它无法追踪时,会回退到急切模式,并让 NumPy 执行那部分代码。即使如此,也有一些功能, torch.compile 的语义与 NumPy 略有不同:

  • NumPy 标量:我们将它们建模为 0 维数组。也就是说, np.float32(3)torch.compile 下返回一个 0 维数组。为了避免图断裂,最好使用这个 0 维数组。如果这破坏了你的代码,你可以通过将 NumPy 标量强制转换为相关的 Python 标量类型 bool/int/float 来解决这个问题。

  • 负步长: np.flip 和使用负步长的切片会返回一个副本。

  • 类型提升:NumPy 的类型提升将在 NumPy 2.0 中发生变化。新规则在 NEP 50 中描述。 torch.compile 实现 NEP 50 而不是即将被弃用的当前规则。

  • {tril,triu}_indices_from/{tril,triu}_indices 返回数组而不是数组的元组。

对于我们不支持追踪的其他功能,我们将优雅地回退到 NumPy 执行:

  • 非数值数据类型,如日期时间、字符串、字符、void、结构化数据类型和记录数组。

  • 长数据类型 np.float128/np.complex256 以及一些无符号数据类型 np.uint16/np.uint32/np.uint64

  • ndarray 子类。

  • 隐藏数组。

  • axes=[(n,k),(k,m)->(n,m)] 和 ufunc 方法(例如, np.add.reduce )之类的神秘 ufunc 机制。

  • 排序/排序 complex64/complex128 数组。

  • NumPy np.poly1dnp.polynomial

  • 函数中带有 2 个或更多返回值的参数位置( out=tuple 也可以工作)。

  • __array_function____array_interface____array_wrap__

  • ndarray.ctypes 属性。

我可以使用 torch.compile 编译 NumPy 代码吗?¶

当然可以! torch.compile 可以原生理解 NumPy 代码,并将其视为 PyTorch 代码。为此,只需用 torch.compile 装饰器将 NumPy 代码包裹起来即可。

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

在设置环境变量 TORCH_LOGS=output_code 后执行此示例,我们可以看到 torch.compile 能够将乘法和求和合并为一个 C++内核。它还能够使用 OpenMP 并行执行它们(原生 NumPy 是单线程的)。这可以使您的 NumPy 代码速度提高 n 倍,其中 n 是您处理器的核心数!

以这种方式跟踪 NumPy 代码也支持编译代码中的图断点。

我可以使用 torch.compile 在 CUDA 上执行 NumPy 代码并通过它计算梯度吗?¶

是的,你可以!为此,你只需在 torch.device("cuda") 上下文中执行你的代码即可。考虑以下示例

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

在这个示例中, numpy_fn 将在 CUDA 中执行。为了实现这一点, torch.compile 会自动将 XY 从 CPU 移动到 CUDA,然后它将结果 Z 从 CUDA 移动到 CPU。如果我们想在同一个程序运行中多次执行此函数,我们可能希望避免这些相对昂贵的内存复制。为此,我们只需调整我们的 numpy_fn ,使其接受 cuda 张量并返回张量。我们可以通过使用 torch.compiler.wrap_numpy 来实现这一点:

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在这里,我们明确地在 CUDA 内存中创建张量,并将它们传递给函数,该函数在 CUDA 设备上执行所有计算。 wrap_numpy 负责在编译器中标记任何 torch.Tensor 输入为具有 np.ndarray 语义的 torch.compile 级别的输入。在编译器中标记张量是一个非常便宜的操作,因此在运行时不会发生数据复制或数据移动。

使用这个装饰器,我们还可以在 NumPy 代码中进行微分!

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)))

X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)

我们在此上下文中使用 fullgraph=True 作为图断点存在问题。当发生图断点时,我们需要实例化 NumPy 数组。由于 NumPy 数组没有 devicerequires_grad 的概念,因此在图断点期间,此信息会丢失。

我们无法通过图断点传播梯度,因为图断点代码可能执行任意代码,而这些代码不知道如何进行微分。另一方面,在 CUDA 执行的情况下,我们可以像第一个例子中那样解决这个问题,通过使用 torch.device("cuda") 上下文管理器:

@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    prod = X[:, :, None] * Y[:, None, :]
    print("oops, a graph break!")
    return np.sum(prod, axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")

with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在图断点期间,中间张量仍然需要移动到 CPU,但当图断点后的跟踪恢复时,剩余的图仍然在 CUDA 上跟踪。鉴于这种 CUDA <> CPU 和 CPU <> CUDA 的移动,图断点在 NumPy 上下文中相当昂贵,应该避免,但至少它们允许跟踪复杂的代码片段。

我如何在 torch.compile 下调试 NumPy 代码?

调试即时编译代码具有挑战性,因为现代编译器的复杂性以及它们引发的令人畏惧的错误。torch.compile 故障排除文档包含一些技巧和窍门,介绍如何应对这项任务。

如果上述方法不足以定位问题的根源,我们还可以使用一些其他特定的 NumPy 工具。我们可以通过禁用 NumPy 函数的跟踪来区分错误是否完全在 PyTorch 代码中:

from torch._dynamo import config
config.trace_numpy = False

如果错误在于跟踪的 NumPy 代码,我们可以使用 PyTorch 作为后端,通过导入 import torch._numpy as np 来积极执行 NumPy 代码(无需 torch.compile )。这仅应用于调试目的,绝对不能替代 PyTorch API,因为它性能较差,并且作为私有 API,可能会随时更改。无论如何, torch._numpy 是以 PyTorch 为基础的 NumPy 的 Python 实现,它被 torch.compile 内部使用,以将 NumPy 代码转换为 Pytorch 代码。它易于阅读和修改,所以如果您在其中发现任何错误,请随时提交 PR 修复它或简单地打开一个 issue。

如果程序在导入 torch._numpy as np 时仍然工作,那么很可能是 TorchDynamo 中存在 bug。如果是这种情况,请随时提交一个最小化复现问题的 issue。

我尝试了 NumPy 代码,但没有看到任何加速效果。¶

开始的最佳方式是查看教程,其中包含有关如何调试此类 torch.compile 问题的通用建议。

一些图断开可能是因为使用了不受支持的功能。请参阅 torch.compile 支持哪些 NumPy 功能?更普遍地说,记住一些广泛使用的 NumPy 功能与编译器不兼容是有用的。例如,就地修改在编译器内部推理困难,并且通常比它们的就地对应物性能更差。因此,最好避免使用它们。同样适用于使用 out= 参数。相反,优先使用就地操作,让 torch.compile 优化内存使用。同样适用于依赖于数据的操作,如通过布尔掩码的掩码索引,或依赖于数据的控制流,如 ifwhile 构造。

哪个 API 用于细粒度跟踪?

在某些情况下,您可能需要排除代码中的小部分内容,使其不参与 torch.compile 编译。本节提供了一些答案,您可以在 TorchDynamo API 中找到更多关于细粒度跟踪的信息。

如何在函数上执行图断点?

在函数上执行图断点不足以充分表达您希望 PyTorch 执行的操作。您需要更具体地说明您的用例。以下是一些您可能需要考虑的常见用例:

  • 如果您想禁用此函数帧及其递归调用的帧上的编译,请使用 torch._dynamo.disable

  • 如果您想使特定操作符,如 fbgemm ,使用急切模式,请使用 torch._dynamo.disallow_in_graph

一些不常见的用例包括:

  • 如果您想在函数帧上禁用 TorchDynamo 但希望在递归调用的帧上重新启用它,请使用 torch._dynamo.disable(recursive=False)

  • 如果你想防止函数帧内联,请在要防止内联的函数开头使用 torch._dynamo.graph_break

torch._dynamo.disabletorch._dynamo.disallow_in_graph 有什么区别?

Disallow-in-graph 在操作符级别上工作,或者更具体地说,是你在 TorchDynamo 提取的图中看到的操作符。

Disable 在函数帧级别上工作,并决定 TorchDynamo 是否应该检查函数帧。

torch._dynamo.disabletorch._dynamo_skip 有什么区别?

注意

torch._dynamo_skip 已弃用。

你很可能会需要 torch._dynamo.disable 。但在一个不太可能的情况下,你可能需要更精细的控制。假设你只想禁用 a_fn 函数的跟踪,但希望在 aa_fnab_fn 中继续跟踪。下面的图片展示了这个用例:

diagram of torch.compile + disable(a_fn, recursive=False)

在这种情况下,你可以使用 torch._dynamo.disable(recursive=False) 。在之前的版本中,此功能由 torch._dynamo.skip 提供。现在由 torch._dynamo.disable 中的 recursive 标志支持。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源