• 文档 >
  • 多进程最佳实践
快捷键

多进程最佳实践 ¶

torch.multiprocessing 是 Python 的 multiprocessing 模块的直接替换。它支持完全相同的操作,但进行了扩展,因此所有通过 multiprocessing.Queue 发送的张量,其数据都将移动到共享内存中,而只发送另一个进程的句柄。

注意

Tensor 发送到另一个进程时, Tensor 数据是共享的。如果 torch.Tensor.grad 不是 None ,它也是共享的。在发送没有 torch.Tensor.grad 字段的 Tensor 之后,它会在另一个进程创建一个标准的进程特定 .grad Tensor ,这与 Tensor 的数据已被共享不同,不会自动在所有进程中共享。

这允许实现各种训练方法,如 Hogwild、A3C 或任何需要异步操作的其他方法。

CUDA 在多进程下

CUDA 运行时不支持 fork start 方法;必须使用 spawnforkserver start 方法才能在子进程中使用 CUDA。

注意

可以通过创建上下文使用 multiprocessing.get_context(...) 或直接使用 multiprocessing.set_start_method(...) 来设置启动方法。

与 CPU 张量不同,发送过程需要保留原始张量,直到接收过程保留张量的副本。这是在底层实现的,但需要用户遵循最佳实践以确保程序正确运行。例如,发送过程必须保持活跃,直到消费者过程还有对张量的引用,如果消费者过程通过致命信号异常退出,则引用计数无法拯救你。参见本节。

参见:使用 nn.parallel.DistributedDataParallel 代替 multiprocessing 或 nn.DataParallel

最佳实践和技巧

避免和解决死锁

当一个新的进程被创建时,可能会出现很多问题,其中最常见的死锁原因是后台线程。如果有任何线程持有锁或导入模块,并且调用了 fork ,那么子进程很可能处于损坏状态,并可能以不同的方式死锁或失败。请注意,即使你没有这样做,Python 的内置库也会这样做——无需进一步查找 multiprocessingmultiprocessing.Queue 实际上是一个非常复杂的类,它创建了多个用于序列化、发送和接收对象的线程,它们也可能导致上述问题。如果你发现自己处于这种情况,尝试使用 SimpleQueue ,它不使用任何额外的线程。

我们正在尽力让这变得简单,并确保这些死锁不会发生,但有些事情超出了我们的控制范围。如果你有任何无法应对一段时间的问题,请尝试在论坛上寻求帮助,我们将看看是否是我们能够解决的问题。

重复使用通过队列传递的缓冲区

记住每次将一个 Tensor 放入一个 multiprocessing.Queue 中时,它必须移动到共享内存中。如果它已经共享,则不会执行任何操作,否则将产生额外的内存复制,这可能会减慢整个过程。即使有一个进程池将数据发送到单个进程,也请让它发送回缓冲区 - 这几乎是不需要成本的,并且可以让您在下一次发送批次时避免复制。

异步多进程训练(例如 Hogwild)

使用 torch.multiprocessing ,可以异步训练模型,参数要么始终共享,要么定期同步。在前一种情况下,我们建议发送整个模型对象,而在后一种情况下,我们建议只发送 state_dict()

我们建议使用 multiprocessing.Queue 在进程之间传递各种 PyTorch 对象。例如,当使用 fork 启动方法时,可以继承已存在于共享内存中的张量和存储,但这非常容易出错,应该谨慎使用,并且仅由高级用户使用。队列虽然有时不是那么优雅的解决方案,但在所有情况下都能正常工作。

警告

应该小心使用没有用 if __name__ == '__main__' 保护的全局语句。如果使用了不同于 fork 的启动方法,它们将在所有子进程中执行。

Hogwild

具体的 Hogwild 实现可以在示例仓库中找到,但为了展示代码的整体结构,下面也提供了一个最小示例:

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

CPU 在多进程 ¶

不适当的进程并行处理可能导致 CPU 过载,使得不同的进程竞争 CPU 资源,从而降低效率。

本教程将解释什么是 CPU 过载以及如何避免它。

CPU 过载 ¶

CPU 过载是一个技术术语,指的是分配给系统的虚拟 CPU 总数超过了硬件上可用的虚拟 CPU 总数的情况。

这会导致对 CPU 资源的严重竞争。在这种情况下,进程之间会频繁切换,从而增加了进程切换的开销并降低了整体系统效率。

请参阅示例仓库中 Hogwild 实现中的代码示例,了解 CPU 过载的情况。

当在 CPU 上使用以下命令运行训练示例时:

python main.py --num-processes 4

假设机器上有 N 个 vCPU 可用,执行上述命令将生成 4 个子进程。每个子进程将为自己分配 N 个 vCPU,从而需要 4*N 个 vCPU。然而,机器上只有 N 个 vCPU 可用。因此,不同的进程将竞争资源,导致频繁的进程切换。

以下观察表明存在 CPU 过载:

  1. 高 CPU 利用率:通过使用 htop 命令,您可以观察到 CPU 利用率持续很高,通常达到或超过其最大容量。这表明对 CPU 资源的需求超过了可用的物理核心,导致进程之间为 CPU 时间竞争。

  2. 频繁的上下文切换导致系统效率低下:在 CPU 过载的情况下,进程之间为 CPU 时间竞争,操作系统需要快速在不同进程之间切换以公平分配资源。这种频繁的上下文切换增加了开销并降低了整体系统效率。

避免 CPU 过载

避免 CPU 过载的好方法是合理分配资源。确保同时运行的进程或线程数量不超过可用的 CPU 资源。

在这种情况下,解决方案是为子进程指定适当的线程数。这可以通过在子进程中使用 torch.set_num_threads(int) 函数来实现。

假设机器上有 N 个 vCPU 和 M 个进程将被生成,每个进程使用的最大 num_threads 值将是 floor(N/M) 。为了避免在 mnist_hogwild 示例中 CPU 过载,需要在示例仓库中的 train.py 文件中进行以下更改。

def train(rank, args, model, device, dataset, dataloader_kwargs):
    torch.manual_seed(args.seed + rank)

    #### define the num threads used in current sub-processes
    torch.set_num_threads(floor(N/M))

    train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    for epoch in range(1, args.epochs + 1):
        train_epoch(epoch, args, model, device, train_loader, optimizer)

使用 torch.set_num_threads(floor(N/M)) 为每个进程设置 num_thread 。其中,用 N 替换可用的 vCPU 数量,用 M 替换所选的进程数。适当的 num_thread 值将根据具体任务而变化。然而,作为一个一般性指南, num_thread 的最大值应该是 floor(N/M) ,以避免 CPU 过载。在避免 CPU 过载的 mnist_hogwild 训练示例中,可以实现 30 倍的性能提升。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源