• 教程 >
  • 融合模块食谱
快捷键

融合模块食谱 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

本食谱演示了如何将一系列 PyTorch 模块融合成一个单一模块,以及如何进行性能测试以比较融合模型与其非融合版本。

简介

在对模型应用量化以减小其大小和内存占用之前(有关量化的详细信息,请参阅量化食谱),模型中的模块列表可能首先被融合成一个单一模块。融合是可选的,但它可能节省内存访问,使模型运行更快,并提高其准确性。

前置条件 ¶

PyTorch 1.6.0 或 1.7.0

步骤 ¶

按以下步骤融合示例模型,量化它,脚本化它,为移动端优化它,保存它,并使用 Android 基准工具进行测试。

1. 定义示例模型 ¶

使用在 PyTorch Mobile 性能秘籍中定义的相同示例模型:

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

class AnnotatedConvBnReLUModel(torch.nn.Module):
    def __init__(self):
        super(AnnotatedConvBnReLUModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU(inplace=True)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

2. 生成带和不带 fuse_modules 的两个模型

在上面的模型定义下方添加以下代码并运行脚本:

model = AnnotatedConvBnReLUModel()
print(model)

def prepare_save(model, fused):
    model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    torchscript_model = torch.jit.script(model)
    torchscript_model_optimized = optimize_for_mobile(torchscript_model)
    torch.jit.save(torchscript_model_optimized, "model.pt" if not fused else "model_fused.pt")

prepare_save(model, False)

model = AnnotatedConvBnReLUModel()
model_fused = torch.quantization.fuse_modules(model, [['bn', 'relu']], inplace=False)
print(model_fused)

prepare_save(model_fused, True)

原始模型及其融合版本的图表将按以下方式打印:

AnnotatedConvBnReLUModel(
  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

AnnotatedConvBnReLUModel(
  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BNReLU2d(
    (0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU(inplace=True)
  )
  (relu): Identity()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

在第二个融合模型输出中,列表中的第一个 bn 项被融合模块替换,其余模块(例如本例中的 relu)被替换为恒等映射。此外,生成了非融合版本和融合版本的模型,分别为 model.pt 和 model_fused.pt。

3. 构建安卓基准测试工具

获取 PyTorch 源代码,并按照以下步骤构建 Android 基准工具:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
git submodule update --init --recursive
BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON

这将在 build_android/bin 文件夹中生成 Android 基准二进制文件 speed_benchmark_torch。

4. 测试比较融合和非融合模型

连接您的 Android 设备,然后复制 speed_benchmark_torch 和模型文件,并在它们上运行基准工具:

adb push build_android/bin/speed_benchmark_torch /data/local/tmp
adb push model.pt /data/local/tmp
adb push model_fused.pt /data/local/tmp
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model.pt" --input_dims="1,3,224,224" --input_type="float"
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model_fused.pt" --input_dims="1,3,224,224" --input_type="float"

最后两个命令的结果应该如下:

Main run finished. Microseconds per iter: 6189.07. Iters per second: 161.575

以及

Main run finished. Microseconds per iter: 6216.65. Iters per second: 160.858

对于这个示例模型,融合模型和非融合模型之间的性能差异并不大。但可以使用类似的步骤来融合和准备一个真实的深度模型,并测试性能提升。请注意,目前 torch.quantization.fuse_modules 只能融合以下模块序列:

  • 卷积,归一化

  • 卷积,归一化,ReLU

  • 卷积,ReLU

  • 线性,ReLU

  • 批标准化,ReLU

如果向 fuse_modules 调用提供了任何其他序列列表,它将被简单地忽略。

了解更多 ¶

请参阅此处 torch.quantization.fuse_modules 的官方文档。


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源