快捷键

torch.onnx.verification

ONNX 验证模块提供了一套验证 ONNX 模型正确性的工具。

torch.onnx.verification.verify_onnx_program(onnx_program, args=None, kwargs=None, compare_intermediates=False)[source]

通过比较与 ExportedProgram 的预期值来验证 ONNX 模型。

参数:
  • onnx_program (_onnx_program.ONNXProgram) – 要验证的 ONNX 程序。

  • args (tuple[Any, ...] | None) – 模型的输入参数。

  • kwargs (dict[str, Any] | None) – 模型的关键字参数。

  • compare_intermediates (bool) – 是否验证中间值。这将花费更长的时间,因此默认情况下是禁用的。

返回:

包含每个值的验证信息的 VerificationInfo 对象。

返回类型:

list[VerificationInfo]

class torch.onnx.verification.VerificationInfo(name, max_abs_diff, max_rel_diff, abs_diff_hist, rel_diff_hist, expected_dtype, actual_dtype)

ONNX 程序中一个值的验证信息。

此类包含预期值和实际值之间的最大绝对差异、最大相对差异以及绝对和相对差异的直方图。它还包括预期和实际的数据类型。

直方图表示为张量的元组,其中第一个张量是直方图计数,第二个张量是箱边。

变量:
  • 名称(str)- 值的名称(输出或中间值)。

  • max_abs_diff (浮点数) – 预期值与实际值之间的最大绝对差。

  • max_rel_diff (浮点数) – 预期值与实际值之间的最大相对差。

  • abs_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 表示绝对差值直方图的张量元组。第一个张量是直方图计数,第二个张量是边界。

  • rel_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 表示相对差值直方图的张量元组。第一个张量是直方图计数,第二个张量是边界。

  • 预期数据类型(torch.dtype)- 预期值的类型。

  • 实际数据类型(torch.dtype)- 实际值的类型。

classmethod from_tensors(name, expected, actual)[source][source]

从两个张量创建一个 VerificationInfo 对象。

参数:
  • 名称(str)- 值的名称。

  • 预期(torch.Tensor | float | int | bool)- 预期的张量。

  • 实际(torch.Tensor | float | int | bool)- 实际的张量。

返回:

VerificationInfo 对象。

返回类型:

验证信息

torch.onnx.verification.verify(model, input_args, input_kwargs=None, do_constant_folding=True, dynamic_axes=None, input_names=None, output_names=None, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, fixed_batch_size=False, use_external_data=False, additional_test_inputs=None, options=None)[source][source]

验证模型导出为 ONNX 与原始 PyTorch 模型的一致性。

自版本 2.7 开始已弃用:请考虑使用 torch.onnx.export(..., dynamo=True) 并使用返回的 ONNXProgram 来测试 ONNX 模型。

参数:
  • 模型(Union[Module, ScriptModule]) – 查看 torch.onnx.export()

  • 输入参数(Union[Tensor, tuple[Any, ...]]) – 查看 torch.onnx.export()

  • 输入关键字参数(collections.abc.Mapping[str, Any] | None) – 查看 torch.onnx.export()

  • 执行常量折叠(bool) – 查看 torch.onnx.export()

  • dynamic_axes (collections.abc.Mapping[str, collections.abc.Mapping[int, str] | collections.abc.Mapping[str, collections.abc.Sequence[int]]] | None) – 查看 torch.onnx.export() .

  • input_names (collections.abc.Sequence[str] | None) – 查看 torch.onnx.export() .

  • output_names (collections.abc.Sequence[str] | None) – 查看 torch.onnx.export() .

  • training (TrainingMode) – 查看 torch.onnx.export() .

  • opset_version (int | None) – 查看第 0 页# 。

  • keep_initializers_as_inputs (bool) – 查看第 0 页# 。

  • verbose (bool) – 查看第 0 页# 。

  • fixed_batch_size (bool) – 旧版参数,仅由 rnn 测试用例使用。

  • 使用外部数据(布尔值)- 明确指定是否导出包含外部数据的模型。

  • additional_test_inputs(collections.abc.Sequence[Union[torch.Tensor, tuple[Any, ...]]] | None)- 元组列表。每个元组是一组用于测试的输入参数。目前仅支持 *args

  • options(torch.onnx.verification.VerificationOptions | None)- 控制验证行为的 VerificationOptions 对象。

引发:
  • AssertionError - 如果 ONNX 模型和 PyTorch 模型的输出在指定的精度上不相等。

  • 值错误 - 如果提供的参数无效。

已弃用 ¶

以下类和函数已弃用。

class torch.onnx.verification.check_export_model_diff[source][source]
class torch.onnx.verification.GraphInfo[source][source]
class torch.onnx.verification.GraphInfoPrettyPrinter[source][source]
class torch.onnx.verification.OnnxBackend[source][source]
class torch.onnx.verification.OnnxTestCaseRepro[source][source]
class torch.onnx.verification.VerificationOptions[source][source]
torch.onnx.verification.find_mismatch()[source][source]
torch.onnx.verification.verify_aten_graph()[source][source]

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源