torch.cuda.comm.gather¶
- torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[source][source]¶
从多个 GPU 设备收集张量。
- 参数:
张量(可迭代张量)- 要收集的张量可迭代序列。除了
dim
维以外的所有维度的张量大小必须匹配。dim(int,可选)- 张量将沿其连接的维度。默认:
0
。destination(torch.device,str 或 int,可选)- 输出设备。可以是 CPU 或 CUDA。默认:当前 CUDA 设备。
out(Tensor,可选,关键字参数)- 存储收集结果的张量。它的大小必须与
tensors
匹配,除了dim
,其大小必须等于sum(tensor.size(dim) for tensor in tensors)
。可以在 CPU 或 CUDA 上。
注意
当指定
out
时,不得指定destination
。- 返回值:
- 如果指定了
destination
,
位于destination
设备上的张量,是沿着dim
拼接tensors
的结果。
- 如果指定了
- 如果指定了
out
,
现在包含沿着dim
拼接tensors
结果的out
张量。
- 如果指定了