• 文档 >
  • torch >
  • torch.cat
快捷键

torch.cat

torch.cat(tensors, dim=0, *, out=None) Tensor

将给定的张量序列在指定维度上连接起来。所有张量必须具有相同的形状(除了连接维度),或者可以是大小为 (0,) 的 1-D 空张量。

torch.cat() 可以看作是 torch.split()torch.chunk() 的逆操作。

torch.cat() 最好通过例子来理解。

参见

torch.stack() 沿着新的维度连接给定的序列。

参数:
  • 张量(张量序列)- 提供的非空张量必须在 cat 维度之外具有相同的形状。

  • dim(int,可选)- 张量连接的维度

关键字参数:

输出(张量,可选)- 输出张量。

示例:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源