由曹杰、米尔·莫哈马迪、亚历克斯·沃特海姆、郑耀诺、乔·斯皮萨克、威尔·克罗玛、沙乌欣·扎希拉扎米所著

今天,我们很高兴地分享我们为 PyTorch/XLA 2.0 的最新工作。PyTorch 2.0 的发布是这一传奇社区又一个重要的里程碑,我们很兴奋能继续成为其中的一员。当 PyTorch/XLA 项目在 2018 年由谷歌和 Meta 启动时,重点是引入前沿的 Cloud TPUs 以支持 PyTorch 社区。在这个过程中,社区中的其他人,如亚马逊,也加入了该项目,社区很快得到了扩展。我们对 XLA 的方向感到兴奋,并认为这个项目继续为 PyTorch 社区带来的好处。在这篇博客中,我们想展示一些正在开发中的关键特性,展示代码片段,并通过一些基准测试来展示其好处。

TorchDynamo / torch.compile(实验性)

TorchDynamo(Dynamo)是一个 Python 级别的 JIT 编译器,旨在使未经修改的 PyTorch 程序运行更快。它提供了一个干净的 API,让编译器后端可以挂钩;其最大特点是能够在执行前动态修改 Python 字节码。在 PyTorch/XLA 2.0 版本中,为 Dynamo 提供了一个实验性的后端,用于推理和训练。

当 Dynamo 识别到模型模式时,会提供一个 Torch FX(FX)图,PyTorch/XLA 使用 Lazy Tensor 方法编译 FX 图并返回编译后的函数。要了解更多关于 PyTorch/XLA 的 Dynamo 实现的技术细节,请查看这篇 dev-discuss 帖子以及 dynamo 文档。

下面是一个运行 ResNet18 的小代码示例:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
      xla_resnet18, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_resnet18(data)

使用 torch.compile ,PyTorch/XLA 只在初始化时跟踪一次 ResNet18 模型,并在每次调用 dynamo_resnet18 时执行编译后的二进制文件,而不是每一步都跟踪模型。为了说明 Dynamo+XLA 的优势,下面是使用 TorchBench 在 Cloud TPU v4-8 上比较 Dynamo 和 LazyTensor(没有 Dynamo)的推理速度分析,y 轴是加速倍数。

Inference Speedup - PyTorch/XLA Dynamo on TPU

动态训练处于开发阶段,其实现阶段早于推理阶段。开发者可以测试这个早期功能,然而,在 2.0 版本中,PyTorch/XLA 支持前向和反向传递图,但不支持优化器图;优化器图可在夜间构建中找到,并将纳入 PyTorch/XLA 2.1 版本。以下是使用 ResNet18 示例进行训练的示例, torch.compile

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='aot_torchxla_trace_once')
  for data, target in loader:
    output = dynamo_train_model(xla_resnet18, data, target)

注意,训练的后端是 aot_torchxla_trace_once (API 将在稳定版本中更新),而推理的后端是 torchxla_trace_once (名称将进行更改)。我们预计使用 Lazy 张量时,每一步训练将提取和执行 3 个图,而不是 1 个训练步骤。以下是使用 Cloud TPU v4-8 上的 TorchBench 比较 Dynamo 和 Lazy 的训练加速分析。

Training Speedup - PyTorch/XLA Dynamo on TPU

PJRT 运行时(Beta)

PyTorch/XLA 正在从 XRT 迁移到新的 PJRT 运行时。PJRT 是一个维护得更好的堆栈,具有证明的性能优势,包括在 TorchBench 2.0 模型上平均提高 35% 的性能。它还支持更丰富的功能,使 SPMD 等技术成为可能。在 PyTorch/XLA 2.0 版本中,PJRT 是 TPU 和 CPU 的默认运行时;GPU 支持处于实验状态。PyTorch/XLA 2.0 版本中包含的 PJRT 功能有:

  • 使用 libtpu PJRT 插件 API 在 libtpu 中改进了性能,最高可达 30%
  • torch.distributed 支持 TPU v2 和 v3,包括 pjrt:// init_method (实验性)
  • 单主机 GPU 支持。多主机支持即将推出。(实验性)

切换到 PJRT 无需更改(或对 GPU 而言仅需最小更改)用户代码(更多详情请参阅 pjrt.md)。运行时配置简单,只需将 PJRT_DEVICE 环境变量设置为本地设备类型(即 TPUGPUCPU )。以下是使用不同设备上的 PJRT 运行时的示例。

# TPU Device
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
# TPU Pod Device
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git"

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
# GPU Device (Experimental)
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

下面是在 v4-8 TPU 上,通过 TorchBench 2.0 任务对 XRT 和 PJRT 进行性能比较。要了解更多关于 PJRT 与 XRT 的信息,请查阅文档。

TorchBench Training Time

并行化

GSPMD(实验性)

我们很高兴向大家介绍 PyTorch 中的通用和可扩展并行化计算图(GSPMD),作为一项新的实验性数据与模型分片解决方案。GSPMD 为常见的机器学习工作负载提供自动并行化,允许开发者像在单个大设备上一样编写 PyTorch 程序,无需自定义分片计算操作和/或集体通信操作。XLA 编译器根据用户提供的分片提示,将单设备程序转换为具有适当集体操作的分区程序。API(RFC)将在 PyTorch/XLA 2.0 版本中作为单个 TPU 虚拟机主机上的实验性功能提供。

GSPMD 的下一步

GSPMD 在 2.0 版本中处于实验状态。为了将其提升至稳定状态,我们计划在后续版本中解决许多功能差距和已知问题,包括多主机支持、DTensor 集成、部分复制分片、异步数据加载和检查点。

FSDP(Beta)

PyTorch/XLA 在 1.12 版本中引入了完全分片数据并行(FSDP)的实验性支持。这个特性是 PyTorch FSDP 的并行表示,XLA 和上游 CUDA 内核的设置存在细微差别。 auto_wrap_policy 是一个新参数,允许开发者自动指定将分区规范传播到神经网络子模块的条件。 auto_wrap_policy 可以作为参数传递,当使用 FSDP 包装模型时。以下两个值得注意的可调用对象是: size_based_auto_wrap_policytransformer_auto_wrap_policy

size_based_auto_wrap_policy 允许用户使用最少参数数量包装子模块。以下示例展示了如何包装具有至少 1000 万个参数的模型子模块。

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7)

transformer_auto_wrap_policy 允许用户包装所有匹配特定层类型的子模块。以下示例展示了如何包装名为 torch.nn.Conv2d 的模型子模块。了解更多信息,请参阅 Ronghang Hu 提供的 ResNet 示例。

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

PyTorch/XLA 的 FSDP 现在已集成到 HuggingFace 训练器类中(PR),使用户能够在 PyTorch/XLA 上训练更大的模型(官方 Hugging Face 文档)。使用此 FSDP 配置在 Cloud TPU v4-64 上训练的 16B 参数的 GPT2 模型实现了 39%的硬件利用率。

TPU 加速器 - 设备数量 v4-64
GPT2 参数数量 16B
使用 FSDP 包装的层 GPT2 块
每片芯片的 TFLOPs 275
每步 PFLOPs 50
硬件利用率 39%

FSDP 与 GSPMD 的区别

FSDP 是一种数据并行技术,通过将模型参数、优化器状态和梯度全部分片来减少设备内存占用。请注意,实际的计算仍然在设备本地进行,并且需要在前向和反向传播过程中对所有分片的模型参数进行全局收集,因此得名“数据并行”。FSDP 是 PyTorch/XLA 中用于扩展大型模型训练的最新功能之一。

相反,GSPMD 是一个通用的并行化系统,它能够实现各种类型的并行性,包括数据和模型并行性。PyTorch/XLA 提供了一个分片注解 API 和 XLAShardedTensor 抽象,用户可以在 PyTorch 程序中注解任何张量以指定分片规范。开发者无需手动实现分片计算或注入集体通信操作即可正确实现。XLA 编译器会完成这项工作,以便每个计算可以在多个设备上以分布式方式运行。

例子与初步结果

要了解 PyTorch/XLA 并行分片 API,请访问我们的 RFC 并查看示例代码引用。以下是一个简单的示例,用于启用数据和模型并行。

model = SimpleLinear().to(xm.xla_device())
# Sharding annotate the linear layer weights.
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  data = data.to(xm.xla_device())
  target = target.to(xm.xla_device())
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

以下图表突出了 PyTorch/XLA FSDP 和 SPMD 在 Cloud TPU v4-8 上运行 ResNet50 时的内存效率优势。

Batch Size Scaling with Spatial Partitioning

结束语…

我们很兴奋将这些功能带给 PyTorch 社区,这仅仅是一个开始。动态形状、对 OpenXLA 的更深入支持等领域正在开发中,我们计划发布更多博客来深入探讨细节。PyTorch/XLA 是完全开源开发的,我们邀请您加入开发者社区,通过 GitHub 提交问题、提交拉取请求和发送 RFC。您可以在各种 XLA 设备上尝试 PyTorch/XLA,包括 TPUs 和 GPU。以下是开始使用的方法。

再次祝贺 PyTorch 社区取得这一里程碑!

干杯,

谷歌的 PyTorch 团队