备注
点击此处下载完整示例代码
ONNX 简介 || 将 PyTorch 模型导出为 ONNX || 扩展 ONNX 导出器操作符支持 || 将带有控制流的模型导出为 ONNX
将具有控制流的模型导出为 ONNX ¶
作者:Xavier Dupré
概述 ¶
本教程演示了在将 PyTorch 模型导出为 ONNX 时如何处理控制流逻辑。它突出了直接导出条件语句的挑战,并提供了绕过这些挑战的解决方案。
除非将它们重构为使用 torch.cond()
,否则条件逻辑无法导出到 ONNX。让我们从一个实现测试的简单模型开始。
您将学习的内容:
如何将模型重构以使用
torch.cond()
进行导出。如何将具有控制流逻辑的模型导出到 ONNX。
如何使用 ONNX 优化器优化导出的模型。
前提条件 _
torch >= 2.6
import torch
定义模型 ¶
定义了两种模型:
ForwardWithControlFlowTest
:一个包含 if-else 条件的正向方法的模型。
ModelWithControlFlowTest
:将 ForwardWithControlFlowTest
作为简单 MLP 的一部分的模型。使用随机输入张量测试这些模型以确认它们按预期执行。
class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x
class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)
def forward(self, x):
out = self.mlp(x)
return out
model = ModelWithControlFlowTest()
导出模型:第一次尝试 ¶
使用 torch.export.export 导出此模型失败,因为前向传递中的控制流逻辑创建了一个导出器无法处理的图断点。这种行为是预期的,因为未使用 torch.cond()
编写的条件逻辑不受支持。
使用 try-except 块捕获导出过程中的预期失败。如果导出意外成功,将引发 AssertionError
异常。
x = torch.randn(3)
model(x)
try:
torch.export.export(model, (x,), strict=False)
raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
print(e)
使用 torch.onnx.export()
与 JIT Tracing
当使用 torch.onnx.export()
和 dynamo=True 参数导出模型时,导出器默认使用 JIT 跟踪。这种回退允许模型导出,但生成的 ONNX 图可能无法忠实反映原始模型逻辑,因为跟踪存在局限性。
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
建议补丁:使用 torch.cond()
¶ 进行重构
为了使控制流程可导出,教程演示了将 ForwardWithControlFlowTest
中的 forward 方法替换为使用 torch.cond`()
的重构版本。
重构详情:
两个辅助函数(identity2 和 neg)代表条件逻辑的分支:* torch.cond`()
用于指定条件和两个分支以及输入参数。* 然后将更新后的 forward 方法动态分配给模型中的 ForwardWithControlFlowTest
实例。打印出子模块列表以确认替换。
def new_forward(x):
def identity2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward
让我们看看 FX 图的样子。
print(torch.export.export(model, (x,), strict=False))
让我们再次导出。
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
我们可以优化模型并移除创建来捕获控制流分支的模型局部函数。
onnx_program.optimize()
print(onnx_program.model)
结论 ¶
本教程演示了将具有条件逻辑的模型导出到 ONNX 的挑战,并提出了使用 torch.cond()
的实用解决方案。虽然默认导出器可能会失败或生成不完美的图,但重构模型的逻辑确保了兼容性并生成了忠实的 ONNX 表示。
通过理解这些技术,我们可以克服在使用 PyTorch 模型中的控制流时遇到的常见问题,并确保与 ONNX 工作流程的顺利集成。
进一步阅读 ¶
以下列表涉及从基本示例到高级场景的教程,不一定按照列表顺序排列。您可以自由地直接跳转到您感兴趣的具体主题,或者耐心地逐一浏览,以了解有关 ONNX 导出器的所有内容。
脚本总运行时间:(0 分钟 0.000 秒)