今天,我们很高兴地宣布 PyTorch/XLA SPMD:将 GSPMD 集成到 PyTorch 中,并提供易于使用的 API。寻求卓越性能和扩展的 PyTorch 开发者可以在最大程度地利用 AI 加速器(如 Google Cloud TPUs)的同时,训练和部署最大的神经网络。
引言
GSPMD 是一个用于 ML 工作负载的自动并行化系统。XLA 编译器根据用户提供的分区提示,将单设备程序转换为具有适当归约的分区程序。这使得开发者可以像在单个大设备上一样编写 PyTorch 程序,而无需任何自定义的分区计算和/或集体通信操作来扩展模型。
PyTorch/XLA SPMD 允许 PyTorch 用户以更少的努力和更好的性能使用 GSPMD 并行化他们的机器学习工作负载。一些关键亮点包括:
- 更好的开发者体验。所有操作只需用户进行少量分片注释即可完成,PyTorch/XLA SPMD 实现的性能与最有效的 PyTorch 分片实现相当(请参阅下面的示例和结果部分)。PyTorch/XLA SPMD 将编程 ML 模型的任务与并行化的挑战分开。其对模型分片自动化的方法让用户免于实现带有适当集体操作的分片版本 ops。
- 一个 API,可以启用各种并行算法(包括数据并行、全分片数据并行、空间分区张量和流水线并行,以及这些算法的组合)以适应不同的机器学习工作负载和模型架构。
- 在大型模型训练中具有行业领先的性能。PyTorch/XLA SPMD 将强大的 XLA GSPMD 带到 PyTorch,使用户能够充分利用 Google Cloud TPUs 的全部功能。
- 允许 PyTorch 和 JAX 开发者利用相同的底层 XLA API 来扩展模型。
关键概念
分片注解 API 的关键概念包括:1)网格,2)分区规范,以及 3)使用网格和分区规范表达分片意图的 mark_sharding
API。更详细的设计概述可在用户指南中找到。
网格
对于给定的设备集群,物理网格是互连拓扑的表示。
我们根据这种拓扑结构推导出一个逻辑网格,以创建可以用于在模型中分区张量不同轴的设备子组。我们应用分片注释来将程序映射到逻辑网格;这将在程序图中自动插入通信集体以支持功能正确性(见下图)。
我们使用 Mesh API 抽象逻辑网格。逻辑网格的轴可以命名。以下是一个示例:
import numpy as np
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh.get_logical_mesh()
>> array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
mesh.shape()
>> OrderedDict([('x', 4), ('y', 2)])
分区规范
partition_spec 与输入张量具有相同的秩。每个维度描述了相应的输入张量维度如何在设备网格(逻辑上由 mesh_shape 定义)上分片。 partition_spec
是一个包含 device_mesh
维度的元组 index
,None,或包含网格维度索引的元组。 index
可以是 int
或 str
,如果相应的网格维度被命名。这指定了每个输入秩是如何分片( index
到 mesh_shape
)或复制( None
)的。
# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (4, 2), ('data', 'model'))
partition_spec = ('model', 'data')
xs.mark_sharding(input_tensor, mesh, partition_spec)
我们支持原始 GSPMD 论文中描述的所有三种分片类型。例如,可以指定部分复制如下:
# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (2, 2, 2), ('x', 'y', 'z'))
# evenly shard across x and z and replicate among y
partition_spec = ('x', 'z') # equivalent to ('x', None, 'z')
xs.mark_sharding(input_tensor, mesh, partition_spec)
简单示例,使用分片注解
用户可以使用 mark_sharding
API(src)对原生 PyTorch 张量进行注解。该 API 以 torch.Tensor
作为输入,并返回一个 XLAShardedTensor。
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor
调用 mark_sharding
API 需要一个用户定义的逻辑网格和 partition_spec,为 XLA 编译器生成分片注解。分片规范附加到 XLATensor
以及原始输入张量上。以下是一个简单的使用示例,来自 [RFC],说明分片注解 API 的工作原理:
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devicese // 2) # 2x4 on v3-8, 2x2 on v4-8
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = (0, 1)
m1_sharded = xs.mark_sharding(t, mesh, partition_spec)
assert isinstance(m1_sharded, XLAShardedTensor) == True
# Note that the sharding annotation is also in-placed updated to t
我们可以在 PyTorch 程序中注解不同的张量,以启用不同的并行技术,如下面的注释所述:
# Sharding annotate the linear layer weights. SimpleLinear() is a nn.Module.
model = SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
# Assumes `loader` returns data, target on XLA device
optimizer.zero_grad()
# Sharding annotate input data, we can shard any input
# dimensions. Sharding the batch dimension enables
# data parallelism, sharding the feature dimension enables
# spatial partitioning.
xs.mark_sharding(data, mesh, partition_spec)
ouput = model(data)
loss = loss_fn(output, target)
optimizer.step()
xm.mark_step()
更完整的单元测试用例和集成测试示例可在 PyTorch/XLA 仓库中找到。
结果
性能
我们使用 GPT-2 模型(src)对 PyTorch/XLA SPMD 的性能进行了测量,并将其与用户模式 FSDP 进行了比较。
在这里,SPMD 应用与 FSDP 图相同的分片方案(即 1D 分片)。用户应通过探索更高级的 SPMD 分片方案来获得更好的 MFU 结果。
我们使用模型 FLOPS 利用率(MFU)作为比较的指标。MFU 是“观察到的吞吐量与系统在峰值 FLOPs 下运行的理论最大吞吐量的比率”(PaLM 论文)。
flops_per_step = 6 * global_batch_size * seq_len * num_params
model_flops_utilization = flops_per_step / step_time(s) / chip_count / flops_per_chip
此估计假设输入维度远大于输入序列长度(d_model » seq_len)。如果违反此假设,自注意力 FLOPs 将开始变得足够显著,并且此表达式将低估真实的 MFU。
可扩展性
SPMD 的一个核心优势是灵活的分区,可以用来节省加速器内存(HBM)使用并提高可扩展性。为了可扩展性分析,我们提出了两项研究:1)我们使用 Hugging Face transformers(GPT-2)作为基础实现,检查了 4 个模型大小的峰值 HBM;2)我们检查了空间分区下的峰值 HBM 使用情况。
上图说明了未分片的 2B 参数模型峰值内存占用为 26GB(红色虚线)。通过硬编码模型权重(模型并行)可以减少峰值内存占用,从而使用给定的 TPU pod 切片实现更大的模型训练。在这些实验中,我们在 Google Cloud TPU v4-16 上实现了 39.75%的 MFU,在 4B 参数模型上。
我们还使用空间分区和简单的 ResNet50 示例(src)在 Cloud TPU v4-8 上运行了输入批处理可扩展性测试。输入批处理通常在批处理维度上分片以提高数据并行性(DDP,FSDP),但 PyTorch/XLA SPMD 允许在输入特征维度上进行输入分片。如图所示,通过空间分区可以将每个设备的批处理大小推高到 512,这是其他数据并行技术无法实现的。
PyTorch/XLA SPMD 的前路
我们对 PyTorch/XLA 的未来感到非常兴奋,并邀请社区加入我们。SPMD 仍然是实验性的,我们持续为其添加新功能。在未来版本中,我们计划解决异步数据加载、部分复制分片和其他改进。我们非常愿意听取您的意见,回答您对 PyTorch/XLA SPMD 的疑问,并了解您如何使用 SPMD。
喝彩!
谷歌的 PyTorch/XLA 团队