• 教程 >
  • PyTorch 菜谱 >
  • 支持 TorchScript 的分布式优化器
快捷键

支持 TorchScript 的分布式优化器 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

警告

TorchScript 不再处于活跃开发状态。

在本菜谱中,您将学习:

  • 支持 TorchScript 的分布式优化器的高级思想及其带来的特性

  • 如何编写支持 TorchScript 的定制化分布式优化器

需求

什么是分布式优化器? ¶

DistributedOptimizer 接受一个远程参数(RRef)列表,并在参数所在的 worker 上本地运行优化器,这通常与分布式 RPC/Autograd 一起使用,以进行模型并行训练。它可以使用任何本地优化算法(无论是 torch.optim 中提供的预定义算法还是自定义定义的算法)来在每个 worker 上应用梯度。

什么是支持 TorchScript 的分布式优化器? ¶

分布式优化器在分布式模型并行训练中得到广泛应用,在有些常见的用例中,由于性能关注和资源利用率,训练需要在多线程方式下进行,而不是多进程(或者至少部分多线程,即参数服务器托管模型和参数的一部分,新线程按请求更新参数)。PyTorch 本身不支持原生的多线程训练,因为它受到 Python 的全局解释器锁(GIL)的影响,但它可以利用 TorchScript 来摆脱 GIL,并以多线程方式运行模型。

对于关键模型训练工作负载,提高训练性能是一个重要话题。研究人员通常会希望通过图表示(即通过算子融合)或实现自定义算子内核来加速训练。

支持 TorchScript 的分布式优化器可以帮助摆脱全局解释器锁(GIL),从而提高 PyTorch 在多线程环境中的训练性能,它还解锁了通过使用 TorchScript 提供的先进编译技术(例如 CPU/GPU 融合)进一步提升性能的潜力。

如何编写支持 TorchScript 的自定义分布式优化器?

下面的代码展示了如何编写一个自定义分布式优化器,给定现有的本地优化器实现,这解锁了 TorchScript 带来的好处,包括 GIL 移除和性能提升的机会。

假设您已经有一个在训练过程中使用的本地优化器,在这种情况下,我们将以准双曲动量(QHM)为例,展示如何启用 TorchScript 支持,请注意,这也适用于任何继承自 torch.optim.Optimizer 的任何自定义优化器。

首先,我们需要将计算和状态管理从优化器实现中分离出来,这样我们就可以提取计算部分并将其作为一个自由函数,使其对 TorchScript 友好。这有两个好处:1. 计算逻辑变得更容易检查,它允许我们快速将参数更新/计算部分转换为 TorchScript,并利用 TorchScript IR 进行进一步优化(操作融合等)。2. 分布式优化器底层使用不同的机制来获取梯度和更新参数(我们单独存储梯度而不是在反向传播过程中直接填充 param.grad 字段)。分离计算允许分布式优化器在多线程模式下启用优化器更新的可能性,因为它消除了对 param.grad 的潜在竞争条件。

import torch
from torch import Tensor
from typing import List


def qhm_update(params: List[Tensor],
            dp_list: List[Tensor],
            momentum_buffer_list: List[Tensor],
            lr: float,
            nu: float,
            weight_decay: float,
            weight_decay_type: str,
            momentum: float):

    for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
        if weight_decay != 0:
            if weight_decay_type == "grad":
                d_p.add_(weight_decay, p)
            elif weight_decay_type == "direct":
                p.mul_(1.0 - lr * weight_decay)
            else:
                raise ValueError("Invalid weight decay type provided")

        momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)

        p.data.add_(-lr * nu, momentum_buffer)
        p.data.add_(-lr * (1.0 - nu), d_p)

接下来我们将定义一个具有 TorchScript 兼容性的分布式功能优化器,以管理优化器状态并调用我们上面定义的 TorchScript 兼容的更新函数。请注意,与正常自定义优化器相比,存在一些不同的约定:1. 我们不继承 torch.optim.Optimizer ,因为 TorchScript 不支持多态性;2. step 接受梯度列表而不是损失闭包。

import torch
from torch import Tensor
from typing import List, Optional, Dict

# define this as a TorchScript class
@torch.jit.script
class FunctionalQHM(object):
    def __init__(self,
                params: List[Tensor],
                lr: float,
                momentum: float,
                nu: float,
                weight_decay: float = 0.0,
                weight_decay_type: str = "grad"):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if weight_decay_type not in ("grad", "direct"):
            raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))

        self.defaults = {
            "lr": lr,
            "momentum": momentum,
            "nu": nu,
            "weight_decay": weight_decay,
        }
        self.weight_decay_type = weight_decay_type

        # NOTE: we only have one param_group here and don't allow user to add additional
        # param group as it's not a common use case.
        self.param_group = {"params": params}

        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

    def step(self, gradients: List[Optional[Tensor]]):
        params = self.param_group['params']
        params_with_grad = []
        grads = []
        momentum_buffer_list: List[Tensor] = []

        if len(params) != len(gradients):
            raise ValueError(
                "the gradients passed in does not equal to the size of the parameters!"
                + f"Params length: {len(params)}. "
                + f"Gradients length: {len(gradients)}"
            )

        for param, gradient in zip(self.param_group['params'], gradients):
            if gradient is not None:
                params_with_grad.append(param)
                grads.append(gradient)
                state = self.state[param]
                state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)
                momentum_buffer_list.append(state['momentum_buffer'])

        # calls into the update function we just defined
        with torch.no_grad():
            qhm_update(params_with_grad,
                    grads,
                    momentum_buffer_list,
                    self.defaults['lr'],
                    self.defaults['nu'],
                    self.defaults['weight_decay'],
                    self.weight_decay_type,
                    self.defaults['momentum'])

最后,我们将我们新定义的分布式功能优化器注册到 functional_optim_map 中,这样 DistributedOptimizer 将尝试获取我们的自定义实现,而不是预定义的默认实现。

from torch.distributed.optim import DistributedOptimizer

DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM

现在,您可以通过将其传递给 DistributedOptimizer 在分布式训练中正常使用 QHM 优化器。

...
remote_params_list = [...]
dist_optim = DistributedOptimizer(
    QHM, remote_params_list, *args, **kwargs
)

DistributedOptimizer 将自动将 QHM 优化器转换为 FunctionalQHM ,并启用 TorchScript 支持。这将解锁由多线程训练带来的性能提升,并为进一步改进提供更多潜力(例如,TorchScript 融合等)。

注意,PyTorch 内置的大多数优化器已经使用这种方法来加速分布式训练。如果您看到有关某些优化器尚未转换的警告,您可以按照以下食谱编写自己的转换。


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源