torch.nn.functional.torch.nn.parallel.data_parallel¶
- torch.nn.parallel.data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None)[source][source]¶
在给定的 device_ids 中的 GPU 上并行评估 module(input)。
这是 DataParallel 模块的功能版本。
- 参数:
模块(Module)- 要并行评估的模块
输入(Tensor)- 模块的输入
device_ids(列表,元素类型为 python:int 或 torch.device)- 用于复制模块的 GPU ID
输出设备(Python 整型列表或 torch.device)- GPU 位置输出。使用-1 表示 CPU。(默认:device_ids[0])
- 返回值:
包含模块(input)在 output_device 上结果的 Tensor
- 返回类型: