• 文档 >
  • 量化 >
  • 量化 API 参考 >
  • 合并模块
快捷键

合并模块 ¶

class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[source][source]

将一系列模块合并成一个模块。

仅融合以下模块序列:conv, bn conv, bn, relu conv, relu linear, relu bn, relu 其他序列保持不变。对于这些序列,将列表中的第一个模块替换为融合模块,其余模块替换为恒等模块。

参数:
  • model – 包含要融合的模块的模型

  • modules_to_fuse – 要融合的模块名称列表的列表。如果只有一个要融合的模块列表,也可以是字符串列表。

  • inplace – 布尔值,指定融合是否在模型中就地发生,默认返回新模型

  • fuser_func – 函数接收一个模块列表,并输出相同长度的融合模块列表。例如,fuser_func([convModule, BNModule]) 返回列表 [ConvBNModule, nn.Identity()] 默认为 torch.ao.quantization.fuse_known_modules

  • fuse_custom_config_dict – 融合的自定义配置

# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
    # Additional fuser_method mapping
    "additional_fuser_method_mapping": {
        (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
    },
}
返回值:

拥有融合模块的模型。如果 inplace=True,则创建一个新的副本。

示例:

>>> m = M().eval()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

>>> m = M().eval()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源