• 文档 >
  • 量化 >
  • 量化精度调试 >
  • torch.ao.ns._numeric_suite
快捷键

torch.ao.ns._numeric_suite

警告

此模块为早期原型,可能随时更改。

torch.ao.ns._numeric_suite.compare_weights(float_dict, quantized_dict)[source][source]

比较浮点模块与其对应的量化模块的权重。返回一个字典,键对应模块名称,每个条目都是一个字典,包含两个键‘float’和‘quantized’,分别包含浮点数和量化的权重。此字典可用于比较和计算浮点数和量化模型的权重量化误差。

演示用法:

wt_compare_dict = compare_weights(
    float_model.state_dict(), qmodel.state_dict())
for key in wt_compare_dict:
    print(
        key,
        compute_error(
            wt_compare_dict[key]['float'],
            wt_compare_dict[key]['quantized'].dequantize()
        )
    )
参数:
  • float_dict (dict[str, Any]) – 浮点模型的状态字典

  • quantized_dict (dict[str, Any]) – 量化模型的状态字典

返回:

与模块名称对应的键的字典,每个条目都是一个包含两个键‘float’和‘quantized’的字典,分别包含浮点数和量化权重

返回类型:

权重字典

torch.ao.ns._numeric_suite.get_logger_dict(mod, prefix='')[source][source]

遍历模块并将所有日志统计信息保存到目标字典中。这主要用于量化精度调试。

支持的日志记录器类型:

ShadowLogger:用于记录量化模块及其匹配的浮点阴影模块的输出,OutputLogger:用于记录模块的输出

参数:
  • mod(模块)- 我们想要保存所有日志统计信息的模块

  • prefix(字符串)- 当前模块的前缀

返回:

保存所有日志统计信息的字典

返回类型:

目标字典

类 torch.ao.ns._numeric_suite.Logger[source][source]

统计日志的基础类

forward(x)[source][source]
类 torch.ao.ns._numeric_suite.ShadowLogger[source][source] ¶

用于记录原始模块和阴影模块输出的类。

forward(x, y)[source][source]
class torch.ao.ns._numeric_suite.OutputLogger[source][source]

用于记录模块输出的类

forward(x)[source][source]
class torch.ao.ns._numeric_suite.Shadow(q_module, float_module, logger_cls)[source][source]

阴影模块将浮点模块与其匹配的量化模块关联起来作为阴影。然后它使用 Logger 模块处理这两个模块的输出。

参数:
  • q_module – 我们想要阴影的从 float_module 量化的模块

  • float_module – 用于阴影 q_module 的浮点模块

  • logger_cls – 用于处理 q_module 和 float_module 输出的日志器类型。可以使用 ShadowLogger 或自定义日志器。

forward(*x)[source][source]
返回类型:

张量

add(x, y)[source][source]
返回类型:

张量

add_scalar(x, y)[source][source]
返回类型:

张量

mul(x, y)[source][source]
返回类型:

张量

mul_scalar(x, y)[source][source]
返回类型:

张量

cat(x, dim=0)[source][source]
返回类型:

张量

add_relu(x, y)[source][source]
返回类型:

张量

torch.ao.ns._numeric_suite.prepare_model_with_stubs(float_module, q_module, module_swap_list, logger_cls)[source][source]

将浮点模块作为影子模块附加到其匹配的量化模块上,前提是浮点模块类型在 module_swap_list 列表中。

演示用法:

prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
q_model(data)
ob_dict = get_logger_dict(q_model)
参数:
  • float_module(模块)- 用于生成 q_module 的浮点模块

  • q_module(模块)- 从 float_module 量化的模块

  • module_swap_list(集合[type])- 要附加影子的浮点模块类型列表

  • logger_cls (Callable) – 在阴影模块中用于处理量化模块及其浮点阴影模块输出的日志记录器类型

torch.ao.ns._numeric_suite.compare_model_stub(float_model, q_model, module_swap_list, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.ShadowLogger'>)[source][source]

比较模型中的量化模块与其浮点对应模块,为两者提供相同的输入。返回一个字典,键对应模块名称,每个条目都是一个包含两个键 'float' 和 'quantized' 的字典,分别包含量化及其匹配的浮点阴影模块的输出张量。此字典可用于比较和计算模块级别的量化误差。

此函数首先调用 prepare_model_with_stubs() 来交换我们想要与 Shadow 模块进行比较的量化模块,该函数接收量化模块、对应的浮点模块和日志记录器作为输入,并在内部创建一个前向路径,使浮点模块与影子量化模块共享相同的输入。日志记录器可以自定义,默认日志记录器为 ShadowLogger,它将保存量化模块和浮点模块的输出,可用于计算模块级别的量化误差。

演示用法:

module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
for key in ob_dict:
    print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
参数:
  • float_model(模块)- 用于生成 q_model 的浮点模型

  • q_model(模块)- 从 float_model 量化的模型

  • module_swap_list(set[type])- 将附加影子模块的浮点模块类型列表。

  • 数据 - 运行准备好的 q_model 所使用的输入数据

  • logger_cls - 在阴影模块中用于处理量化模块及其浮点阴影模块输出的日志器类型

返回类型:

dict[str, dict]

torch.ao.ns._numeric_suite.get_matching_activations(float_module, q_module)[source][source]

找到浮点模块和量化模块之间的匹配激活。

参数:
  • float_module(模块)- 用于生成 q_module 的浮点模块

  • q_module(模块)- 从 float_module 量化的模块

返回:

包含与量化模块名称对应的键的字典,每个条目都是一个包含两个键‘float’和‘quantized’的字典,包含匹配的浮点激活和量化激活

返回类型:

act_dict

torch.ao.ns._numeric_suite.prepare_model_outputs(float_module, q_module, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[source][source]

准备模型,将记录器附加到 float 模块和量化模块,如果它们在 allow_list 中。

参数:
  • float_module (模块) – 用于生成 q_module 的 float 模块

  • q_module(模块)- 从 float_module 量化的模块

  • logger_cls - 要附加到 float_module 和 q_module 的日志记录器类型

  • allow_list - 要附加日志记录器的模块类型列表

torch.ao.ns._numeric_suite.compare_model_outputs(float_model, q_model, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[source][source]

比较相同输入下浮点模型和量化模型在对应位置的输出激活。返回一个字典,键对应量化模块名称,每个条目都是一个字典,包含两个键‘float’和‘quantized’,分别包含量化模型和浮点模型在匹配位置的激活。此字典可用于比较和计算传播量化误差。

演示用法:

act_compare_dict = compare_model_outputs(float_model, qmodel, data)
for key in act_compare_dict:
    print(
        key,
        compute_error(
            act_compare_dict[key]['float'],
            act_compare_dict[key]['quantized'].dequantize()
        )
    )
参数:
  • float_model(模块)- 用于生成 q_model 的浮点模型

  • q_model(模块)- 从 float_model 量化的模型

  • data – 运行准备好的 float_model 和 q_model 所使用的输入数据

  • logger_cls – 要附加到 float_module 和 q_module 的日志记录器类型

  • allow_list – 要附加日志记录器的模块类型列表

返回:

包含与量化模块名称对应的键的字典,每个条目都是一个包含两个键‘float’和‘quantized’的字典,包含匹配的浮点数和量化激活

返回类型:

act_compare_dict


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源