Dynamo 深度解析 ¶
TorchDynamo(或简称 Dynamo)是 torch.compile
中的追踪器,它往往是那些疯狂回溯的罪魁祸首。然而,我们不能盲目地将这些错误归咎于 Dynamo。为了提供用户所需的灵活性,Dynamo 被赋予了理解任何 Python 程序的艰巨任务。特别是,Dynamo 必须在内部实现 Python 编程语言的大部分功能!
在本文中,我们将从头开始介绍 Dynamo 的内部设计。我们将讨论它提供的功能以及它的实现方式。到本文结束时,您将更好地理解当您 torch.compiled
PyTorch 程序时出了什么问题,以及编译出错,或者成功但加速效果不符合预期的情况。
动态规划的温和介绍
在深入探讨所有实现细节之前,让我们先讨论一下 Dynamo 的功能是什么。
Dynamo 是一个跟踪器。这意味着,给定一个函数和输入,它将执行该函数并将一系列线性指令(没有控制流)记录到图中。例如,考虑以下程序:
import torch
@torch.compile
def mse(x, y):
z = (x - y) ** 2
return z.sum()
x = torch.randn(200)
y = torch.randn(200)
mse(x, y)
如果我们将这个程序保存到文件 example.py
中并运行
TORCH_LOGS=graph_code python example.py
我们看到了 Dynamo 追踪的结果
def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
# File: example.py:5, code: z = (x - y) ** 2
sub = l_x_ - l_y_
z = sub ** 2
# File: example.py:6, code: return z.sum()
sum_1 = z.sum()
return (sum_1,)
我们称这为给定输入的函数的图(或追踪)。这通过 FX 图来表示。我们可以简单地将 FX 图视为存储函数调用列表的容器。
我们首先应该注意到,该图是 PyTorch 操作的线性序列。[1] Dynamo 记录所有 PyTorch 操作并按顺序存储。例如,它将 z = (x - y) ** 2
分解为其两个构成的操作, sub = l_x_ - l_y_
和 z = sub ** 2
。
当我们说追踪是线性的,我们的意思是没有任何分支或任何控制流。为了看到这一点,请考虑
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
当与 TORCH_LOGS=graph_code
一起执行时,返回
def forward(l_x_: torch.Tensor):
# File: example.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: example.py:7, code: return (n + 1) * y
mul = 3 * y
return (mul,)
我们可以看到,Dynamo 完全从跟踪中移除了 if
语句,只记录了使用输入执行的操作。
因此,应该很清楚,函数的跟踪依赖于输入。特别是,这意味着当我们在 fn(x, 2)
函数中写入 @torch.compile
时不会生成跟踪,而是在使用实际参数执行函数时生成跟踪。
另一个值得注意的有趣之处在于,Dynamo 移除了函数的第二个参数。相反,它将其视为一个常数,并在图中记录了操作 n + 1
的结果。这是 Dynamo 的另一个特性:Dynamo 会将任何非张量值视为常数……除了整数。现在让我们看看整数是如何特殊的。
动力机的最后一个定义属性是它知道如何处理动态形状。符号形状指的是 Dynamo 追踪形状的能力,更一般地,是整数,而不是将它们作为常数。这允许避免重新编译和部署适用于任何大小的通用模型。动态形状出现的主要例子包括批量大小,我们可能使用固定的批量大小来训练模型,但随后对任意批量大小进行推理,或者处理文本或音频时遇到的变量序列长度。
我们可以通过多次执行上面的例子来看到这一点
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)
在这种情况下, TORCH_LOGS=graph_code
生成了两个额外的图
# Graph for n==2 omitted
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:7, code: return (n + 1) * y
add = l_n_ + 1
mul = add * y
return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:9, code: return y / n
truediv = y / l_n_
return (truediv,)
动力机检测到某个整数在第一次调用后改变了其值,并开始追踪它。我们看到这些图是通用的,并通过类型为 SymInt
的对象符号地追踪变量 n
。
如果在这些调用之后我们调用 fn(x, 4)
,Dynamo 将不会重新编译,而是重用已经追踪的图。
总结如下:1. Dynamo 是一个 Python 追踪器 2. 给定一些输入,它返回一个包含 PyTorch 函数的 FX 图 3. 如果它检测到整数在调用之间发生了变化,它还可以追踪整数 4. 它专门处理任何不是张量或标量的其他值
当然,Dynamo 还能做更多的事情,比如确定何时需要重新追踪,重写函数的字节码,实现图断点……为了使介绍简短,我们将在后续内容中逐步讨论所有这些。
PEP 523:向 CPython 添加帧评估 API
现在想象一下,如果我们被赋予实现 Dynamo 的任务,我们应该从哪里开始呢?对我们来说,非常方便的是,PEP 523 与 Python 3.6 一同发布。这个 PEP 旨在允许第三方为 Python 创建 JIT 编译器。让我们看看如何实现。
关于 CPython 的说明:CPython 在内部实现为一个栈式机器。Python 程序被编译成字节码,然后由这个解释器执行。要了解更多关于这些字节码的信息,请参阅标准库中的 dis 模块。还可以查看 CPython 解释器的开发者文档以获取介绍。我们假设读者熟悉栈式机器的概念。
PEP 523 提供了一个 API,用户可以添加一个自定义的函数解释器。然后,CPython 将使用这个解释器而不是它自己的来执行函数。为了能够执行函数,在进入时,CPython 为自定义解释器提供了诸如- 函数的字节码- 函数的参数值(即局部变量)及其名称- 全局变量的值及其名称- 内置函数如 abs
或 print
等信息。
您可以在这里看到所有字段。[2]
总结来说,CPython 为用户的解释器提供了执行函数所需的所有信息。[3]
通过这个 API,我们可以通过实现一个运行代码并记录执行过程中发生的所有 PyTorch 操作的解释器来实现一个追踪器。这正是 Dynamo 所做的事情。
Dynamo 使用这个 CPython API 解析所有这些对象,并将它们打包成 Python 结构。在完成这些操作后……它从 C 语言返回到 Python。除了与 CPython 通信的这一段代码外,Dynamo 完全是用 Python 实现的。
应该很清楚,装饰器 @torch.compile
的任务是安装必要的脚手架,以便在函数被调用时将字节码、参数、全局变量等传递给 Dynamo。再次强调, @torch.compile
并没有实际编译任何内容。
在 Python 中实现 CPython ¶
那么,我们又回到了 Python 的世界。我们有了函数的字节码,以及执行它所需的所有上下文。特别是,我们落在了_convert_frame_assert 函数上。这是装饰器 torch.compile
返回的函数!我们从_dynamo.optimize 进入这个函数。装饰器 torch.compile
只是围绕 _dynamo.optimize
的一个很好的 API。
在开始实现 Python 解释器之前,我们想要定义一个 IR。特别是,我们想要将所有局部和全局变量包裹在我们的内部类中。这使我们能够更好地跟踪这些对象,并将可以以相同方式处理的对象分组在一起,以 Dynamo 的视角来看。
内部类结构的父类是 VariableTracker
,它代表了 Dynamo 理解的不同对象。例如, ListVariable
代表一个 list
对象,并在内部维护一个 VariableTrackers 列表。另一个例子是 ConstantVariable,它包裹了 Dynamo 认为的所有常量对象。我们还有针对需要特殊关注的对象的特殊子类,如 TensorVariable。所有这些内部类都定义在 torch/_dynamo/variables 文件夹中。
Python 对象在 VariableBuilder._wrap 中被包装成对应的 VariableTracker
类。这个函数实际上是一个非常长的 elif
链,它试图递归地将 Python 输入匹配到适当类型的 VariableTracker
。
调试技巧。当我们从 dynamo 得到意外结果时,有时是由构建器引起的。如果构建器的逻辑错误,有时 Dynamo 可能会将变量错误地包装在错误的 VariableTracker
类型中,这可能会在以后引起问题。在遇到 Dynamo 错误时,查看出现在错误中的 VariableTracker
类型以及抛出异常的 VariableTracker
方法非常有用。特别是,有时我们发现一个对象被跟踪为 UserDefinedObjectVariable
(这是 Dynamo 的通用类),而它应该被跟踪为更具体的东西。在这些情况下,通常要归咎于 SourceBuilder.__call__
逻辑。
调试技巧。当使用 TORCH_LOGS=dynamo
运行程序时,打印出来的一个工件是类似以下形式的行
TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]
这是原始程序的字节码以及该点的堆栈状态。这非常有用,可以帮助找到对象没有被正确追踪到 VariableTracker
的位置。
好的,我们已经有了一个用于跟踪器的 IR,现在我们只需要重新实现 CPython 的栈式机器。这通过 symbolic_convert.py 中的 InstructorTranslatorBase 实现。
InstructionTranslatorBase
有大约 200 个方法,实现了几乎所有的 Python 字节码。例如,我们可以看到 BUILD_LIST
的实现。
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
self.push(ListVariable(items, mutation_type=ValueMutationNew()))
这是由 l = [2, 3, 4]
等构造生成的字节码。在这种情况下,由于有三个元素,生成的字节码是 BUILD_LIST 3
。这意味着我们从栈顶弹出 3
个元素,并将由这三个元素组成的新列表对象推送到栈顶。
生成输出图
通过一种符号执行 Python 代码的方式,我们准备提取在给定输入的程序符号执行过程中发生的 PyTorch 操作。这通过 Dynamo 中的 OutputGraph 对象实现。 OutputGraph
对象绑定到`InstructionTranslator`对象,并跟踪创建 FX 图所需的所有数据,该图将由 Dynamo 返回。
FX 图的全部输入和中间元素都是 fx.Node
。在 Dynamo 中, fx.Node
被包装在 fx.Proxy
中。 fx.Proxy
用于构建 FX 图。特别是,它们将它们上执行的每个 PyTorch 操作记录到图中。您可以通过调用 create_proxy 创建要添加到图中的新操作。然后,我们可以通过 wrap_fx_proxy 函数将其添加到图中。
图存储张量上的操作…以及符号整数的操作。我们将在后面讨论符号整数,但首先我们将讨论 Dynamo 如何解决一个相当重要的正确性问题。
让 Dynamo 发声:守卫(Guards)
到目前为止,我们已经找到了一种完全不考虑控制流来追踪程序的方法。为此,我们重新实现了整个 CPython…如果这听起来有点过度,那是因为确实如此。torch.jit.trace 已经实现了这一点,而不需要所有这些机制,那么问题出在哪里呢?
如其文档中警告的那样, torch.jit.trace
的问题在于它仅在追踪的程序不依赖于数据时才有效。换句话说,如果程序本身是线性的,它才会正常工作。这意味着我们需要编写不使用 if-elses、for-while 循环、异常的程序。更进一步,我们使用的所有库都不能使用任何控制流!总的来说,在像 Python 这样动态的语言中不使用控制流实际上是一个巨大的限制。
JAX 通过在重新追踪后始终重新追踪和缓存图来解决此问题。另一方面,Dynamo 使用守卫来避免每次都重新追踪整个程序。
守卫是为了针对一组示例输入专门化一个帧而做出的一个假设(一个关于输入的布尔表达式)。只有在这些假设在新输入上成立的情况下,重用图才是有效的。
例如,任何作为函数输入的常量,如字符串,都会安装一个守卫,声明该输入应为类型 str
且等于我们传递的字符串。运行
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
TORCH_LOGS=guards
将打印(以及其他守卫)
___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'
这表示“局部变量 b
应具有特定的类型(在本例中为 str
,由常量 9433...
表示)且其值应为 'Hello'
”。如果我们再次执行函数并传递不同的参数
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")
我们可以通过运行 TORCH_LOGS=recompiles
来看到失败的守卫
Recompiling function fn in script.py:3
triggered by the following guard failure(s):
- L['b'] == 'Hello'
守卫在函数输入被包装在构建器中以及程序执行期间累积。我们将在下一节展示更多守卫的例子,但首先让我们讨论来源。
来源追踪如何从进入当前帧时存在的原始局部或全局变量中重建一个变量。特别是,它追踪原始的局部和全局对象以及它们包含的任何对象。在
def foo(x: Tensor, y: List[Tensor]):
a = x * y[0]
return a * x
x
和 y
的来源是 LocalSource,而 y[0]
有 GetItemSource,它存储了一个 LocalSource
。另一方面, a
将不会有来源,因为它是一个仅存在于 fx 图中的中间变量。
所有这些都在 torch/_dynamo/source.py 中定义。以下示例中我们可以看到由 GetItemSource
生成的守卫:
import torch
@torch.compile
def fn(x, l):
return x * len(l[0])
fn(torch.randn(8), ["Hi", "Hello"])
生成以下守卫
___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'
这里,我们看到由 GetItemSource
( [0]
和 [1]
)包裹的 LocalSource
( L['l']
)生成的代码。
到这一点,有了源和守卫,我们就能实现一个缓存系统,以避免重新编译,而无需每次都回溯。我们将在后续内容中更详细地讨论这个缓存系统。
仔细的读者可能会注意到,这还没有解释为什么我们需要对 Python 解释器有如此精细的控制,以至于不得不重新实现它。我们展示的守卫示例依赖于输入对象,因此我们仍然可以在执行函数之前计算这些值。换句话说,我们可以在 torch.jit.trace
之上实现这个守卫系统,以更少的努力获得相同的功能……进入符号形状。
符号形状
在引言中我们讨论的另一个观点是,Dynamo 知道如何追踪整数。为了实现这一点,我们使用一个符号类 torch.SymInt,它类似于 int
,但它记录了在输出 FX 图上对其执行的所有操作。[4]我们在介绍符号整数追踪时已经看到了这个类。
让我们讨论一下定义 Dynamo 中符号形状追踪的三个属性,以及如何实现它们。
默认静态
Dynamo 默认认为每个整数,无论是输入还是张量的形状,都是静态的。换句话说,在函数第一次执行时,不会追踪任何整数。然后,只有当它检测到整数或形状在执行过程中发生变化时,才会追踪它并生成一个关于该变量的通用图。
我们已经在介绍中看到了这种行为,使用了整数。现在让我们通过张量形状的例子来看一下。
import torch
@torch.compile
def fn(a, b):
return a.shape[0] * a * b
fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))
使用 TORCH_LOGS=graph_code
运行此程序,我们看到这两个调用都被追踪了
def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
mul = 4 * l_a_
mul_1 = mul * l_b_
return (mul_1,)
def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
size = l_a_.size()
getitem = size[0]
mul = getitem * l_a_
mul_1 = mul * l_b_
return (mul_1,)
在第一个图中,形状被追踪为一个常量,但一旦它发生变化,它就会使用 SymInt
符号来象征性地追踪它。一般来说,通过运行程序使用 TORCH_LOGS=graph_sizes
可以更简单地看到中间值的形状。
TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)
我们可以看到,由于它由 s0
变量表示,两个张量参数的第一个维度是动态的。
通过运行 TORCH_LOGS=guards
,我们可以找到 Dynamo 是如何实现这一点的。
# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]
我们看到,在第一次调用时,守卫检查张量是否有某些固定的尺寸和步长。这些守卫在第二次执行中失败,因此它重新追踪。由于是 int
守卫失败的,所以在第二次迭代中,它符号性地追踪这个 int
符号,并在更通用的内核上安装更通用的守卫。
编译性能技巧。如果你知道一个维度的大小会变化,你可以在调用 torch.compile
之前使用 torch._dynamo.mark_dynamic 将其标记为动态。这将避免具有静态形状的第一个编译。还有其他有用的实用函数,如 maybe_mark_dynamic
或 mark_static
。你还可以通过调用 torch.compile(dynamic=True)
使所有整数和形状被追踪。这主要用于调试目的。
0 和 1 总是具有专用性
无论我们是否将一个维度标记为动态,如果我们传递一个包含该维度为 0 或 1 的输入,Dynamo 都会将其追踪为非动态,并为它生成一个特定的图。这就是为什么在上面的例子中我们会找到形式为 2 <= L['a'].size()[0]
的守卫。
这个选择有几个原因。其中两个尤为重要——一个张量如果其任何维度为零,则它是空的——一个张量只有当其中一个步长为 1 时才能是连续的
这个策略决策不适用于普通的 Python 整型;如果我们认为 Python 整型应该被动态编译,我们不会默认为其专用化;相反,它是否被专用化取决于其使用情况。
鸭子形状 ¶
Dynamo 执行我们所说的“鸭子形状”。如果两个动态整数在跟踪时间具有相同的值,我们将假设它们是相等的,并对其加以保护。实际上,这意味着我们不再像上面示例中那样有两个符号 s0
, s1
,而是将它们统一为 s0
,并有了保护 L['b'].size()[0] == L['a'].size()[0]
。这使我们在编译器内部执行融合的同时,能够生成足够通用的内核。
符号整数的保护 ¶
现在,我们已经了解了符号形状在高级别上的实现及其属性。那么,为什么符号形状要我们通过控制 CPython 解释器的复杂途径呢?考虑以下示例:
import torch
@torch.compile(dynamic=True)
def fn(a):
if a.shape[0] * 2 < 16:
return a
else:
return a + 1
fn(torch.randn(8))
这段代码有一个形式的守卫 2*L['a'].size()[0] >= 16
。从函数的输入角度来看,这是一个非平凡的守卫,但它被注册在程序执行的中间。更重要的是,我们只有在看到基于 SymNodeVariable
参数的 if
语句的条件之后,才能知道这个守卫是必需的。这样的条件对 torch.jit.trace
来说是不可见的,需要深入分析 Python 代码。
调试技巧 使用 TORCH_LOGS=dynamo
运行此代码可以告诉我们守卫添加的位置
eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)
在那里设置断点并查看回溯对于理解守卫的来源非常有用。
使 Dynamo 完美:图断点 ¶
我们讨论的所有工具中,都有一个可以追踪张量和整数上 PyTorch 操作的追踪器,并且它有一个缓存系统,知道何时可以重用之前追踪的图,何时需要重新追踪。所有这些都可以执行任意的 Python 代码!
这里有一个小问题。说“执行任意的 Python 代码”可能有点过于笼统。Dynamo 实现了 Python 的大部分功能,但它是否实现了更复杂的部分,比如协程或异步?它是否实现了整个 Python 标准库?NumPy 也有 Python API。 torch.compile
是否也理解 NumPy 和 Django?[5]
Python 的生态系统非常庞大,其中很大一部分是用 C++或 Rust 等其他更高效的语言编写的,并且它们只是暴露了 Python 绑定。在 Dynamo 中追踪由 C++实现的 Python 对象是没有希望的。当追踪器遇到它不理解的操作时,它能做什么呢?
机器学习追踪器通常处理这个问题的方法是通知用户他们卡住的操作并完全放弃追踪。这在 PyTorch 的情况下会引发真正的可用性问题,因为它的用户已经习惯了它给予他们的灵活性。作为一个现实世界的例子, doctr_det_predictor
模型使用 NumPy 和 cv2
库来后处理模型的结果。
这又是拥有访问 CPython 权限有趣的地方。Dynamo 不会出错,而是让 CPython 运行那段有问题的代码!为此,Dynamo 在追踪时生成一个包含问题代码之前所有操作的图,以及一个包含问题代码之后所有操作的图。[6]然后,在运行时,它将委托给 CPython 执行第一个图,然后是问题代码,然后是第二个图。这个过程称为停止追踪和生成多个图,称为图断开。
一项小小的坦白:我在整个引言和第一部分都撒了谎。Dynamo 并不是生成一个图,而是生成多个图!从第二个图开始回溯,可以被认为是开始追踪一个新的函数。在图断点之后的新的图将拥有自己的守卫、新的局部变量集等等。
要讨论如何实现图断点,我们首先需要回顾 Dynamo 如何与 CPython 交互。使用 PEP 523,CPython 允许用户使用他们自己的帧评估机制。我们还没有讨论的是,CPython 还向其他人公开了自己的帧评估。Dynamo 利用这一点,让快速的 CPython 解释器运行编译后的代码。对于一个没有图断点的函数,当函数被带有相同参数的调用两次时,整个追踪/执行过程如下:
在第一次调用函数时
Dynamo 将函数追踪到 FX 图中
FX 图被编译器(感应器)编译成高效的底层代码……但这又是另一个故事了
它重写函数的字节码,使其直接调用编译后的函数
它给 CPython 这个新的字节码,并要求它运行[这里]
在对函数的第二次调用中
它检查第一次调用中的守卫与新参数[这里]是否相同。由于它们与之前相同,因此通过
它要求 CPython 运行与这些守卫相关的字节码[这里]
单独来看,这个过程看起来过于复杂。为什么生成新的字节码并要求 CPython 运行它,而不是简单地创建一个指向编译函数的 C++绑定并执行它呢?嗯,这种模式允许我们实现图断点!图断点生成的字节码具有以下结构:
执行第一个图的字节码
字节码,如果 CPython 执行了第一个图,则它将像那样离开栈。它还会回放任何对局部或全局变量的修改,这些修改在此点可见
导致 Dynamo 图崩溃的字节码
执行第二个图的字节码
让我们用一个简单的例子来看看
import torch
@torch.compile
def fn(a):
b = a + 2
print("Hi")
return b + a
fn(torch.randn(4))
使用 TORCH_LOGS=bytecode
运行此代码可以显示初始字节码和修改后的字节码
MODIFIED BYTECODE fn script.py line 3
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 CALL_FUNCTION 1
6 STORE_FAST 3 (graph_out_0)
8 LOAD_GLOBAL 0 (print)
10 LOAD_CONST 2 ('Hi')
12 LOAD_FAST 3 (graph_out_0)
14 LOAD_CONST 3 (0)
16 BINARY_SUBSCR
18 STORE_FAST 1 (b)
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST 0 (a)
28 LOAD_FAST 1 (b)
30 CALL_FUNCTION 3
32 RETURN_VALUE
MODIFIED BYTECODE resume_in_fn script.py line 6
0 LOAD_GLOBAL 1 (__compiled_fn_2)
2 LOAD_FAST 2 (b)
4 LOAD_FAST 1 (a)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 1
10 RETURN_VALUE
我们可以看到修改后的字节码被拆分为两个函数, fn
,原始函数,以及一个名为 resume_in_fn
的函数。这个第二个函数是由 Dynamo 创建的,用于实现从图断点开始程序的执行。这通常被称为延续函数。这个延续函数简单地调用第二个编译函数,并传递正确的参数。初始函数的代码被重写,实现了我们之前描述的策略
L0-4. 调用编译函数(
a + 2
)。L6. 将其结果存储在名为
graph_out_0
的局部变量中。graph_out_0
是一个元组L8-18. 将堆栈保留在图断点处的状态
L20. 执行导致图断点的代码
L22-32. 调用编译后的延续函数(
a + b
)
在 Dynamo 中,堆栈的代码生成委托给 VariableTracker
子类。Dynamo 中的每个 VariableTracker
对象都有一个 reconstruct 方法,该方法生成必要的字节码,以在堆栈上创建它所表示的 Python 对象。
调试技巧。图形中断会影响性能,因此最好避免它们。使用 TORCH_LOGS=graph_breaks
运行程序是查找我们的程序触发了多少图形中断的好方法。它返回的信息是以 VariableTracker
对象为单位的,因此上述调试技巧有时也有助于找出导致图形中断的原因。
结论 ¶
Dynamo 是一个复杂的软件。一旦你注册实现 CPython 解释器,你就知道你将面临一段旅程。话虽如此,我们希望这篇帖子能帮助您稍微揭开它的神秘面纱。
Dynamo 主要是用 Python 实现的。我们留下了许多关于我们讨论的代码片段的链接。我们希望阅读这些代码片段,搜索调用它们的地方,或者在它们上设置断点并查看调用栈,有助于理解代码库的其余部分。
当然,了解一个软件如何工作的最佳方式是通过扩展它。在这种情况下,最好的方式是查看 GitHub 上的开放 Dynamo 问题。许多问题只需要对代码进行非常小的修改,一旦你找到需要修改的地方。