• 文档 >
  • 线性并行 - torch.distributed.tensor.parallel
快捷键

线性并行 - torch.distributed.tensor.parallel ¶

线性并行(TP)建立在 PyTorch 分布式张量(DTensor)之上,并提供了不同的并行方式:列并行、行并行和序列并行。

警告

线性并行 API 是实验性的,可能会发生变化。

使用张量并行化并行化您的 nn.Module 的入口点是:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)[source][source]

通过根据用户指定的计划并行化模块或子模块,在 PyTorch 中应用张量并行化。

我们根据 parallelize_plan 并行化模块或子模块。parallelize_plan 包含 ParallelStyle ,表示用户希望如何并行化模块或子模块。

用户还可以为每个模块的完全限定名(FQN)指定不同的并行样式。

注意, parallelize_module 只接受 1-D DeviceMesh ,如果您有 2-D 或 N-D DeviceMesh ,请先将 DeviceMesh 切片为 1-D 子 DeviceMesh,然后再传递给此 API(即 device_mesh["tp"] )。

参数:
  • 模块( nn.Module )- 要并行化的模块。

  • device_mesh( DeviceMesh ,可选)- 描述 DTensor 的设备网格拓扑的对象。如果未指定,调用必须在 DeviceMesh 上下文中进行。

  • parallelize_plan (Union[ ParallelStyle , Dict[str, ParallelStyle ]], optional) – 用于并行化模块的计划。它可以是包含如何为 Tensor Parallelism 准备输入/输出的 ParallelStyle 对象,也可以是模块 FQN 及其对应的 ParallelStyle 对象的字典。如果未指定,调用将目前不执行任何操作。

关键字参数:

src_data_rank (int, optional) – 逻辑/全局张量的源数据排名,它被 distribute_tensor() 用于将碎片/副本分散/广播到其他排名。默认情况下,我们使用每个 DeviceMesh 维度的 group_rank=0 作为源数据以保留单设备语义。如果显式传递 Noneparallelize_module() 将仅使用其本地数据,而不是尝试通过分散/广播来保留单设备语义。默认:0

返回:

一个 nn.Module 对象已并行化。

返回类型:

模块

示例::
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>

注意

对于像 Attention、MLP 层这样的复杂模块架构,我们建议将不同的 ParallelStyles 组合在一起(即 ColwiseParallelRowwiseParallel )并作为 parallelize_plan 传递,以达到所需的分片计算。

张量并行支持以下并行风格:

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source][source]

以列式方式分割兼容的 nn.Module。目前支持 nn.Linear 和 nn.Embedding。用户可以将它与 RowwiseParallel 结合使用,以实现更复杂模块的划分(例如 MLP、Attention)。

关键字参数:
  • input_layouts(放置,可选)- nn.Module 的输入张量的 DTensor 布局,用于将输入张量注释为 DTensor。如果未指定,我们假设输入张量是复制的。

  • output_layouts(放置,可选)- nn.Module 的输出 DTensor 布局,用于确保 nn.Module 的输出具有用户期望的布局。如果未指定,输出张量将在最后一个维度上进行划分。

  • 使用本地输出(布尔值,可选)- 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输出,默认:True。

返回:

代表 Colwise 分片的 nn.Module 对象 ParallelStyle

示例::
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...

注意

默认情况下,如果未指定 output_layoutsColwiseParallel 输出将在最后一个维度上进行分片,如果存在需要特定张量形状的运算符(即 RowwiseParallel 之前),请注意,如果输出被分片,运算符可能需要调整到分片大小。

class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source][source]

以行式方式分割兼容的 nn.Module。目前支持 nn.Linear 和 nn.Embedding。用户可以使用 ColwiseParallel 与之组合,以实现更复杂模块的划分(例如 MLP、Attention)。

关键字参数:
  • input_layouts(放置,可选)- nn.Module 的输入张量的 DTensor 布局,用于将输入张量标注为 DTensor。如果未指定,我们假设输入张量在最后一个维度上进行划分。

  • output_layouts(放置,可选)- nn.Module 的输出 DTensor 布局,用于确保 nn.Module 的输出符合用户期望的布局。如果未指定,输出张量将被复制。

  • 使用本地输出(布尔值,可选)- 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输出,默认:True。

返回:

代表按行分片 nn.Module 的 ParallelStyle 对象。

示例::
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
>>> ...
class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source][source]

SequenceParallel 复制兼容的 nn.Module 参数,并在序列维度上分片输入的情况下运行分片计算。目前支持 nn.LayerNormnn.Dropout ,以及 RMSNorm 的 Python 实现

该风格实现了论文《减少大型 Transformer 模型中的激活重计算》中描述的操作

如果传递给此 nn.Module 的输入是 torch.Tensor ,则假定输入已经在序列维度上进行了分片,并将输入转换为序列维度上的 DTensor 分片。如果传递给此 nn.Module 的输入已经是 DTensor 但未在序列维度上分片,则将重新分配输入以在序列维度上分片。

nn.Module 的输出将在序列维度上进行分片。

关键字参数:
  • sequence_dim (int, 可选) – 输入张量的序列维度,用于将输入张量标注为在序列维度上分片的 DTensor,默认:1。

  • use_local_output (bool, 可选) – 是否使用局部输出代替默认输出,默认:False。

返回:

代表序列并行的 nn.Module 对象。

示例::
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...

注意

SequenceParallel 风格假设在 nn.Module(即 nn.LayerNormRMSNorm )中存在权重时进行初始化(即默认为全 1 初始化)。如果您在这些模块的权重上有自定义初始化,需要在并行化前后广播权重以确保它们被复制。

仅需配置 nn.Module 的输入和输出,使用 DTensor 布局并执行必要的布局重分配,无需将模块参数分布到 DTensor 中,在调用 parallelize_module 时可以使用以下 ParallelStyle s:

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source][source]

配置 nn.Module 的输入,在运行时根据 input_layouts 将 nn.Module 的输入张量转换为 DTensor,并根据 desired_input_layouts 执行布局重分配。

关键字参数:
  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – 输入张量的 DTensor 布局,用于 nn.Module,这是将输入张量转换为 DTensor 的。如果某些输入不是 torch.Tensor 或不需要转换为 DTensor,则需要指定 None 作为占位符。默认:None。

  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 输入张量的期望 DTensor 布局,用于确保 nn.Module 的输入具有期望的 DTensor 布局。此参数的长度需要与 input_layouts 相同。默认:None。

  • input_kwarg_layouts (Dict[str, Placement]) – nn.Module 输入 kwargs 的 DTensor 布局,用于将输入 kwargs 张量转换为 DTensor。默认:None

  • desired_input_kwarg_layouts – (Dict[str, Placement]):nn.Module 输入 kwargs 的期望 DTensor 布局,用于确保 nn.Module 的输入具有期望的 DTensor 布局。默认:None。

  • 使用本地输出(bool,可选)- 是否使用本地 torch.Tensor 代替 DTensor 作为模块输入,默认:False。

返回:

一个 ParallelStyle 对象,用于准备 nn.Module 输入的分区布局。

示例::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...)
>>>         ),
>>>     }
>>> )
class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source][source]

配置 nn.Module 的输出,以便在运行时根据 output_layouts 将 nn.Module 的输出张量转换为 DTensors,并根据 desired_output_layouts 进行布局重分配。

关键字参数:
  • output_layouts(Union[Placement, Tuple[Placement]])- nn.Module 的输出张量的 DTensor 布局,用于将输出张量转换为 DTensor(如果它们是 torch.Tensor )。如果某些输出不是 torch.Tensor 或不需要转换为 DTensor, None 需要指定为占位符。

  • desired_output_layouts(Union[Placement, Tuple[Placement]])- nn.Module 的期望输出张量的 DTensor 布局,用于确保 nn.Module 的输出具有期望的 DTensor 布局。

  • use_local_output(bool,可选)- 是否使用本地 torch.Tensor 而不是 DTensor 作为模块输出,默认:True。

返回:

一个用于准备 nn.Module 输出分片布局的 ParallelStyle 对象。

示例::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan = PrepareModuleOutput(
>>>         output_layouts=Replicate(),
>>>         desired_output_layouts=Shard(0)
>>>     )
>>> )

注意

当使用 Shard(dim) 作为上述 ParallelStyle 的输入/输出布局时,我们假设输入/输出激活张量在 TP 操作的 dim 张量维度上均匀分片。例如,由于 RowwiseParallel 接受在最后一个维度上分片的输入,它假设输入张量已经在最后一个维度上均匀分片。对于不均匀分片激活张量的情况,可以将 DTensor 直接传递给分片模块,并在每个 ParallelStyle 之后使用 use_local_output=False 返回 DTensor,其中 DTensor 可以跟踪不均匀分片信息。

对于像 Transformer 这样的模型,我们建议用户在 parallelize_plan 中使用 ColwiseParallelRowwiseParallel 一起,以实现整个模型(即 Attention 和 MLP)所需的分片。

支持通过以下上下文管理器进行并行化的交叉熵损失计算(损失并行):

torch.distributed.tensor.parallel.loss_parallel()[source][source]

一个上下文管理器,启用损失并行,当输入在类别维度上分片时,可以执行高效的并行化损失计算。目前仅支持交叉熵损失。

在此上下文管理器中,可以像往常一样使用 cross_entropy()CrossEntropyLoss ,对输入参数有以下假设。如果有的话,相应的 backward() 调用也需要在此上下文管理器下进行。

参数:
  • 输入( DTensor )- 输入 logits。假设其在类别维度上分片。

  • 目标(联合 torch.TensorDTensor )- 必须是真实类别索引(目前不支持类别概率)。假定在 DeviceMesh 上复制。

  • 权重(联合 torch.TensorDTensor ,可选)- 如果提供,假定在 DeviceMesh 上复制。

  • 标签平滑 - 目前不支持。

返回:

复制的 DTensor

示例

在这里手动创建了一个分片 DTensor 以展示其用法。在实际应用中,它通常是 TP 模块的输出。

>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
>>>     loss.backward()
>>> ...

警告

loss_parallel API 是实验性的,可能会发生变化。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源