展平 ¶
- class torch.nn.Unflatten(dim, unflattened_size)[source][source]¶
将张量 dim 展开到期望的形状。用于与
Sequential
一起使用。dim
指定要展开的输入张量的维度,当使用 Tensor 或 NamedTensor 时,分别可以是 int 或 str。unflattened_size
是展开维度的新形状,对于 Tensor 输入可以是 int 的元组、int 的列表或 torch.Size;对于 NamedTensor 输入是 NamedShape((name, size) 元组的元组)。
- 形状:
输入: ,其中 是维度
dim
的大小, 表示包括零个在内的任意数量的维度。输出: ,其中 =
unflattened_size
和 。
- 参数:
dim(int 或 str 的联合)- 要展开的维度
unflattened_size(torch.Size、元组、列表、NamedShape)- 展开维度后的新形状
示例
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])
- NamedShape
tuple
[tuple
[str
,int
]]的别名