展平 ¶
- class torch.nn.Flatten(start_dim=1, end_dim=- 1)[source][source]¶
将连续的维度展平成一个张量。
用于与
Sequential
配合,详细信息请见torch.flatten()
。- 形状:
输入: ,其中 是维度 和 的大小, 表示任意数量的维度,包括零个维度。
输出: 。
- 参数:
start_dim(整型)- 要展平的第一个维度(默认=1)。
end_dim (int) – 最后一个要展平的维度(默认 = -1)。
- 示例::
>>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5])