快捷键

DataParallel

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source][source]

实现模块级别的数据并行。

此容器通过在批处理维度中分块来分割输入,以并行应用给定的 module ,通过在每个指定的设备上复制一次其他对象来在设备上分配输入。在正向传递过程中,该模块在每个设备上被复制,每个副本处理输入的一部分。在反向传递过程中,每个副本的梯度被汇总到原始模块中。

批处理大小应大于使用的 GPU 数量。

警告

即使只有一个节点,也建议使用 DistributedDataParallel ,而不是此类,来进行多 GPU 训练。请参阅:使用 nn.parallel.DistributedDataParallel 代替 multiprocessing 或 nn.DataParallel 和 Distributed Data Parallel。

允许将任意位置和关键字输入传递给 DataParallel,但某些类型会被特别处理。张量将在指定的 dim 上分散(默认为 0)。tuple、list 和 dict 类型将被浅复制。其他类型将在不同的线程之间共享,如果在模型的前向传递过程中写入,可能会被损坏。

并行化的 module 必须在运行此 DataParallel 模块之前将参数和缓冲区放在 device_ids[0] 上。

警告

在每次正向传播中, module 将在每个设备上复制,因此对正在运行的模块 forward 的任何更新都将丢失。例如,如果 module 有一个计数器属性,每次 forward 都会增加,它将始终保持在初始值,因为更新是在副本上完成的,这些副本在 forward 之后将被销毁。然而, DataParallel 保证 device[0] 上的副本将与基本并行化的 module 共享存储参数和缓冲区。因此,对参数或缓冲区上的就地更新将被记录。例如, BatchNorm2dspectral_norm() 依赖于这种行为来更新缓冲区。

警告

定义在 module 及其子模块上的正向和反向钩子将被调用 len(device_ids) 次,每次调用都带有位于特定设备上的输入。特别是,钩子仅保证按照对应设备上的操作顺序正确执行。例如,不能保证通过 register_forward_pre_hook() 设置的钩子会在所有@3@4 调用之前执行,但可以保证每个这样的钩子会在该设备的相应 forward() 调用之前执行。

警告

module 返回一个标量(即 0 维张量)在 forward() 中时,此包装器将返回一个长度等于数据并行使用中设备数量的向量,包含每个设备的结果。

注意

Module 中使用 pack sequence -> recurrent network -> unpack sequence 模式时存在一个细微差别。请参阅 FAQ 中的“我的循环神经网络不与数据并行工作”部分以获取详细信息。

参数:
  • 模块(Module)- 要并行化的模块

  • device_ids(Python 整数列表或 torch.device)- CUDA 设备(默认:所有设备)

  • 输出设备(int 或 torch.device)- 输出位置(默认:device_ids[0])

变量:

模块(Module)- 要并行化的模块

示例:

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源