探索 TorchRec 分片¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
本教程将主要介绍通过 EmbeddingPlanner
和 DistributedModelParallel
API 对嵌入表进行分片方案,并通过显式配置它们来探讨不同分片方案对嵌入表的益处。
安装¶
需求:- python >= 3.7
当使用 torchRec 时,我们强烈推荐使用 CUDA。如果使用 CUDA:- cuda >= 11.0
# install conda to make installying pytorch with cudatoolkit 11.3 easier.
!sudo rm Miniconda3-py37_4.9.2-Linux-x86_64.sh Miniconda3-py37_4.9.2-Linux-x86_64.sh.*
!sudo wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local
# install pytorch with cudatoolkit 11.3
!sudo conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
安装 torchRec 还会安装 FBGEMM,这是一个包含 CUDA 内核和 GPU 启用操作的集合,用于运行
# install torchrec
!pip3 install torchrec-nightly
安装多进程,它可以在 colab 中与 ipython 一起使用,以进行多进程编程
!pip3 install multiprocess
需要以下步骤以使 Colab 运行时检测到添加的共享库。运行时会在 /usr/lib 中搜索共享库,因此我们将安装在 /usr/local/lib/ 中的库复制过来。这是一个非常必要的步骤,仅适用于 Colab 运行时。
!sudo cp /usr/local/lib/lib* /usr/lib/
在此点重新启动您的运行时,以便可以看到新安装的包。在重新启动后立即运行以下步骤,以便 Python 知道在哪里查找包。始终在重新启动运行时后运行此步骤。
import sys
sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages', './.local/lib/python3.7/site-packages']
分布式设置
由于笔记本环境,我们无法在此运行 SPMD 程序,但可以在笔记本内部进行多进程操作以模拟设置。当使用 Torchrec 时,用户应负责设置自己的 SPMD 启动器。我们设置了环境,以便基于 torch 分布式的通信后端可以工作。
import os
import torch
import torchrec
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
构建我们的嵌入模型
在这里,我们使用 TorchRec 提供的 EmbeddingBagCollection 来构建我们的嵌入包模型,并使用嵌入表。
在这里,我们创建了一个包含四个嵌入包的 EmbeddingBagCollection(EBC)。我们有两种类型的表:大表和小表,通过它们的行大小差异来区分:4096 与 1024。每个表仍然由 64 维嵌入表示。
我们为表配置了 ParameterConstraints
数据结构,这为模型并行 API 提供了提示,以帮助决定表的分片和放置策略。在 TorchRec 中,我们支持以下选项:* table-wise
:将整个表放置在一个设备上;* row-wise
:按行维度均匀分片表,并将每个分片放置在通信世界的每个设备上;* column-wise
:按嵌入维度均匀分片表,并将每个分片放置在通信世界的每个设备上;* table-row-wise
:针对主机内通信进行优化的特殊分片,适用于可用的快速主机间设备互连,例如 NVLink;* data_parallel
:为每个设备复制表;
注意我们最初在设备“meta”上分配 EBC。这将告诉 EBC 暂时不分配内存。
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict
large_table_cnt = 2
small_table_cnt = 2
large_tables=[
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
) for i in range(large_table_cnt)
]
small_tables=[
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
) for i in range(small_table_cnt)
]
def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"large_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
) for i in range(large_table_cnt)
}
small_table_constraints = {
"small_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
) for i in range(small_table_cnt)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
ebc = torchrec.EmbeddingBagCollection(
device="cuda",
tables=large_tables + small_tables
)
分布式模型并行在多进程下
现在,我们有一个单进程执行函数,用于模拟 SPMD 执行期间一个 rank 的工作。
此代码将与其他进程共同分片模型并相应地分配内存。它首先设置进程组,并使用规划器进行嵌入表放置,然后使用 DistributedModelParallel
生成分片模型。
def single_rank_execution(
rank: int,
world_size: int,
constraints: Dict[str, ParameterConstraints],
module: torch.nn.Module,
backend: str,
) -> None:
import os
import torch
import torch.distributed as dist
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ModuleSharder, ShardingEnv
from typing import cast
def init_distributed_single_host(
rank: int,
world_size: int,
backend: str,
# pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
) -> dist.ProcessGroup:
os.environ["RANK"] = f"{rank}"
os.environ["WORLD_SIZE"] = f"{world_size}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
return dist.group.WORLD
if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
topology = Topology(world_size=world_size, compute_device="cuda")
pg = init_distributed_single_host(rank, world_size, backend)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
sharded_model = DistributedModelParallel(
module,
env=ShardingEnv.from_process_group(pg),
plan=plan,
sharders=sharders,
device=device,
)
print(f"rank:{rank},sharding plan: {plan}")
return sharded_model
多进程执行 ¶
现在让我们以多个 GPU rank 的形式执行代码。
import multiprocess
def spmd_sharing_simulation(
sharding_type: ShardingType = ShardingType.TABLE_WISE,
world_size = 2,
):
ctx = multiprocess.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(
target=single_rank_execution,
args=(
rank,
world_size,
gen_constraints(sharding_type),
ebc,
"nccl"
),
)
p.start()
processes.append(p)
for p in processes:
p.join()
assert 0 == p.exitcode
表级分片 ¶
现在让我们为 2 个 GPU 执行两个进程的代码。我们可以在计划打印中看到我们的表是如何跨 GPU 分片的。每个节点将有一个大表和一个小表,这表明我们的规划器试图为嵌入表进行负载均衡。表级分片是许多小型中型表在设备上负载均衡的默认选择方案。
spmd_sharing_simulation(ShardingType.TABLE_WISE)
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}
探索其他分片模式
我们最初探讨了表级分片的样子以及它是如何平衡表的位置。现在我们专注于负载均衡的细粒度分片模式:行级分片。行级分片专门针对单个设备由于大嵌入行数导致的内存大小增加而无法容纳的大表。它可以解决模型中超级大表的位置问题。用户可以在打印的计划日志的 shard_sizes
部分看到,表在行维度上减半以分布到两个 GPU 上。
spmd_sharing_simulation(ShardingType.ROW_WISE)
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}
另一方面,列级分片针对具有大嵌入维度的表解决负载不平衡问题。我们将垂直分割表。用户可以在打印的计划日志的 shard_sizes
部分看到,表在嵌入维度上减半以分布到两个 GPU 上。
spmd_sharing_simulation(ShardingType.COLUMN_WISE)
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}
对于 table-row-wise
,由于其多主机环境下的运行特性,我们无法模拟它。我们将在未来提供一个 Python SPMD 示例来训练带有 table-row-wise
的模型。
在数据并行的情况下,我们将为所有设备重复表。
spmd_sharing_simulation(ShardingType.DATA_PARALLEL)
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}