开始使用 CommDebugMode
指令
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:Anshul Sinha
在本教程中,我们将探讨如何使用 CommDebugMode
与 PyTorch 的分布式张量(DTensor)进行调试,通过跟踪分布式训练环境中的集体操作。
前提条件 _
Python 3.8 - 3.11
PyTorch 2.2 或更高版本
什么是 CommDebugMode
以及为什么它有用
随着模型尺寸的不断增大,用户正在寻求利用各种并行策略的组合来扩展分布式训练。然而,现有解决方案之间的互操作性不足构成了一个重大挑战,这主要是因为缺乏一个可以连接这些不同并行策略的统一抽象。为了解决这个问题,PyTorch 提出了 DistributedTensor(DTensor),它抽象了分布式训练中张量通信的复杂性,为用户提供无缝体验。然而,在处理现有的并行解决方案以及使用统一抽象如 DTensor 开发并行解决方案时,底层集体通信的透明度不足可能会给高级用户识别和解决问题带来挑战。为了应对这一挑战, CommDebugMode
,一个 Python 上下文管理器,将作为 DTensor 的主要调试工具之一,使用户能够查看在使用 DTensor 时何时以及为什么发生集体操作,从而有效解决这一问题。
使用 CommDebugMode
¶
这就是如何使用 CommDebugMode
的方法:
# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)
这是 MLPModule 在噪声级别 0 的输出示例:
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
要使用 CommDebugMode
,您必须将运行模型的代码包裹在 CommDebugMode
中,并调用您想要用于显示数据的 API。您还可以使用 noise_level
参数来控制显示信息的详细程度。以下是每个噪声级别显示的内容:
在上述示例中,您可以看到集体操作 all_reduce 在 MLPModule
的前向传播中只发生一次。此外,您可以使用 CommDebugMode
来定位到 all-reduce 操作发生在 MLPModule
的第二层线性层中。
以下是您可以使用来自上传自己的 JSON 导出的交互式模块树可视化:
结论 ¶
在本教程中,我们学习了如何使用 CommDebugMode
来调试使用 PyTorch 通信归约的分布式张量和并行解决方案。您可以在嵌入的视觉浏览器中使用自己的 JSON 输出。
关于 CommDebugMode
的更详细信息,请参阅 comm_mode_features_example.py