• 文档 >
  • torch >
  • torch.tensor_split
快捷键

torch.tensor_split

torch.tensor_split(input, indices_or_sections, dim=0) → 列表形式的张量

将张量分割成多个子张量,所有子张量都是对 input 的视图,根据 indices_or_sections 指定的索引或部分数量,沿维度 dim 进行分割。此函数基于 NumPy 的 numpy.array_split()

参数:
  • 输入(张量)- 要分割的张量

  • indices_or_sections (Tensor, int 或 list 或 int 的元组) –

    如果 indices_or_sections 是一个整数 n 或一个零维 long 张量,则 input 沿着维度 dim 被分割成 n 部分。如果 input 沿着维度 dim 可以被 n 整除,则每个部分的大小将相等,为 input.size(dim) / n 。如果 input 不能被 n 整除,则前 int(input.size(dim) % n) 个部分的大小为 int(input.size(dim) / n) + 1 ,其余部分的大小为 int(input.size(dim) / n)

    如果 indices_or_sections 是一个 list 或 tuple 的 int,或者一个一维 long 张量,那么 input 将沿着维度 dim 在列表、元组或张量中的每个索引处进行分割。例如, indices_or_sections=[2, 3]dim=0 将产生张量 input[:2]input[2:3]input[3:]

    如果 indices_or_sections 是一个张量,它必须在 CPU 上是一个零维或一维 long 张量。

  • dim(int,可选)- 沿着分割张量的维度。默认: 0

示例:

>>> x = torch.arange(8)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))

>>> x = torch.arange(7)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
>>> torch.tensor_split(x, (1, 6))
(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))

>>> x = torch.arange(14).reshape(2, 7)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13]])
>>> torch.tensor_split(x, 3, dim=1)
(tensor([[0, 1, 2],
        [7, 8, 9]]),
 tensor([[ 3,  4],
        [10, 11]]),
 tensor([[ 5,  6],
        [12, 13]]))
>>> torch.tensor_split(x, (1, 6), dim=1)
(tensor([[0],
        [7]]),
 tensor([[ 1,  2,  3,  4,  5],
        [ 8,  9, 10, 11, 12]]),
 tensor([[ 6],
        [13]]))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源