我们非常高兴向 PyTorch 生态系统介绍 depyf
,这是一个旨在帮助用户理解、学习和适应 torch.compile
的新项目!
动机
torch.compile
是 PyTorch 2.x 的基石,只需一行代码即可轻松加速机器学习工作流程,无论是训练还是推理。仅仅包含 @torch.compile
就可以显著提高代码的性能。然而,确定 torch.compile
的最佳插入点并不容易,更不用说调整各种旋钮以实现最大效率的复杂性了。
torch.compile
堆栈的复杂性,包括 Dynamo、AOTAutograd、Inductor 等,学习曲线陡峭。这些对于深度学习性能优化至关重要的组件,如果没有对该主题的坚实基础,可能会令人望而却步。
注意:有关 torch.compile 的工作原理的入门示例,请参阅此说明性教程。
常用工具: TORCH_COMPILE_DEBUG
为了消除神秘感 torch.compile
,常用的方法涉及利用 TORCH_COMPILE_DEBUG
环境变量。虽然它提供了更多信息,但解读输出仍然是一项艰巨的任务。
例如,当我们有以下代码:
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
main()
然后使用 TORCH_COMPILE_DEBUG=1 python test.py
运行它,我们将得到一个名为 torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520
的目录,在该目录下有这些文件:
.
├── torchdynamo
│ └── debug.log
└── torchinductor
├── aot_model___0_debug.log
├── aot_model___10_debug.log
├── aot_model___11_debug.log
├── model__4_inference_10.1
│ ├── fx_graph_readable.py
│ ├── fx_graph_runnable.py
│ ├── fx_graph_transformed.py
│ ├── ir_post_fusion.txt
│ ├── ir_pre_fusion.txt
│ └── output_code.py
├── model__5_inference_11.2
│ ├── fx_graph_readable.py
│ ├── fx_graph_runnable.py
│ ├── fx_graph_transformed.py
│ ├── ir_post_fusion.txt
│ ├── ir_pre_fusion.txt
│ └── output_code.py
└── model___9.0
├── fx_graph_readable.py
├── fx_graph_runnable.py
├── fx_graph_transformed.py
├── ir_post_fusion.txt
├── ir_pre_fusion.txt
└── output_code.py
生成的文件和日志往往提出更多问题,而不是回答,让开发者对数据中的含义和关系感到困惑。常见的难题包括:
-
model__4_inference_10.1
是什么意思? - 我有一个函数,但在目录中有三个
model__xxx.py
,它们之间有什么对应关系? -
debug.log
中的那些LOAD_GLOBAL
是什么东西?
一款更好的工具: depyf
来拯救
让我们看看 depyf
如何帮助开发者解决上述挑战。要使用 depyf
,只需执行 pip install depyf
或访问项目页面 https://github.com/thuml/depyf 安装最新版本,然后将主代码放在 with depyf.prepare_debug
中。
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
import depyf
with depyf.prepare_debug("depyf_debug_dir"):
main()
执行 python test.py
后, depyf
将生成一个名为 depyf_debug_dir
的目录( prepare_debug
函数的参数)。在该目录下,将有以下文件:
.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py
并且有两个明显的优势:
- 长而难以理解的
torchdynamo/debug.log
已消失。其内容已清理并显示为可读源代码,在full_code_for_xxx.py
和__transformed_code_{n}_for_xxx.py
中。值得注意的是,depyf
的最繁琐和困难的工作是将torchdynamo/debug.log
内部的字节码反编译成 Python 源代码,从而让开发者摆脱 Python 内部令人畏惧的部分。 - 函数名与计算图之间的对应关系得到了尊重。例如,在
__transformed_code_0_for_toy_example.py
中,我们可以看到一个名为__compiled_fn_0
的函数,我们立即会知道其对应的计算图在__compiled_fn_0_xxx.py
中,因为它们具有相同的__compiled_fn_0
前缀名称。
从 full_code_for_xxx.py
开始,并跟随相关的函数,用户将清楚地了解 torch.compile
对其代码做了什么。
还有一件事:单步调试功能
使用调试器逐行执行代码是一种理解代码工作原理的好方法。然而,在 TORCH_COMPILE_DEBUG
下,这些文件仅供用户参考,无法使用用户关心的数据进行执行。
注意:这里的“调试”指的是检查和改进程序的过程,而不是纠正有错误的代码。
depyf
的一个突出特点是它能够方便地实现 torch.compile
的逐步调试:它生成的所有文件都与 Python 解释器内部的运行时代码对象相关联,我们可以在这些文件中设置断点。使用方法简单,只需添加一个上下文管理器 with depyf.debug()
,它就应该能解决问题:
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
import depyf
with depyf.prepare_debug("depyf_debug_dir"):
main()
with depyf.debug():
main()
只有一点需要注意:调试 torch.compile
的工作流程与标准调试工作流程不同。由于 torch.compile
,许多代码是动态生成的。因此,我们需要:
- 启动程序
- 当程序退出
with depyf.prepare_debug("depyf_debug_dir")
时,代码将在depyf_debug_dir
中可用。 - 当程序进入
with depyf.debug()
时,它将自动内部设置断点,使程序暂停。 - 导航到
depyf_debug_dir
以设置断点。 - 继续运行代码,调试器将触发这些断点!
这是一张截图,展示了它的样子。所有代码和张量变量都是实时状态,我们可以检查任何变量,并逐步执行代码,就像我们日常的调试工作流程一样!唯一的区别是我们现在调试的是 torch.compile
生成的代码,而不是人类编写的代码。
结论
torch.compile
是一个加速 PyTorch 代码的无价工具。对于那些想要深入了解 torch.compile
,无论是为了充分利用其全部功能还是为了集成自定义操作,学习曲线可能会非常陡峭。 depyf
旨在降低这一障碍,提供用户友好的体验,以便理解、学习和适应 torch.compile
。
请探索 depyf
并亲身体验其益处!该项目是开源的,可在 https://github.com/thuml/depyf 轻松获取。安装简单,通过 pip install depyf
即可完成。我们希望 depyf
能够通过 torch.compile
增强每个人的开发工作流程。