快捷键

展平 ¶

class torch.nn.Unflatten(dim, unflattened_size)[source][source]

将张量 dim 展开到期望的形状。用于与 Sequential 一起使用。

  • dim 指定要展开的输入张量的维度,当使用 Tensor 或 NamedTensor 时,分别可以是 int 或 str。

  • unflattened_size 是展开维度的新形状,对于 Tensor 输入可以是 int 的元组、int 的列表或 torch.Size;对于 NamedTensor 输入是 NamedShape((name, size) 元组的元组)。

形状:
  • 输入: (,Sdim,)(*, S_{\text{dim}}, *) ,其中 SdimS_{\text{dim}} 是维度 dim 的大小, * 表示包括零个在内的任意数量的维度。

  • 输出: (,U1,...,Un,)(*, U_1, ..., U_n, *) ,其中 UU = unflattened_sizei=1nUi=Sdim\prod_{i=1}^n U_i = S_{\text{dim}}

参数:
  • dim(int 或 str 的联合)- 要展开的维度

  • unflattened_size(torch.Size、元组、列表、NamedShape)- 展开维度后的新形状

示例

>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
NamedShape

tuple [ tuple [ str , int ]]的别名


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源