由亚历山大·索阿雷和弗朗西斯科·马萨撰写

引言

基于 FX 的特征提取是 TorchVision 的一个新工具,它允许我们在 PyTorch 模块的前向传递过程中访问输入的中间变换。它是通过符号跟踪前向方法来生成一个图,其中每个节点代表一个单独的操作。节点以人类可读的方式命名,以便可以轻松指定想要访问的节点。

这听起来有点复杂吗?不用担心,这篇文章中有一小部分内容适合每个人。无论你是初学者还是高级深度视觉从业者,你可能会想了解 FX 特征提取。如果你还想了解更多关于特征提取的一般背景,请继续阅读。如果你已经熟悉这一点,并想了解如何在 PyTorch 中实现它,请快速浏览“PyTorch 中的现有方法:优缺点”。如果你已经了解在 PyTorch 中进行特征提取的挑战,请随意快速浏览到“FX 拯救”部分。

特征提取回顾

我们都习惯了有一个深度神经网络(DNN)接受输入并产生输出的想法,我们并不一定考虑其中发生了什么。让我们以 ResNet-50 分类模型为例:

CResNet-50 takes an image of a bird and transforms that into the abstract concept 'bird'
图 1:ResNet-50 将一只鸟的图像转化为抽象概念“鸟”。来源:ImageNet 中的鸟图像。

我们知道,在 ResNet-50 架构中存在许多顺序的“层”,它们逐步转换输入。在下面的图 2 中,我们揭开盖子,展示 ResNet-50 中的层,同时也展示了输入在这些层中传递时的中间转换。

ResNet-50 transforms the input image in multiple steps. Conceptually, we may access the intermediate transformation of the image after each one of these steps.
图 2:ResNet-50 通过多个步骤转换输入图像。从概念上讲,我们可能访问在这些步骤之后图像的中间转换。来源:ImageNet 中的鸟图像。

PyTorch 中现有方法的优缺点

在基于 FX 的特征提取被引入之前,PyTorch 中已经存在几种进行特征提取的方法。

为了说明这些方法,让我们考虑一个简单的卷积神经网络,它执行以下操作

  • 应用多个“块”,每个块内部包含多个卷积层。
  • 经过几个块之后,它使用全局平均池化和展平操作。
  • 最后它使用一个单一的输出分类层。
import torch
from torch import nn


class ConvBlock(nn.Module):
   """
   Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
   via 2x2 max pool.
   """

   def __init__(self, num_layers, in_channels, out_channels):
       super().__init__()
       self.convs = nn.ModuleList(
           [nn.Sequential(
               nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
               nn.ReLU()
            )
            for i in range(num_layers)]
       )
       self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
      
   def forward(self, x):
       for conv in self.convs:
           x = conv(x)
       x = self.downsample(x)
       return x
      

class CNN(nn.Module):
   """
   Applies several ConvBlocks each doubling the number of channels, and
   halving the feature map size, before taking a global average and classifying.
   """

   def __init__(self, in_channels, num_blocks, num_classes):
       super().__init__()
       first_channels = 64
       self.blocks = nn.ModuleList(
           [ConvBlock(
               2 if i==0 else 3,
               in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
               out_channels=first_channels*(2**i))
            for i in range(num_blocks)]
       )
       self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
       self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)

   def forward(self, x):
       for block in self.blocks:
           x = block(x)
       x = self.global_pool(x)
       x = x.flatten(1)
       x = self.cls(x)
       return x


model = CNN(3, 4, 10)
out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

假设我们想要获取全局平均池化之前的最终特征图。我们可以这样做:

修改前向方法

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   self.final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x

直接返回它:

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x, final_feature_map

看起来很简单。但这里有一些缺点,都源于同一个根本问题:那就是修改源代码并不理想:

  • 由于项目的实际考虑,修改并不总是容易访问和更改。
  • 如果我们想要灵活性(开关特征提取,或者对其有变体),我们需要进一步修改源代码以支持这一点。
  • 不总是仅仅插入一行代码的问题。想想看,按照我写的这个模块,你是如何获取中间块中的特征图的。
  • 总体来说,当我们实际上不需要改变模型的工作方式时,我们更愿意避免维护模型源代码的开销。

当处理更大、更复杂的模型,并试图从嵌套子模块中获取特征时,这种缺点可能会变得更加棘手。

使用原始模块的参数编写一个新的模块

接着上面的例子,如果我们想从每个块中获取一个特征图,我们可以这样编写一个新的模块:

class CNNFeatures(nn.Module):
   def __init__(self, backbone):
       super().__init__()
       self.blocks = backbone.blocks

   def forward(self, x):
       feature_maps = []
       for block in self.blocks:
           x = block(x)
           feature_maps.append(x)
       return feature_maps


backbone = CNN(3, 4, 10)
model = CNNFeatures(backbone)
out = model(torch.zeros(1, 3, 32, 32))  # This is now a list of Tensors, each representing a feature map

事实上,这与 TorchVision 内部用于制作其许多检测模型的方法非常相似。

虽然这种方法解决了直接修改源代码的一些问题,但仍然存在一些主要的缺点:

  • 只能直接访问顶层子模块的输出,处理嵌套子模块会迅速变得复杂。
  • 我们必须小心,不要在输入和输出之间错过任何重要操作。我们在将原始模块的精确功能转录到新模块中引入了潜在的错误。

总体来说,这种方法和最后一种方法都存在将特征提取与模型源代码本身结合在一起的复杂性。实际上,如果我们检查 TorchVision 模型的源代码,我们可能会怀疑一些设计选择受到了希望以这种方式用于下游任务的影响。

使用钩子

钩子使我们从编写源代码的范式转向了指定输出的范式。考虑到我们上面的玩具 CNN 示例以及获取每个层的特征图的目标,我们可以使用如下钩子:

model = CNN(3, 4, 10)
feature_maps = []  # This will be a list of Tensors, each representing a feature map

def hook_feat_map(mod, inp, out):
	feature_maps.append(out)

for block in model.blocks:
	block.register_forward_hook(hook_feat_map)

out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

现在我们对访问嵌套子模块具有完全的灵活性,我们摆脱了修改源代码的责任。但这种方法也有其自身的缺点:

  • 我们只能将钩子应用于模块。如果我们有需要获取输出的功能操作(重塑、视图、功能非线性等),钩子无法直接应用于它们。
  • 我们没有修改任何源代码,因此整个前向传播都会执行,无论是否有钩子。如果我们只需要访问早期特征而不需要最终输出,这可能会导致大量的无用计算。
  • 钩子与 TorchScript 不兼容。

这里是对不同方法和它们的优缺点的一个总结:

  可以直接使用源代码,无需任何修改或重写 在访问功能方面具有完全的灵活性 跳过不必要的计算步骤 火炬脚本友好
修改前向方法 NO 技术上是的。取决于你愿意写多少代码。所以实际上,不是。 YES YES
新模块,重用原始模块的子模块/参数 NO 技术上是的。取决于你愿意写多少代码。所以实际上,不是。 YES YES
钩子 YES 主要是的。仅输出子模块的结果 NO NO

表 1:使用 PyTorch 进行特征提取的一些现有方法的优缺点

在本文的下一段,让我们看看我们如何让 YES 全面实现。

外汇救星

对于一些刚开始学习 Python 和编码的新手来说,此时可能出现的自然问题是:“我们能不能直接指向一行代码,并告诉 Python 或 PyTorch 我们想要那行代码的结果?”对于那些已经花费更多时间编码的人来说,为什么不能这样做的原因是显而易见的:一行代码中可能发生多个操作,无论是明确写在那里,还是作为子操作隐含的。只需以这个简单的模块为例:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.submodule = MySubModule()

    def forward(self, x):
        return self.submodule(x + self.param).clamp(min=0.0, max=1.0)

前向方法只有一行代码,我们可以将其展开为:

  1. 添加 self.paramx
  2. 将 x 通过 self.submodule 传递。在这里,我们需要考虑该子模块中发生的步骤。我将仅使用示例操作名称:I. submodule.op_1 II. submodule.op_2
  3. 应用钳位操作

因此,即使我们指向这一行,问题随之而来:“我们想要提取输出的哪个步骤?”

FX 是一个核心 PyTorch 工具包,(简化地说)它执行了我刚才提到的展开操作。它执行一种称为“符号跟踪”的操作,这意味着 Python 代码被解释并逐个操作地执行,使用一些虚拟代理来模拟真实输入。引入一些术语,上述每一步都被视为一个“节点”,连续的节点相互连接形成一个“图”(这与常见的数学概念图类似)。这里是将上述“步骤”翻译成这种图的概念。

Graphical representation of the result of symbolically tracing our example of a simple forward method.
图 3:符号追踪我们简单前向方法示例的结果的图形表示

请注意,我们将其称为图,而不仅仅是步骤集合,因为图可以分支并重新组合。想想残差块中的跳过连接。这看起来可能像这样:

Graphical representation of a residual skip connection. The middle node is like the main branch of a residual block, and the final node represents the sum of the input and output of the main branch.
图 4:残差跳连接的图形表示。中间节点类似于残差块的主分支,最终节点表示主分支的输入和输出的总和。

现在,TorchVision 的 get_graph_node_names 函数按照上述描述应用 FX,并在执行此过程中,为每个节点添加一个可读性强的名称。让我们用上一节中的玩具 CNN 模型来试一试:

model = CNN(3, 4, 10)
from torchvision.models.feature_extraction import get_graph_node_names
nodes, _ = get_graph_node_names(model)
print(nodes)

这将导致:

['x', 'blocks.0.convs.0.0', 'blocks.0.convs.0.1', 'blocks.0.convs.1.0', 'blocks.0.convs.1.1', 'blocks.0.downsample', 'blocks.1.convs.0.0', 'blocks.1.convs.0.1', 'blocks.1.convs.1.0', 'blocks.1.convs.1.1', 'blocks.1.convs.2.0', 'blocks.1.convs.2.1', 'blocks.1.downsample', 'blocks.2.convs.0.0', 'blocks.2.convs.0.1', 'blocks.2.convs.1.0', 'blocks.2.convs.1.1', 'blocks.2.convs.2.0', 'blocks.2.convs.2.1', 'blocks.2.downsample', 'blocks.3.convs.0.0', 'blocks.3.convs.0.1', 'blocks.3.convs.1.0', 'blocks.3.convs.1.1', 'blocks.3.convs.2.0', 'blocks.3.convs.2.1', 'blocks.3.downsample', 'global_pool', 'flatten', 'cls']

我们可以将这些节点名称视为对感兴趣操作的分层组织“地址”。例如,'blocks.1.downsample' 指的是第二个 ConvBlock 中的 MaxPool2d 层。

create_feature_extractor ,这是所有魔法发生的地方,比 get_graph_node_names 多走几步。它将所需的节点名称作为输入参数之一,然后使用更多的 FX 核心功能:

  1. 将所需的节点分配为输出。
  2. 删除不必要的下游节点及其相关参数。
  3. 将生成的图翻译回 Python 代码。
  4. 将另一个 PyTorch 模块返回给用户。该模块的 forward 方法包含步骤 3 中的 Python 代码。

作为演示,这是如何应用 create_feature_extractor 从我们的玩具 CNN 模型中获取 4 个特征图

from torchvision.models.feature_extraction import create_feature_extractor
# Confused about the node specification here?
# We are allowed to provide truncated node names, and `create_feature_extractor`
# will choose the last node with that prefix.
feature_extractor = create_feature_extractor(
	model, return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3'])
# `out` will be a dict of Tensors, each representing a feature map
out = feature_extractor(torch.zeros(1, 3, 32, 32))

简单来说。归根结底,FX 特征提取只是让我们能够实现一些人在刚开始编程时可能天真地希望的事情:“只给我这段代码的输出结果(指着屏幕)”*。

  • … 不需要我们修改源代码。
  • … 在访问任何中间转换方面提供了完全的灵活性,无论是模块的结果还是函数操作的结果。
  • ...提取特征后,会删除不必要的计算步骤
  • ...而且之前没提过,它还兼容 TorchScript!

这里是再次添加了 FX 特征提取行的新表格

  可以直接使用源代码,无需任何修改或重写 全面的功能访问灵活性 减少不必要的计算步骤 支持 TorchScript
修改前向方法 NO 技术上是的。取决于你愿意写多少代码。所以实际上,不是。 YES YES
新模块,重用原始模块的子模块/参数 NO 技术上是的。取决于你愿意写多少代码。所以实际上,不是。 YES YES
钩子 YES 主要为 YES。只有子模块的输出 NO NO
FX YES YES YES YES

表 2:表 1 的副本,增加了一行用于 FX 特征提取。FX 特征提取全面得到 YES!

当前 FX 的限制

虽然我很乐意在那里结束这篇文章,但 FX 确实有一些自己的限制,归结为:

  1. 在解释和翻译成图的过程中,可能存在一些尚未被 FX 处理的 Python 代码。
  2. 动态控制流无法用静态图来表示。

当这些问题出现时,最简单的方法是将底层代码打包成一个“叶子节点”。回想一下图 3 中的示例图?从概念上讲,我们可能同意 submodule 应该被视为一个节点,而不是表示底层操作的节点集合。如果我们这样做,我们可以重新绘制这个图,如下所示:

The individual operations within `submodule` may (left - within red box), may be consolidated into one node (right - node #2) if we consider the `submodule` as a 'leaf' node.
图 5:在`submodule`内部的各个操作(左 - 红色框内),如果我们将`submodule`视为“叶子”节点,则可能被合并成一个节点(右 - 节点#2)。

如果子模块中存在一些问题代码,我们可能想要这样做,但我们没有从其中提取任何中间转换的需求。实际上,通过为 create_feature_extractor 或 get_graph_node_names 提供关键字参数就可以轻松实现。

model = CNN(3, 4, 10)
nodes, _ = get_graph_node_names(model, tracer_kwargs={'leaf_modules': [ConvBlock]})
print(nodes)

输出将是:

['x', 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'global_pool', 'flatten', 'cls']

注意,与之前相比,任何给定 ConvBlock 的所有节点都合并成了一个节点。

我们可以对函数做类似的事情。例如,Python 内置的 len 需要被包装,并且结果应该被视为一个叶子节点。以下是使用核心 FX 功能实现此操作的方法:

torch.fx.wrap('len')

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       len(x)

model = MyModule()
feature_extractor = create_feature_extractor(model, return_nodes=['add'])

对于您定义的函数,您可以使用另一个关键字参数来 create_feature_extractor (细节:这里说明为什么您可能想这样做):

def myfunc(x):
   return len(x)

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       myfunc(x)

model = MyModule()
feature_extractor = create_feature_extractor(
   model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})

注意,上述所有修复都没有涉及修改源代码。

当然,有时您试图获取的中间转换可能就在导致问题的同一个前向方法或函数中。在这种情况下,我们不能简单地将该模块或函数视为叶子节点,因为那样我们就无法访问其中的中间转换。在这些情况下,可能需要对源代码进行一些重写。以下是一些示例(非详尽):

  • 当尝试通过带有 assert 语句的代码进行跟踪时,FX 将引发错误。在这种情况下,您可能需要删除该断言或将其替换为 torch._assert (这不是一个公共函数 - 因此请将其视为临时解决方案并谨慎使用)。
  • 符号化追踪张量切片的内部更改不受支持。您需要为切片创建一个新的变量,然后应用操作,最后使用连接或堆叠重建原始张量。
  • 在静态图中表示动态控制流在逻辑上是不可能的。看看您是否可以将编码逻辑简化为非动态的内容——请参阅 FX 文档以获取提示。

通常,您可以查阅 FX 文档以获取有关符号追踪限制和可能的解决方案的更多详细信息。

结论

我们快速回顾了特征提取及其原因。尽管 PyTorch 中已有进行特征提取的方法,但它们都存在相当显著的不足。我们学习了 TorchVision 的 FX 特征提取工具的工作原理以及它相较于现有方法的灵活性。虽然对于后者还有一些小问题需要解决,但我们理解了其局限性,可以根据我们的用例与其他方法的局限性进行权衡。希望将这个新工具添加到您的 PyTorch 工具箱后,您现在可以处理大多数可能遇到的特征提取需求。

开心编码!