torch.unflatten¶
- torch.unflatten(input, dim, sizes) Tensor ¶
扩展输入张量的一个维度到多个维度。
参见
torch.flatten()
该函数的逆。它将多个维度合并为一个。- 参数:
input (Tensor) – 输入张量。
dim (int) – 要展开的维度,指定为
input.shape
中的索引。sizes (Tuple[int]) – 展开维度后的新形状。其中元素可以是 -1,在这种情况下,相应的输出维度将被推断。否则,
sizes
的乘积必须等于input.shape[dim]
。
- 返回值:
指定维度未展平的输入视图。
- 示例::
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape torch.Size([3, 2, 2, 1]) >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape torch.Size([3, 2, 2, 1]) >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape torch.Size([5, 2, 2, 3, 1, 1, 3])