• 教程 >
  • 开始使用 DeviceMesh
快捷键

入门学习 DeviceMesh ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:张 Iris,梁 万超

备注

在 github 上查看和编辑此教程。

前提条件:

设置分布式通信器,即 NVIDIA 集体通信库(NCCL)通信器,对于分布式训练可能是一个重大挑战。对于需要组合不同并行性的工作负载,用户需要手动设置和管理 NCCL 通信器(例如, ProcessGroup )以针对每个并行性解决方案。这个过程可能很复杂,容易出错。 DeviceMesh 可以简化这个过程,使其更易于管理且更不容易出错。

什么是 DeviceMesh ¶

DeviceMesh 是一个高级抽象,用于管理 ProcessGroup 。它允许用户轻松创建跨节点和节点内进程组,无需担心如何为不同的子进程组正确设置排名。用户还可以通过 DeviceMesh 轻松管理多维并行性背后的进程组/设备。

PyTorch DeviceMesh

为什么 DeviceMesh 很有用 ¶

当需要处理多维度并行(即 3-D 并行)且需要并行组合性时,DeviceMesh 非常有用。例如,当您的并行解决方案需要跨主机和每个主机内的通信时。上面的图片显示,我们可以创建一个 2D 网格,将每个主机内的设备连接起来,并在同质设置中将每个设备与其对应的主机上的设备连接起来。

没有 DeviceMesh,用户在应用任何并行之前需要手动设置 NCCL 通信器、每个进程上的 cuda 设备,这可能相当复杂。以下代码片段展示了没有 DeviceMesh 的情况下混合分片 2-D 并行模式设置。首先,我们需要手动计算分片组和复制组。然后,我们需要将正确的分片组和复制组分配给每个 rank。

import os

import torch
import torch.distributed as dist

# Understand world topology
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print(f"Running example on {rank=} in a world with {world_size=}")

# Create process groups to manage 2-D like parallel pattern
dist.init_process_group("nccl")
torch.cuda.set_device(rank)

# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
# and assign the correct shard group to each rank
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
shard_groups = (
    dist.new_group(shard_rank_lists[0]),
    dist.new_group(shard_rank_lists[1]),
)
current_shard_group = (
    shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
)

# Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7))
# and assign the correct replicate group to each rank
current_replicate_group = None
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
    replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
    replicate_group = dist.new_group(replicate_group_ranks)
    if rank in replicate_group_ranks:
        current_replicate_group = replicate_group

要运行上述代码片段,我们可以利用 PyTorch Elastic。让我们创建一个名为 2d_setup.py 的文件。然后,运行以下 torch elastic/torchrun 命令。

torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py

备注

为了演示的简便,我们仅使用一个节点来模拟 2D 并行。请注意,此代码片段也可以在多主机设置上运行。

init_device_mesh() 的帮助下,我们只需两行代码即可完成上述 2D 设置,如果需要,我们仍然可以访问底层的 ProcessGroup

from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))

# Users can access the underlying process group thru `get_group` API.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")

让我们创建一个名为 2d_setup_with_device_mesh.py 的文件。然后,运行以下 torch elastic/torchrun 命令。

torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py

如何使用 DeviceMesh 与 HSDP

混合分片数据并行(HSDP)是一种在主机内执行全分布式并行(FSDP)和在主机间执行数据并行(DDP)的二维策略。

让我们看看 DeviceMesh 如何通过简单的设置帮助您将 HSDP 应用于模型的一个示例。使用 DeviceMesh,用户无需手动创建和管理分片组和复制组。

import torch
import torch.nn as nn

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


# HSDP: MeshShape(2, 4)
mesh_2d = init_device_mesh("cuda", (2, 4))
model = FSDP(
    ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
)

让我们创建一个名为 hsdp.py 的文件。然后,运行以下 torch elastic/torchrun 命令。

torchrun --nproc_per_node=8 hsdp.py

如何使用 DeviceMesh 为您的自定义并行解决方案

在处理大规模训练时,您可能需要更复杂的自定义并行训练组合。例如,您可能需要为不同的并行解决方案切割出子网格。DeviceMesh 允许用户从父网格中切割出子网格,并复用在父网格初始化时已创建的 NCCL 通信器。

from torch.distributed.device_mesh import init_device_mesh
mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp"))

# Users can slice child meshes from the parent mesh.
hsdp_mesh = mesh_3d["replicate", "shard"]
tp_mesh = mesh_3d["tp"]

# Users can access the underlying process group thru `get_group` API.
replicate_group = hsdp_mesh["replicate"].get_group()
shard_group = hsdp_mesh["shard"].get_group()
tp_group = tp_mesh.get_group()

结论 ¶

总结来说,我们学习了 DeviceMeshinit_device_mesh() 的内容,以及如何使用它们来描述集群中设备的布局。

如需更多信息,请参阅以下内容:


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源