torch.tensor_split¶
- torch.tensor_split(input, indices_or_sections, dim=0) → 列表形式的张量
将张量分割成多个子张量,所有子张量都是对
input
的视图,根据indices_or_sections
指定的索引或部分数量,沿维度dim
进行分割。此函数基于 NumPy 的numpy.array_split()
。- 参数:
输入(张量)- 要分割的张量
indices_or_sections (Tensor, int 或 list 或 int 的元组) –
如果indices_or_sections
是一个整数n
或一个零维 long 张量,则input
沿着维度dim
被分割成n
部分。如果input
沿着维度dim
可以被n
整除,则每个部分的大小将相等,为input.size(dim) / n
。如果input
不能被n
整除,则前int(input.size(dim) % n)
个部分的大小为int(input.size(dim) / n) + 1
,其余部分的大小为int(input.size(dim) / n)
。如果
indices_or_sections
是一个 list 或 tuple 的 int,或者一个一维 long 张量,那么input
将沿着维度dim
在列表、元组或张量中的每个索引处进行分割。例如,indices_or_sections=[2, 3]
和dim=0
将产生张量input[:2]
、input[2:3]
和input[3:]
。如果
indices_or_sections
是一个张量,它必须在 CPU 上是一个零维或一维 long 张量。dim(int,可选)- 沿着分割张量的维度。默认:
0
示例:
>>> x = torch.arange(8) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) >>> x = torch.arange(7) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) >>> torch.tensor_split(x, (1, 6)) (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) >>> x = torch.arange(14).reshape(2, 7) >>> x tensor([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]]) >>> torch.tensor_split(x, 3, dim=1) (tensor([[0, 1, 2], [7, 8, 9]]), tensor([[ 3, 4], [10, 11]]), tensor([[ 5, 6], [12, 13]])) >>> torch.tensor_split(x, (1, 6), dim=1) (tensor([[0], [7]]), tensor([[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]]))