• 文档 >
  • torch.distributed.tensor
快捷键

torch.distributed.tensor ¬

注意

torch.distributed.tensor 目前处于 alpha 状态,处于开发中,我们正在为文档中列出的大多数 API 承诺向后兼容性,但如果有必要,可能会有 API 变更。

PyTorch DTensor(分布式张量)¬

PyTorch DTensor 提供简单灵活的张量分片原语,透明地处理分布式逻辑,包括分片存储、算子计算以及跨设备/主机进行集体通信。 DTensor 可用于构建不同的并行解决方案,并在处理多维分片时支持分片状态字典表示。

请参阅基于 DTensor 构建的 PyTorch 原生并行解决方案的示例:

采用 SPMD(单程序,多数据)编程模型,让用户能够像编写单设备程序一样编写分布式程序,并保持相同的收敛属性。它通过指定 DeviceMeshPlacement 提供统一的张量划分布局(DTensor 布局):

  • DeviceMesh 使用 n 维数组表示集群的设备拓扑和通信器。

  • Placement 描述了逻辑张量在 DeviceMesh 上的划分布局。DTensor 支持三种类型的放置: ShardReplicatePartial

DTensor 类 APIs ¶

DTensor 是一个 torch.Tensor 子类。这意味着一旦创建了一个 DTensor ,它可以以非常相似的方式用于 torch.Tensor ,包括运行不同类型的 PyTorch 操作,就像在单个设备上运行一样,允许 PyTorch 操作进行适当的分布式计算。

除了现有的 torch.Tensor 方法外,它还提供了一组与 torch.Tensor 交互的附加方法,例如将 DTensor Layout 转换为新的 DTensor,获取所有设备上的完整张量内容等。

class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)

DTensor (分布式张量)是 torch.Tensor 的一个子类,它提供了对多设备 torch.Tensor 的单设备抽象,通过 DeviceMesh 和以下类型的 Placement 描述分布式张量分片布局(DTensor Layout):

  • Shard : 张量在张量维度 dim 上分片,位于 DeviceMesh 维度的设备上

  • Replicate : 张量在 DeviceMesh 维度的设备上复制

  • Partial : 张量在 DeviceMesh 维度的设备上等待归约

当调用 PyTorch 运算符时, DTensor 会覆盖 PyTorch 运算符以执行分片计算并在必要时发出通信。除了运算符计算外, DTensor 还会根据运算符本身的语义正确地转换或传播放置(DTensor 布局)并生成新的 DTensor 输出。

确保在调用 PyTorch 运算符时 DTensor 的数值正确性, DTensor 要求运算符的每个 Tensor 参数必须是 DTensor。

注意

直接使用 Tensor 子类构造函数在这里不是推荐的方式创建一个 DTensor (即它不能正确处理 autograd,因此不是公共 API)。请参阅 create_dtensor 部分,了解如何创建一个 DTensor

返回类型:

DTensor

__create_chunk_list__()[source][source]

返回一个 ChunkStorageMetadata 列表,该数据类描述了当前 rank 上本地分片/副本的大小/偏移量。对于 DTensor,每个 rank 将有一个单独的本地分片/副本,因此返回的列表通常只有一个元素。

该魔术方法主要用于分布式检查点目的。

返回:

表示当前 rank 上分片大小/偏移量的 List[ ChunkStorageMetadata ]对象。

static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]

根据指定的 device_meshplacements 在每个 rank 上从本地 torch.Tensor 创建一个 DTensor

参数:
  • local_tensor (torch.Tensor) – 每个 rank 上的本地 torch.Tensor。

  • device_mesh ( DeviceMesh ,可选) – 放置张量的 DeviceMesh,如果未指定,必须在 DeviceMesh 上下文管理器下调用,默认:None

  • placements (List[ Placement ],可选) – 描述如何在 DeviceMesh 上放置本地 torch.Tensor 的 placements,必须与 device_mesh.ndim 具有相同数量的元素。

关键字参数:
  • run_check (bool, 可选) – 以额外的通信为代价,在所有进程间执行一致性检查,以检查每个本地张量的元信息以确保正确性。如果存在 Replicateplacements 中,设备网格维度上的第一个进程的数据将被广播到其他进程。默认:False

  • shape (torch.Size, 可选) – 一个指定 DTensor 大小的 int 列表,该 DTensor 基于 local_tensor 构建。注意,如果 local_tensor 的形状在不同进程间不同,则需要提供此信息。如果不提供,将假设给定的分布式张量在所有进程间均匀划分。默认:None

  • stride (tuple, 可选) – 指定 DTensor 步长的 int 列表。如果不提供,将假设给定的分布式张量在所有进程间均匀划分。默认:None

返回:

一个 DTensor 对象

返回类型:

DTensor

注意

run_check=False 时,用户有责任确保传入的本地张量在各个 rank 上正确(即张量已根据 Shard(dim) 放置进行分片或根据 Replicate() 放置进行复制)。如果不这样做,创建的 DTensor 的行为将是未定义的。

注意

如果 from_local 可微分,则创建的 DTensor 对象的 requires_grad 将取决于 local_tensor 是否需要 requires_grad。

full_tensor(*, grad_placements=None)[source][source]

返回此 DTensor 的全张量。它将执行必要的集体操作以收集其 DeviceMesh 中其他 rank 的本地张量并将它们连接在一起。这是以下代码的语法糖:

dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()

关键字参数:

grad_placements (List[ Placement ], 可选) – 该放置描述了从该函数返回的全 Tensor 的任何梯度布局的未来布局。full_tensor 将 DTensor 转换为全 torch.Tensor,返回的 torch.tensor 可能不会在代码中稍后用作原始复制的 DTensor 布局。此参数是用户可以向 autograd 提供的提示,如果返回张量的梯度布局与原始复制的 DTensor 布局不匹配。如果未指定,我们将假设全张量的梯度布局为复制。

返回:

表示此 DTensor 完整张量的一个对象。

返回类型:

张量

注意

full_tensor 是可微分的。

redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]

redistribute 执行必要的集体操作,将当前 DTensor 从其当前放置位置重新分配到新的放置位置,或从当前 DeviceMesh 重新分配到新的 DeviceMesh。即,我们可以通过为 DeviceMesh 的每个维度指定复制放置,将分片 DTensor 转换为复制 DTensor。

当从当前分配到新设备网格维度上的分配进行重新分配时,我们将执行以下操作,包括通信集体或局部操作:

  1. Shard(dim) -> Replicate(): all_gather

  2. Shard(src_dim) -> Shard(dst_dim): all_to_all

  3. Replicate() -> Shard(dim) : 本地分块(即 torch.chunk

  4. Partial() -> Replicate(): all_reduce

  5. Partial() -> Shard(dim): reduce_scatter

redistribute 将正确地确定为在 1-D 或 N-D DeviceMesh 上创建的 DTensors 分配所需的重新分配步骤。

参数:
  • device_mesh ( DeviceMesh , 可选) – 放置 DTensor 的 DeviceMesh。如未指定,将使用当前 DTensor 的 DeviceMesh。默认:None

  • placements (List[ Placement ], 可选) – 描述如何将 DTensor 放置到 DeviceMesh 中的新放置,必须与 device_mesh.ndim 的元素数量相同。默认:在所有网格维度上复制

关键字参数:

async_op (bool, 可选) – 是否异步执行 DTensor 重新分配操作。默认:False

返回:

一个 DTensor 对象

返回类型:

DTensor

注意

redistribute 可微,这意味着用户无需担心重新分配操作的逆向公式。

注意

redistribute 目前仅支持在同一 DeviceMesh 上重新分配 DTensor,如需将 DTensor 重新分配到不同的 DeviceMesh,请提交问题。

to_local(*, grad_placements=None)[source][source]

获取此 DTensor 在其当前 rank 上的本地张量。对于分片,它返回逻辑张量视图的本地分片;对于复制,它返回当前 rank 上的副本。

关键字参数:

grad_placements(List[ Placement ],可选)- 描述了从该函数返回的 Tensor 的任何梯度布局的未来布局。to_local 将 DTensor 转换为本地张量,返回的本地张量可能不会在代码的后续部分用作原始 DTensor 的布局。此参数是用户可以向 autograd 提供的提示,如果返回的张量梯度布局与原始 DTensor 布局不匹配。如果未指定,我们将假设梯度布局与原始 DTensor 保持相同,并使用该布局进行梯度计算。

返回:

一个 torch.TensorAsyncCollectiveTensor 对象。它代表当前 rank 上的本地张量。当返回 AsyncCollectiveTensor 对象时,表示本地张量尚未准备好(即通信尚未完成)。在这种情况下,用户需要调用 wait 以等待本地张量准备好。

返回类型:

张量

注意

to_local 可导,返回的局部张量的 requires_grad 将取决于 DTensor 是否需要梯度。

属性 device_meshDeviceMesh ¶

与此 DTensor 对象关联的 DeviceMesh 属性。

注意

device_mesh 是只读属性,不能设置。

属性放置 tuple[torch.distributed.tensor.placement_types.Placement...]

该 DTensor 的 placements 属性描述了该 DTensor 在其 DeviceMesh 上的布局。

注意

placements 是一个只读属性,不能设置。

DeviceMesh 作为分布式通信器。

DeviceMesh 是从 DTensor 构建而成的抽象,用于描述集群的设备拓扑结构并表示多维通信器(在 ProcessGroup 之上)。有关如何创建/使用 DeviceMesh 的详细信息,请参阅 DeviceMesh 菜谱。

DTensor 放置类型 ¶

DTensor 支持以下类型的 Placement 在每个 DeviceMesh 维度上:

class torch.distributed.tensor.placement_types.Shard(dim)[source][source]

Shard(dim) 放置描述了在 dim 索引的 tensor 维度上对 DTensor 的分片,在 DeviceMesh 维度上,每个 rank 只持有全局 Tensor 的一个分片/片段。 Shard(dim) 放置遵循 torch.chunk(dim) 语义,当 tensor 维度不能被 DeviceMesh 维度整除时,DeviceMesh 维度上的最后几个分片可能为空。 Shard 放置可以被所有 DTensor API(例如 distribute_tensor、from_local 等)使用。

参数:

dim(int)- 描述 DTensor 在其对应的 DeviceMesh 维度上分片的 tensor 维度。

警告

在 tensor 维度大小不能被 DeviceMesh 维度整除的情况下对 tensor 维度进行分片的功能目前处于实验阶段,可能发生变化。

dim: int
class torch.distributed.tensor.placement_types.Replicate[source][source]

Replicate() 放置描述了在相应的 DeviceMesh 维度上复制的 DTensor,其中每个在 DeviceMesh 维度上的 rank 持有全局 Tensor 的副本。 Replicate 放置可由所有 DTensor API 使用(即 distribute_tensorDTensor.from_local ,等等)

class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]

Partial(reduce_op) 放置描述了在指定 DeviceMesh 维度上待减小的 DTensor,其中每个在 DeviceMesh 维度上的 rank 持有全局 Tensor 的部分值。用户可以使用 redistributePartial DTensor 重新分配到指定 DeviceMesh 维度上的 ReplicateShard(dim) 放置,这将触发底层的必要通信操作(即 allreducereduce_scatter

参数:

reduce_op (字符串,可选) – 用于将部分 DTensor 转换为 Replicated/Sharded DTensor 的缩减操作。仅支持元素级缩减操作,包括:“sum”(求和)、“avg”(平均值)、“product”(乘积)、“max”(最大值)、“min”(最小值),默认:“sum”。

注意

Partial 放置可以由 DTensor 操作生成,并且只能由 DTensor.from_local API 使用。

reduce_opstr='sum'
class torch.distributed.tensor.placement_types.Placement[source][source]

放置类型的基类,其中描述了如何将 DTensor 放置到 DeviceMeshPlacementDeviceMesh 一起可以描述 DTensor 的布局。它是三个主要 DTensor 放置类型的基础类: ShardReplicate ,和 Partial

此类不打算直接使用,主要作为类型占位符。

is_partial()[来源][来源] ¶
返回类型:

布尔型

is_replicate()[来源][来源] ¶
返回类型:

布尔型

is_shard(dim=None)[source][source]
返回类型:

布尔型

不同方式创建 DTensor

有三种方式来构建一个 DTensor :
  • distribute_tensor() 从每个 rank 上的逻辑或“全局” torch.Tensor 创建一个 DTensor 。这可以用来分片叶节点 torch.Tensor (即模型参数/缓冲区和输入)。

  • 在每个进程上从本地创建一个,可用于从非叶子节点创建(即正向/反向过程中的中间激活张量)。

  • DTensor 提供了专门的张量工厂函数(例如, empty()ones()randn() 等),允许通过直接指定 DeviceMeshPlacement 来创建不同的 DTensor 。与 distribute_tensor() 相比,这可以直接在设备上物化分片内存,而不是在初始化逻辑张量内存后进行分片。

从逻辑 torch.Tensor 创建 DTensor

SPMD(单程序,多数据)编程模型在 torch.distributed 中启动多个进程(即通过 torchrun )来执行相同的程序,这意味着程序中的模型首先会在不同的进程中初始化(即模型可能首先在 CPU 上初始化,或在元设备上初始化,或者在有足够内存的情况下直接在 GPU 上初始化)。

提供了一个 API,可以将模型权重或张量分片到 DTensor ,其中它会在每个进程中创建一个由“逻辑”张量构成的 DTensor。这将使创建的 DTensor 能够符合单设备语义,这对于数值正确性至关重要。

torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]

将叶 torch.Tensor (即 nn.Parameter/buffers)根据指定的 placements 分配到 device_meshdevice_meshplacements 的 rank 必须相同。要分配的是逻辑或“全局”张量,API 将使用 DeviceMesh 维度的第一个 rank 的 tensor 作为真实来源以保持单设备语义。如果您想在 Autograd 计算过程中构建 DTensor,请使用 DTensor.from_local()

参数:
  • tensor (torch.Tensor) – 要分配的 torch.Tensor。注意,如果您想在不是该网格维度中设备数量整除的维度上分片张量,我们将使用 torch.chunk 语义来分片张量并分散碎片。不均匀分片行为是实验性的,可能随时更改。

  • device_mesh ( DeviceMesh ,可选) – 将张量分布到 DeviceMesh 的 DeviceMesh,如果未指定,必须在 DeviceMesh 管理器下调用,默认:None

  • placements (List[ Placement ],可选) – 描述如何在 DeviceMesh 上放置张量的 placements,必须与 device_mesh.ndim 的元素数量相同。如果未指定,我们将默认将张量复制到 device_mesh 的设备_mesh 的每个维度的第一个秩。

关键字参数:

src_data_rank (int,可选) – 逻辑/全局张量的源数据秩,它由 distribute_tensor() 用于将碎片/副本分散/广播到其他秩。默认情况下,我们使用 group_rank=0 在每个 DeviceMesh 维度上作为源数据以保留单设备语义。如果显式传递 Nonedistribute_tensor() 将仅使用其本地数据,而不是尝试通过分散/广播来保留单设备语义。默认:0

返回:

一种 DTensorXLAShardedTensor 对象。

返回类型:

DTensor

注意

当使用 xla 设备类型初始化 DeviceMesh 时, distribute_tensor 返回 XLAShardedTensor。有关更多详细信息,请参阅此问题。XLA 集成是实验性的,可能会发生变化。

除了 distribute_tensor() 之外,DTensor 还提供了一个 distribute_module() API,以允许在 nn.Module 层面上更容易地进行分片。

torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]

此函数暴露了三个函数来控制模块的参数/输入/输出:

1. 在运行时之前对模块进行分片,通过指定 partition_fn (即允许用户根据指定的 partition_fn 将 Module 参数转换为 DTensor 参数)。2. 通过指定 input_fnoutput_fn 来控制模块在运行时的输入或输出。(即转换输入为 DTensor ,将输出转换回 torch.Tensor

参数:
  • module ( nn.Module ) – 要分片的用户模块。

  • 设备网格( DeviceMesh )- 放置模块的设备网格。

  • partition_fn(可调用函数)- 分区参数的函数(即跨 device_mesh 分片某些参数)。如果未指定 partition_fn ,则默认情况下,将 module 的所有模块参数复制到网格中。

  • input_fn(可调用函数)- 指定输入分布,即可以控制模块的输入如何分片。 input_fn 将作为模块 forward_pre_hook (预前向钩子)安装。

  • output_fn(可调用函数)- 指定输出分布,即可以控制输出如何分片,或将它转换回 torch.Tensor。 output_fn 将作为模块 forward_hook (后向钩子)安装。

返回:

包含所有参数/缓冲区的模块,均为 DTensor

返回类型:

模块

注意

当使用 xla 设备类型初始化 DeviceMesh 时, distribute_module 返回具有 PyTorch/XLA SPMD 注解参数的 nn.Module。有关更多详细信息,请参阅此问题。XLA 集成是实验性的,可能会更改。

DTensor 工厂函数

DTensor 还提供了专门的张量工厂函数,允许通过指定 DeviceMeshPlacement 来直接使用 torch.Tensor 类似工厂函数 API(即 torch.ones、torch.empty 等)创建 DTensor

torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

返回一个填充了标量值 0 的 DTensor

参数:

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))

关键字参数:
  • requires_grad(布尔值,可选)- 如果 autograd 应该记录返回的 DTensor 的操作。默认: False

  • dtype( torch.dtype ,可选)- 返回的 DTensor 的期望数据类型。默认:如果 None ,则使用全局默认值(见 torch.set_default_dtype() )。

  • layout( torch.layout ,可选)- 返回的 DTensor 的期望布局。默认: torch.strided

  • 装置网状结构 – DeviceMesh 类型,包含各 rank 的网状信息

  • 排放序列 – 一系列 Shard 类型: Replicate

返回:

每个 rank 上的 DTensor 对象

返回类型:

DTensor

torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

返回一个填充了标量值 1 的 DTensor ,其形状由变量参数 size 定义。

参数:

size(int...)- 定义输出 DTensor 形状的一组整数。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

关键字参数:
  • dtype ( torch.dtype ,可选) – 返回的 DTensor 所期望的数据类型。默认:如果 None ,则使用全局默认值(见 torch.set_default_dtype() )。

  • layout ( torch.layout ,可选) – 返回 DTensor 所期望的布局。默认: torch.strided

  • requires_grad (bool,可选) – 如果 autograd 应记录返回的 DTensor 上的操作。默认: False

  • device_mesh – DeviceMesh 类型,包含 ranks 的网格信息

  • 放置序列 – Placement 类型: Shard , Replicate

返回:

每个等级上的一个 DTensor 对象

返回类型:

DTensor

torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

返回一个填充未初始化数据的 DTensorDTensor 的形状由变量参数 size 定义。

参数:

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))

关键字参数:
  • dtype ( torch.dtype ,可选) – 返回 DTensor 的期望数据类型。默认:如果 None ,则使用全局默认值(见 torch.set_default_dtype() )。layout ( torch.layout ,可选):返回 DTensor 的期望布局。默认: torch.strided

  • requires_grad(布尔值,可选)- 如果 autograd 应记录对返回的 DTensor 的操作。默认: False

  • device_mesh - DeviceMesh 类型,包含 rank 的网格信息

  • placements - Placement 类型: ShardReplicate

返回:

每个 rank 上的 DTensor 对象

返回类型:

DTensor

torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

返回一个根据 device_meshplacements 填充的 DTensor ,形状由参数 size 定义。

参数:
  • size (int...) – 定义输出 DTensor 形状的一组整数。可以是多个参数,也可以是列表或元组等集合。例如:ones(1,2,3..)或 ones([1,2,3..])或 ones((1,2,3..))

  • fill_value(标量)- 用于填充输出张量的值。

关键字参数:
  • dtype( torch.dtype ,可选)- 返回的 DTensor 期望的数据类型。默认:如果 None ,则使用全局默认值(见 torch.set_default_dtype() )。

  • layout( torch.layout ,可选)- 返回 DTensor 期望的布局。默认: torch.strided

  • requires_grad(布尔值,可选)- 如果 autograd 应记录对返回的 DTensor 的操作。默认: False

  • device_mesh - DeviceMesh 类型,包含 rank 的网格信息。

  • placements - Placement 类型的序列: ShardReplicate

返回:

每个 rank 上的 DTensor 对象

返回类型:

DTensor

torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

返回一个用 DTensor 填充的随机数,这些随机数来自区间 [0, 1) 上的均匀分布。张量的形状由变量参数 size 定义。

参数:

size (int...) – 定义输出 DTensor 形状的一组整数。可以是多个参数的变量数量,也可以是列表或元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

关键字参数:
  • dtype ( torch.dtype , 可选) – 返回的 DTensor 的期望数据类型。默认:如果 None ,则使用全局默认值(见 torch.set_default_dtype() )。

  • layout ( torch.layout , 可选) – 返回 DTensor 的期望布局。默认: torch.strided

  • requires_grad (bool, 可选) – 如果 autograd 应记录返回的 DTensor 上的操作。默认: False

  • 装置网格 – DeviceMesh 类型,包含各 rank 的网格信息。

  • 排放 – Placement 类型: ShardReplicate

返回:

每个 rank 上的 DTensor 对象

返回类型:

DTensor

torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

返回一个填充有均值为 0、方差为 1 的正态分布随机数的 DTensor 。张量的形状由变量参数 size 定义。

参数:

size(int...)- 定义输出 DTensor 形状的一组整数。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..)或 ones([1,2,3..])或 ones((1,2,3..))。

关键字参数:
  • dtype ( torch.dtype ,可选) – 返回的 DTensor 所期望的数据类型。默认:如果 None ,则使用全局默认值(见 torch.set_default_dtype() )。

  • layout ( torch.layout ,可选) – 返回 DTensor 所期望的布局。默认: torch.strided

  • requires_grad (bool,可选) – 如果 autograd 应记录返回的 DTensor 上的操作。默认: False

  • device_mesh – DeviceMesh 类型,包含 ranks 的网格信息。

  • 放置序列 – Placement 类型: ShardReplicate

返回:

每个等级上都有一个 DTensor 对象

返回类型:

DTensor

调试

记录日志

当启动程序时,您可以使用 torch._logging 中的 TORCH_LOGS 环境变量开启额外的日志记录:

  • TORCH_LOGS=+dtensor 将显示 logging.DEBUG 级别及其以上的日志信息。

  • TORCH_LOGS=dtensor 将显示 logging.INFO 级别及其以上的日志信息。

  • TORCH_LOGS=-dtensor 将显示 logging.WARNING 级别及其以上的日志信息。

调试工具 §

调试应用 DTensor 的程序,并深入了解底层发生的集体操作细节,DTensor 提供了 CommDebugMode

class torch.distributed.tensor.debug.CommDebugMode

CommDebugMode 是一个上下文管理器,用于统计其上下文中功能集体的数量。它通过 TorchDispatchMode 来实现。

注意

目前并非所有集体操作都得到支持。

演示用法

mod = ...
comm_mode = CommDebugMode()
with comm_mode:
    mod.sum().backward()
print(comm_mode.get_comm_counts())
generate_comm_debug_tracing_table(noise_level=3)[来源][来源] ¶

生成详细表格,显示模块级别的操作和集体跟踪信息。信息量取决于 noise_level

  1. 打印模块级别的集体计数

  2. 打印不包括在平凡操作中的 dTensor 操作,模块信息

  3. 打印不包括在平凡操作中的操作

  4. 打印所有操作

生成_json_dump(file_name='comm_mode_log.json', noise_level=3)[source][source] ¶

创建用于构建浏览器可视化的 json 文件 0. 打印模块级别的集体计数 1. 打印不包括在平凡操作中的 dTensor 操作 2. 打印不包括在平凡操作中的操作 3. 打印所有操作

get_comm_counts()[source][source]

返回通信计数作为字典。

返回:

通信计数作为字典。

返回类型:

dict[任意类型, int]

get_parameter_info()[来源][来源] ¶
返回类型:

dict[str, dict[str, Any]]

get_sharding_info()[来源][来源] ¶
返回类型:

dict[str, dict[str, Any]]

get_total_counts()[source][source]
返回类型:

int

log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)[source][source]

控制台 CommDebugMode 输出的替代方案,将输出写入用户指定的文件

为了可视化维度少于 3 的 DTensor 的划分,DTensor 提供了 visualize_sharding()

torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')[source]

在终端中可视化 1D 或 2D 的 DTensor 的划分。

注意

这需要 tabulate 包。对于空张量不会打印划分信息。

实验性功能

DTensor 还提供了一组实验性功能。这些功能处于原型设计阶段,或者基本功能已完成,但正在寻找用户反馈。如果您对这些功能有反馈,请向 PyTorch 提交问题。

torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]

context_parallel 是一个实验性 API,用于启用上下文并行(CP)。此 API 执行两个操作:1) 将 SDPA( torch.nn.functional.scaled_dot_product_attention )替换为启用 CP 的版本,2) 沿序列维度划分 buffers ,每个 rank 将根据 mesh 保留相应的 shard。

参数:
  • 网格( DeviceMesh )- 上下文并行性的设备网格。

  • 缓冲区(可选[List[torch.Tensor]])- 使用的缓冲区依赖于序列维度。例如,输入批次、标签和位置嵌入缓冲区。这些缓冲区必须沿序列维度分片以确保准确性。分片将在原地发生,缓冲区的形状将在上下文中改变。缓冲区将在上下文完成后恢复。可以使用 no_restore_buffers 来指定哪些缓冲区不需要恢复。请注意, buffers 不应包含任何 nn.Parameter。

  • buffer_seq_dims(可选[List[int]])- buffers 的序列维度。

  • no_restore_buffers(可选[Set[torch.Tensor]])- 这些集合中的缓冲区在上下文退出后不会恢复。此集合必须是 buffers 的子集。如果缓冲区在上下文退出后不再使用,可以将这些缓冲区放入此列表中,以避免额外的恢复时间。

返回类型:

生成器[None, None, None]

警告

torch.distributed._tensor.experimental.attention.context_parallel 是 PyTorch 中的一个原型功能。该 API 可能会发生变化。

torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]

local_map() 是一个实验性 API,允许用户将 DTensor 传递给一个编写为应用于 torch.Tensor 的函数。这是通过提取 DTensor 的本地组件,调用函数,并根据 out_placements 将输出封装到 DTensor 来实现的。

参数:
  • func (Callable) – 要应用于每个本地分片的 DTensor 的函数。

  • 输出放置(Union[放置类型,Tuple[放置类型, ...]])- func 展平输出中 DTensor 的期望放置。如果展平的 output 是单个值,则 out_placements 应为类型 PlacementType。否则,如果展平的 output 有多个值,则 out_placements 应为与展平的 output 1:1 映射的 PlacementType 值的元组。此外,对于 Tensor 输出,我们使用 PlacementType 作为其放置(一个 Tuple[放置] 值)。对于非 Tensor 输出,PlacementType 应为 None。注意,唯一的例外是未传入任何 DTensor 参数。在这种情况下,即使 out_placements 不为 None,结果函数也应忽略期望的放置,因为函数不是在 DTensor 下运行的。

  • in_placements (Tuple[PlacementType, …], optional) – func 的必需放置位置,在 in_placements 的扁平化输入中。如果指定了 local_map() ,则 DTensor 会检查每个 redistribute_inputs 参数的放置是否与必需放置相同。如果放置不同且 Falseredistribute_inputs ,则会引发异常。否则,如果 Truefunc ,则参数将首先重新分配到必需的分区放置,然后再将其本地张量传递给 None 。唯一的例外是当必需放置不是 torch.Tensor 且参数为 func 时。在这种情况下,将跳过放置检查,并将参数直接传递给 in_placements 。如果 None 为 @15# ,则不会执行放置检查。默认:None

  • device_mesh ( DeviceMesh , optional) – 所有 DTensor 放置在其上的设备网格。如果没有指定,则将从输入 DTensor 的设备网格中推断出来。local_map 要求每个 DTensor 都放置在相同的设备网格上。默认:None。

  • redistribute_inputs (bool, 可选) – 表示是否重新分片输入的 bool 值,当它们的放置与所需输入放置不同时。如果此值为 False 且某些 DTensor 输入的放置不同,将引发异常。默认:False。

返回:

一个 Callable ,将 func 应用到输入的每个本地分片 DTensor 上,并返回由 func 的返回值构造的 DTensor

引发:
  • AssertionError – 如果输入 DTensor 未放置在同一设备网格上,或者如果它们放置在与传入的 device_mesh 参数不同的设备网格上。

  • AssertionError – 对于任何非 DTensor 输出,我们要求其对应的输出放置在 out_placements 为 None。如果不满足此条件,将引发 AssertionError。

  • ValueError – 如果 redistribute_inputs=False 但输入 DTensor 需要根据 in_placements 进行重新分配。

示例

>>> def mm_allreduce_forward(device_mesh, W, X):
>>>     partial_sum_tensor = torch.mm(W, X)
>>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>>     return reduced_tensor
>>>
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)]  # row-wise sharding placements on 1-d mesh
>>> col_wise = [Shard(1)]  # col-wise sharding placements on 1-d mesh
>>>
>>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion
>>> local_mm_allreduce_forward = local_map(
>>>     mm_allreduce_forward,
>>>     out_placements=[Replicate()],
>>>     in_placements=[col_wise, row_wise],
>>>     device_mesh=device_mesh,
>>> )
>>>
>>> W_dt = distribute_tensor(
...     W, device_mesh, (col_wise)
... )  # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(
...     X, device_mesh, (row_wise)
... )  # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(
...     device_mesh, W_dt, X_dt
... )  # apply local_mm_allreduce_forward to DTensors

注意

此 API 目前处于实验阶段,可能随时更改。

torch.distributed.tensor.experimental.register_sharding(op)[source]

register_sharding() 是一个实验性 API,允许用户在张量输入和输出为 DTensor 时注册分片策略。当以下情况时可能很有用:(1) 对于 op 不存在默认分片策略,例如当 op 是一个不支持 DTensor 的自定义操作符时;(2) 当用户想要覆盖现有操作符的默认分片策略时。

参数:

op (Union[OpOverload, List[OpOverload]]) – 要注册自定义分片函数的操作符或操作符列表。

返回:

一个函数装饰器,可以用来包装定义 op 中指定操作符分片策略的函数。定义的分片策略将被注册到 DTensor,如果 DTensor 已经实现了该操作符,将覆盖默认分片策略。自定义分片函数接受与原始 op 相同的输入(如果参数是 torch.Tensor ,则将被 DTensor 内部使用的类似张量对象替换)。函数应返回一个由 2-元组组成的序列,每个 2-元组指定可接受的输出放置及其对应的输入放置。

示例

>>> @register_sharding(aten._softmax.default)
>>> def custom_softmax_sharding(x, dim, half_to_float):
>>>     softmax_dim = dim if dim >= 0 else dim + x.ndim
>>>     acceptable_shardings = []
>>>
>>>     all_replicate = ([Replicate()], [Replicate(), None, None])
>>>     acceptable_shardings.append(all_replicate)
>>>
>>>     for sharding_dim in range(x.ndim):
>>>         if sharding_dim != softmax_dim:
>>>             all_sharded = (
>>>                 [Shard(sharding_dim)],
>>>                 [Shard(sharding_dim), None, None],
>>>             )
>>>             acceptable_shardings.append(all_sharded)
>>>
>>>     return acceptable_shardings

注意

此 API 目前处于实验阶段,内容可能随时更改


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源