• 教程 >
  • 扩展 ONNX 导出器操作符支持
快捷键

ONNX 简介 || 将 PyTorch 模型导出为 ONNX || 扩展 ONNX 导出器操作符支持 || 将具有控制流的模型导出为 ONNX

扩展 ONNX 导出器操作符支持

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

作者:王帝泰,朱俊

概述 ¶

本教程介绍了如何为不支持的 PyTorch 运算符创建 ONNX 实现,或者用您自己的实现替换现有实现。

我们将涵盖三种需要扩展 ONNX 导出器运算符支持的场景:

  • 覆盖现有 PyTorch 运算符的实现

  • 使用自定义 ONNX 运算符

  • 支持自定义 PyTorch 算子

您将学习的内容:

  • 如何在 ONNX 中覆盖或添加对 PyTorch 算子的支持。

  • 如何集成自定义 ONNX 算子以适应特定运行时。

  • 如何实现并将自定义 PyTorch 算子转换为 ONNX。

前提条件 _

在开始本教程之前,请确保您已完成以下先决条件:

  • torch >= 2.6

  • 目标 PyTorch 运算符

  • 在继续之前完成 ONNX Script 教程

  • 使用 ONNX Script 实现的运算符

覆盖现有 PyTorch 运算符的实现

尽管 ONNX 导出团队尽力支持所有 PyTorch 运算符,但其中一些可能尚未得到支持。在本节中,我们将演示如何将不受支持的 PyTorch 运算符添加到 ONNX 注册表中。

备注

实现不受支持的 PyTorch 运算符的步骤与替换现有 PyTorch 运算符的实现为自定义实现的步骤相同。由于我们实际上没有不受支持的 PyTorch 运算符可用于本教程,我们将利用这一点,以与运算符未由 ONNX 导出器实现时相同的方式,用自定义实现替换 torch.ops.aten.add.Tensor

当模型由于不受支持的运算符而无法导出到 ONNX 时,ONNX 导出器将显示类似以下错误消息:

No decompositions registered for [...]

错误信息表明不支持的 PyTorch 运算符是 torch.ops.aten.add.Tensor 。该运算符的类型是 <class 'torch._ops.OpOverload'> ,我们将使用此运算符作为注册自定义实现的靶子。

import torch
import onnxscript

# Opset 18 is the standard supported version as of PyTorch 2.6
from onnxscript import opset18 as op


# Create a model that uses the operator torch.ops.aten.add.Tensor
class Model(torch.nn.Module):
    def forward(self, input_x, input_y):
        return torch.ops.aten.add.Tensor(input_x, input_y)


# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# All attributes must be annotated with type hints.
def custom_aten_add(self, other, alpha: float = 1.0):
    if alpha != 1.0:
        alpha = op.CastLike(alpha, other)
        other = op.Mul(other, alpha)
    # To distinguish the custom implementation from the builtin one, we switch the order of the inputs
    return op.Add(other, self)


x = torch.tensor([1.0])
y = torch.tensor([2.0])

# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
onnx_program = torch.onnx.export(
    Model().eval(),
    (x, y),
    dynamo=True,
    custom_translation_table={
        torch.ops.aten.add.Tensor: custom_aten_add,
    },
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()

现在让我们检查模型,并验证模型是否使用了自定义实现。

print(onnx_program.model)

翻译使用了我们的自定义实现:在节点 node_Add_0 中, input_y 现在排在第一位, input_x 排在第二位。

我们可以使用 ONNX Runtime 运行模型,并通过直接在输入张量上调用 torch.onnx.ONNXProgram 来验证结果。

result = onnx_program(x, y)[0]
torch.testing.assert_close(result, torch.tensor([3.0]))

使用自定义 ONNX 算子

在这种情况下,我们使用标准 PyTorch 算子创建模型,但运行时(如微软的 ONNX 运行时)可以为该内核提供自定义实现,从而有效地替换现有实现。

在下面的示例中,我们使用 ONNX 运行时提供的 com.microsoft.Gelu 算子,它不同于 ONNX 规范中的 Gelu

class GeluModel(torch.nn.Module):
    def forward(self, input_x):
        return torch.ops.aten.gelu(input_x)


# Create a namespace for the custom operator using ONNX Script
# ``com.microsoft`` is an official ONNX Runtime namespace
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)

# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
# The function must be scripted using the ``@onnxscript.script()`` decorator when
# using operators from custom domains. This may be improved in future versions.
from onnxscript import FLOAT


@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
    return microsoft_op.Gelu(self)


onnx_program = torch.onnx.export(
    GeluModel().eval(),
    (x,),
    dynamo=True,
    custom_translation_table={
        torch.ops.aten.gelu.default: custom_aten_gelu,
    },
)

# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()

让我们检查模型并验证模型是否使用了命名空间 com.microsoft 中的 op_type Gelu

print(onnx_program.model)

与上一个示例类似,我们可以使用 ONNX Runtime 来运行模型并验证结果。

result = onnx_program(x)[0]
torch.testing.assert_close(result, torch.ops.aten.gelu(x))

支持自定义 PyTorch 运算符

在这种情况下,运算符是由用户实现并注册到 PyTorch 的运算符。

在下面的示例中,我们希望使用一个自定义运算符,该运算符接受一个张量输入,并返回一个输出。该运算符将输入加到自身上,并返回四舍五入的结果。

首先,我们假设自定义操作符已通过 torch.library.custom_op() 实现并注册。您可以参考《在 Python 中创建新的自定义操作符》以获取如何创建自定义操作符的详细指南。

# Define and use the operator in PyTorch
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
    return torch.round(input + input)


@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
    return torch.empty_like(tensor_x)


class AddAndRoundModel(torch.nn.Module):
    def forward(self, input):
        return add_and_round_op(input)


# Implement the custom operator in ONNX using ONNX Script
def onnx_add_and_round(input):
    return op.Round(op.Add(input, input))


onnx_program = torch.onnx.export(
    AddAndRoundModel().eval(),
    (x,),
    dynamo=True,
    custom_translation_table={
        torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
    },
)

# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
print(onnx_program)

翻译使用我们的自定义实现将 torch.export.ExportedProgram` 中的 torch.ops.mylibrary.add_and_round_op.default 操作符转换为 ONNX 操作符 AddRound

最后我们验证结果。

result = onnx_program(x)[0]
torch.testing.assert_close(result, add_and_round_op(x))

结论 ¶

恭喜!在本教程中,我们探讨了 custom_translation_table 选项,并发现了如何使用 ONNX Script 创建对不支持或现有 PyTorch 操作符的自定义实现。

终于,我们利用 ONNX Runtime 执行模型,并将结果与 PyTorch 进行比较,从而全面了解了在 ONNX 生态系统中处理不受支持的算子。

进一步阅读 ¶

以下列表涉及从基本示例到高级场景的教程,不一定按照列表顺序排列。您可以自由地直接跳转到您感兴趣的具体主题,或者耐心地逐一浏览,以了解有关 ONNX 导出器的所有内容。

1. 将 PyTorch 模型导出到 ONNX
2. 扩展 ONNX 导出器操作符支持
3. 将具有控制流的模型导出到 ONNX

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源