• 文档 >
  • torch >
  • torch.split
快捷键

torch.split

torch.split(tensor, split_size_or_sections, dim=0)[source][source]

将张量分割成块。每个块都是原始张量的视图。

如果 split_size_or_sections 是整数类型,则 tensor 将分割成等大小的块(如果可能)。如果沿着给定维度 dim 的张量大小不能被 split_size 整除,则最后一个块将更小。

如果 split_size_or_sections 是列表,则 tensor 将分割成 len(split_size_or_sections) 块,块的大小根据 dimsplit_size_or_sections 确定。

参数:
  • 张量(Tensor)- 要分割的张量。

  • 单个块的大小或块大小的列表(int)或(list(int))

  • 沿着哪个维度分割张量(int)

返回类型:

tuple[torch.Tensor, …]

示例:

>>> a = torch.arange(10).reshape(5, 2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1, 4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源