• 教程 >
  • (原型)FX 图模式量化用户指南
快捷键

(原型)FX 图形模式量化用户指南 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:张杰

FX 图模式量化需要具有符号可追踪性的模型。我们使用 FX 框架将符号可追踪的 nn.Module 实例转换为 IR,并在 IR 上执行量化过程。请在 PyTorch 讨论论坛中发布您关于在 PyTorch 中进行模型符号追踪的问题

量化过程仅适用于模型中的符号可追踪部分。数据依赖的控制流(例如 if 语句、for 循环等使用符号可追踪值的)是一种常见的模式,目前不支持。如果您的模型不是端到端符号可追踪的,您有几种选择来仅对模型的一部分启用 FX 图模式量化。您可以使用以下选项的任意组合:

  1. 不可追踪的代码无需量化
    1. 仅符号追踪需要量化的代码

    2. 跳过不可追踪代码的符号追踪

  2. 不可追踪的代码需要量化
    1. 重新设计你的代码以使其可符号追踪

    2. 编写自己的观测和量化子模块

如果不需要量化的代码不可符号追踪,我们可以选择以下两种方式运行 FX 图模式量化:

仅符号化追踪需要量化的代码 ¶

当整个模型不可符号追踪,但我们要量化的子模块可符号追踪时,我们只需在该子模块上运行量化。

在之前:

class M(nn.Module):
    def forward(self, x):
        x = non_traceable_code_1(x)
        x = traceable_code(x)
        x = non_traceable_code_2(x)
        return x

after:之后

class FP32Traceable(nn.Module):
    def forward(self, x):
        x = traceable_code(x)
        return x

class M(nn.Module):
    def __init__(self):
        self.traceable_submodule = FP32Traceable(...)
    def forward(self, x):
        x = self.traceable_code_1(x)
        # We'll only symbolic trace/quantize this submodule
        x = self.traceable_submodule(x)
        x = self.traceable_code_2(x)
        return x

量化代码:

qconfig_mapping = QConfigMapping().set_global(qconfig)
model_fp32.traceable_submodule = \
  prepare_fx(model_fp32.traceable_submodule, qconfig_mapping, example_inputs)

如果原始模型需要保留,您必须在调用量化 API 之前自行复制。

跳过符号化追踪不可追踪的代码 ¶

当我们在模块中有一些不可追踪的代码,并且这部分代码不需要量化时,我们可以将这部分代码提取到一个子模块中,并跳过对该子模块的符号化追踪。

在...之前

class M(nn.Module):

    def forward(self, x):
        x = self.traceable_code_1(x)
        x = non_traceable_code(x)
        x = self.traceable_code_2(x)
        return x

之后,不可追踪的部分移动到模块中并标记为叶子

class FP32NonTraceable(nn.Module):

    def forward(self, x):
        x = non_traceable_code(x)
        return x

class M(nn.Module):

    def __init__(self):
        ...
        self.non_traceable_submodule = FP32NonTraceable(...)

    def forward(self, x):
        x = self.traceable_code_1(x)
        # we will configure the quantization call to not trace through
        # this submodule
        x = self.non_traceable_submodule(x)
        x = self.traceable_code_2(x)
        return x

量化代码:

qconfig_mapping = QConfigMapping.set_global(qconfig)

prepare_custom_config_dict = {
    # option 1
    "non_traceable_module_name": "non_traceable_submodule",
    # option 2
    "non_traceable_module_class": [MNonTraceable],
}
model_prepared = prepare_fx(
    model_fp32,
    qconfig_mapping,
    example_inputs,
    prepare_custom_config_dict=prepare_custom_config_dict,
)

如果需要量化不可符号追踪的代码,我们有以下两种选择:

优化你的代码以实现符号可追踪性 ¶

如果代码易于重构并且可以符号化追踪,我们可以重构代码并移除 Python 中不可追踪的结构。

关于符号化追踪支持的更多信息可以在此处找到。

在之前:

def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

这不是符号化可追踪的,因为在 x.view(*new_x_shape)中不支持解包,但是很容易移除解包,因为 x.view 也支持列表输入。

after:之后

def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

这可以与其他方法结合使用,量化代码取决于模型。

编写你自己的观测和量化子模块 ¶

如果无法将不可追踪的代码重构为可符号追踪的代码,例如,它包含一些无法消除的循环,如 nn.LSTM,我们需要将不可追踪的代码提取到一个子模块中(在 fx 图模式量化中我们称之为 CustomModule)并定义子模块的观测和量化版本(在训练后静态量化或量化感知训练中进行静态量化)或定义量化版本(在训练后动态和仅权重量化中进行)。

在之前:

class M(nn.Module):

    def forward(self, x):
        x = traceable_code_1(x)
        x = non_traceable_code(x)
        x = traceable_code_1(x)
        return x

after:之后

1. 将不可追踪代码提取到 FP32NonTraceable 非追踪逻辑中,并在模块中封装

class FP32NonTraceable:
    ...

2. 定义 FP32NonTraceable 的观测版本

class ObservedNonTraceable:

    @classmethod
    def from_float(cls, ...):
        ...

3. 定义 FP32NonTraceable 的静态量化版本,并添加一个类方法“from_observed”,用于将 ObservedNonTraceable 转换为 StaticQuantNonTraceable

class StaticQuantNonTraceable:

    @classmethod
    def from_observed(cls, ...):
        ...
# refactor parent class to call FP32NonTraceable
class M(nn.Module):

   def __init__(self):
        ...
        self.non_traceable_submodule = FP32NonTraceable(...)

    def forward(self, x):
        x = self.traceable_code_1(x)
        # this part will be quantized manually
        x = self.non_traceable_submodule(x)
        x = self.traceable_code_1(x)
        return x

量化代码:

# post training static quantization or
# quantization aware training (that produces a statically quantized module)v
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        "static": {
            FP32NonTraceable: ObservedNonTraceable,
        }
    },
}

model_prepared = prepare_fx(
    model_fp32,
    qconfig_mapping,
    example_inputs,
    prepare_custom_config_dict=prepare_custom_config_dict)

校准/训练(未展示)

convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "static": {
            ObservedNonTraceable: StaticQuantNonTraceable,
        }
    },
}
model_quantized = convert_fx(
    model_prepared,
    convert_custom_config_dict)

在训练后动态/仅权重量化这两种模式下,我们不需要观察原始模型,因此我们只需要定义量化模型

class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable
   ...
   @classmethod
   def from_observed(cls, ...):
       ...

   prepare_custom_config_dict = {
       "non_traceable_module_class": [
           FP32NonTraceable
       ]
   }
# The example is for post training quantization
model_fp32.eval()
model_prepared = prepare_fx(
    model_fp32,
    qconfig_mapping,
    example_inputs,
    prepare_custom_config_dict=prepare_custom_config_dict)

convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "dynamic": {
            FP32NonTraceable: DynamicQuantNonTraceable,
        }
    },
}
model_quantized = convert_fx(
    model_prepared,
    convert_custom_config_dict)

你也可以在测试 test_custom_module_class 中找到自定义模块的示例 torch/test/quantization/test_quantize_fx.py


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源