torch.split¶
- torch.split(tensor, split_size_or_sections, dim=0)[source][source]¶
将张量分割成块。每个块都是原始张量的视图。
如果
split_size_or_sections
是整数类型,则tensor
将分割成等大小的块(如果可能)。如果沿着给定维度dim
的张量大小不能被split_size
整除,则最后一个块将更小。如果
split_size_or_sections
是列表,则tensor
将分割成len(split_size_or_sections)
块,块的大小根据dim
和split_size_or_sections
确定。- 参数:
张量(Tensor)- 要分割的张量。
单个块的大小或块大小的列表(int)或(list(int))
沿着哪个维度分割张量(int)
- 返回类型:
tuple[torch.Tensor, …]
示例:
>>> a = torch.arange(10).reshape(5, 2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) >>> torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))