• 文档 >
  • torch >
  • torch.dsplit
快捷键

torch.dsplit

torch.dsplit(input, indices_or_sections) → 列表形式的张量

将具有三个或更多维度的张量 input 沿着 indices_or_sections 深度方向分割成多个张量。每个分割都是 input 的视图。

这相当于调用 torch.tensor_split(input, indices_or_sections, dim=2)(分割维度为 2),除非 indices_or_sections 是一个整数,并且它必须能整除分割维度,否则会抛出运行时错误。

此函数基于 NumPy 的 numpy.dsplit()

参数:
  • input(张量)- 要分割的张量。

  • indices_or_sections(整数或列表或整数元组)- 请参阅 torch.tensor_split() 中的参数。

示例::
>>> t = torch.arange(16.0).reshape(2, 2, 4)
>>> t
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],
        [[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]])
>>> torch.dsplit(t, 2)
(tensor([[[ 0.,  1.],
        [ 4.,  5.]],
       [[ 8.,  9.],
        [12., 13.]]]),
 tensor([[[ 2.,  3.],
          [ 6.,  7.]],
         [[10., 11.],
          [14., 15.]]]))
>>> torch.dsplit(t, [3, 6])
(tensor([[[ 0.,  1.,  2.],
          [ 4.,  5.,  6.]],
         [[ 8.,  9., 10.],
          [12., 13., 14.]]]),
 tensor([[[ 3.],
          [ 7.]],
         [[11.],
          [15.]]]),
 tensor([], size=(2, 2, 0)))

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源