快捷键

展平 ¶

class torch.nn.Flatten(start_dim=1, end_dim=- 1)[source][source]

将连续的维度展平成一个张量。

用于与 Sequential 配合,详细信息请见 torch.flatten()

形状:
  • 输入: (,Sstart,...,Si,...,Send,)(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *) ,其中 SiS_{i} 是维度 ii* 的大小, * 表示任意数量的维度,包括零个维度。

  • 输出: (,i=startendSi,)(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)

参数:
  • start_dim(整型)- 要展平的第一个维度(默认=1)。

  • end_dim (int) – 最后一个要展平的维度(默认 = -1)。

示例::
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源