备注
点击此处下载完整示例代码
ONNX 简介 || 将 PyTorch 模型导出为 ONNX || 扩展 ONNX 导出器操作符支持 || 将具有控制流的模型导出为 ONNX
扩展 ONNX 导出器操作符支持
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:王帝泰,朱俊
概述 ¶
本教程介绍了如何为不支持的 PyTorch 运算符创建 ONNX 实现,或者用您自己的实现替换现有实现。
我们将涵盖三种需要扩展 ONNX 导出器运算符支持的场景:
覆盖现有 PyTorch 运算符的实现
使用自定义 ONNX 运算符
支持自定义 PyTorch 算子
您将学习的内容:
如何在 ONNX 中覆盖或添加对 PyTorch 算子的支持。
如何集成自定义 ONNX 算子以适应特定运行时。
如何实现并将自定义 PyTorch 算子转换为 ONNX。
前提条件 _
在开始本教程之前,请确保您已完成以下先决条件:
torch >= 2.6
目标 PyTorch 运算符
在继续之前完成 ONNX Script 教程
使用 ONNX Script 实现的运算符
覆盖现有 PyTorch 运算符的实现
尽管 ONNX 导出团队尽力支持所有 PyTorch 运算符,但其中一些可能尚未得到支持。在本节中,我们将演示如何将不受支持的 PyTorch 运算符添加到 ONNX 注册表中。
备注
实现不受支持的 PyTorch 运算符的步骤与替换现有 PyTorch 运算符的实现为自定义实现的步骤相同。由于我们实际上没有不受支持的 PyTorch 运算符可用于本教程,我们将利用这一点,以与运算符未由 ONNX 导出器实现时相同的方式,用自定义实现替换 torch.ops.aten.add.Tensor
。
当模型由于不受支持的运算符而无法导出到 ONNX 时,ONNX 导出器将显示类似以下错误消息:
No decompositions registered for [...]
错误信息表明不支持的 PyTorch 运算符是 torch.ops.aten.add.Tensor
。该运算符的类型是 <class 'torch._ops.OpOverload'>
,我们将使用此运算符作为注册自定义实现的靶子。
import torch
import onnxscript
# Opset 18 is the standard supported version as of PyTorch 2.6
from onnxscript import opset18 as op
# Create a model that uses the operator torch.ops.aten.add.Tensor
class Model(torch.nn.Module):
def forward(self, input_x, input_y):
return torch.ops.aten.add.Tensor(input_x, input_y)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# All attributes must be annotated with type hints.
def custom_aten_add(self, other, alpha: float = 1.0):
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
# To distinguish the custom implementation from the builtin one, we switch the order of the inputs
return op.Add(other, self)
x = torch.tensor([1.0])
y = torch.tensor([2.0])
# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
onnx_program = torch.onnx.export(
Model().eval(),
(x, y),
dynamo=True,
custom_translation_table={
torch.ops.aten.add.Tensor: custom_aten_add,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
现在让我们检查模型,并验证模型是否使用了自定义实现。
print(onnx_program.model)
翻译使用了我们的自定义实现:在节点 node_Add_0
中, input_y
现在排在第一位, input_x
排在第二位。
我们可以使用 ONNX Runtime 运行模型,并通过直接在输入张量上调用 torch.onnx.ONNXProgram
来验证结果。
result = onnx_program(x, y)[0]
torch.testing.assert_close(result, torch.tensor([3.0]))
使用自定义 ONNX 算子
在这种情况下,我们使用标准 PyTorch 算子创建模型,但运行时(如微软的 ONNX 运行时)可以为该内核提供自定义实现,从而有效地替换现有实现。
在下面的示例中,我们使用 ONNX 运行时提供的 com.microsoft.Gelu
算子,它不同于 ONNX 规范中的 Gelu
。
class GeluModel(torch.nn.Module):
def forward(self, input_x):
return torch.ops.aten.gelu(input_x)
# Create a namespace for the custom operator using ONNX Script
# ``com.microsoft`` is an official ONNX Runtime namespace
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
# The function must be scripted using the ``@onnxscript.script()`` decorator when
# using operators from custom domains. This may be improved in future versions.
from onnxscript import FLOAT
@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
return microsoft_op.Gelu(self)
onnx_program = torch.onnx.export(
GeluModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.aten.gelu.default: custom_aten_gelu,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
让我们检查模型并验证模型是否使用了命名空间 com.microsoft
中的 op_type Gelu
。
print(onnx_program.model)
与上一个示例类似,我们可以使用 ONNX Runtime 来运行模型并验证结果。
result = onnx_program(x)[0]
torch.testing.assert_close(result, torch.ops.aten.gelu(x))
支持自定义 PyTorch 运算符
在这种情况下,运算符是由用户实现并注册到 PyTorch 的运算符。
在下面的示例中,我们希望使用一个自定义运算符,该运算符接受一个张量输入,并返回一个输出。该运算符将输入加到自身上,并返回四舍五入的结果。
首先,我们假设自定义操作符已通过 torch.library.custom_op()
实现并注册。您可以参考《在 Python 中创建新的自定义操作符》以获取如何创建自定义操作符的详细指南。
# Define and use the operator in PyTorch
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
return torch.round(input + input)
@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
return torch.empty_like(tensor_x)
class AddAndRoundModel(torch.nn.Module):
def forward(self, input):
return add_and_round_op(input)
# Implement the custom operator in ONNX using ONNX Script
def onnx_add_and_round(input):
return op.Round(op.Add(input, input))
onnx_program = torch.onnx.export(
AddAndRoundModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
print(onnx_program)
翻译使用我们的自定义实现将 torch.export.ExportedProgram`
中的 torch.ops.mylibrary.add_and_round_op.default
操作符转换为 ONNX 操作符 Add
和 Round
。
最后我们验证结果。
result = onnx_program(x)[0]
torch.testing.assert_close(result, add_and_round_op(x))
结论 ¶
恭喜!在本教程中,我们探讨了 custom_translation_table
选项,并发现了如何使用 ONNX Script 创建对不支持或现有 PyTorch 操作符的自定义实现。
终于,我们利用 ONNX Runtime 执行模型,并将结果与 PyTorch 进行比较,从而全面了解了在 ONNX 生态系统中处理不受支持的算子。
进一步阅读 ¶
以下列表涉及从基本示例到高级场景的教程,不一定按照列表顺序排列。您可以自由地直接跳转到您感兴趣的具体主题,或者耐心地逐一浏览,以了解有关 ONNX 导出器的所有内容。
脚本总运行时间:(0 分钟 0.000 秒)