• 教程 >
  • 将带有控制流的模型导出为 ONNX
快捷键

ONNX 简介 || 将 PyTorch 模型导出为 ONNX || 扩展 ONNX 导出器操作符支持 || 将带有控制流的模型导出为 ONNX

将具有控制流的模型导出为 ONNX ¶

作者:Xavier Dupré

概述 ¶

本教程演示了在将 PyTorch 模型导出为 ONNX 时如何处理控制流逻辑。它突出了直接导出条件语句的挑战,并提供了绕过这些挑战的解决方案。

除非将它们重构为使用 torch.cond() ,否则条件逻辑无法导出到 ONNX。让我们从一个实现测试的简单模型开始。

您将学习的内容:

  • 如何将模型重构以使用 torch.cond() 进行导出。

  • 如何将具有控制流逻辑的模型导出到 ONNX。

  • 如何使用 ONNX 优化器优化导出的模型。

前提条件 _

  • torch >= 2.6

import torch

定义模型 ¶

定义了两种模型:

ForwardWithControlFlowTest :一个包含 if-else 条件的正向方法的模型。

ModelWithControlFlowTest :将 ForwardWithControlFlowTest 作为简单 MLP 的一部分的模型。使用随机输入张量测试这些模型以确认它们按预期执行。

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlowTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlowTest()

导出模型:第一次尝试 ¶

使用 torch.export.export 导出此模型失败,因为前向传递中的控制流逻辑创建了一个导出器无法处理的图断点。这种行为是预期的,因为未使用 torch.cond() 编写的条件逻辑不受支持。

使用 try-except 块捕获导出过程中的预期失败。如果导出意外成功,将引发 AssertionError 异常。

x = torch.randn(3)
model(x)

try:
    torch.export.export(model, (x,), strict=False)
    raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
    print(e)

使用 torch.onnx.export() 与 JIT Tracing

当使用 torch.onnx.export() 和 dynamo=True 参数导出模型时,导出器默认使用 JIT 跟踪。这种回退允许模型导出,但生成的 ONNX 图可能无法忠实反映原始模型逻辑,因为跟踪存在局限性。

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)

建议补丁:使用 torch.cond() ¶ 进行重构

为了使控制流程可导出,教程演示了将 ForwardWithControlFlowTest 中的 forward 方法替换为使用 torch.cond`() 的重构版本。

重构详情:

两个辅助函数(identity2 和 neg)代表条件逻辑的分支:* torch.cond`() 用于指定条件和两个分支以及输入参数。* 然后将更新后的 forward 方法动态分配给模型中的 ForwardWithControlFlowTest 实例。打印出子模块列表以确认替换。

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward

让我们看看 FX 图的样子。

print(torch.export.export(model, (x,), strict=False))

让我们再次导出。

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)

我们可以优化模型并移除创建来捕获控制流分支的模型局部函数。

onnx_program.optimize()
print(onnx_program.model)

结论 ¶

本教程演示了将具有条件逻辑的模型导出到 ONNX 的挑战,并提出了使用 torch.cond() 的实用解决方案。虽然默认导出器可能会失败或生成不完美的图,但重构模型的逻辑确保了兼容性并生成了忠实的 ONNX 表示。

通过理解这些技术,我们可以克服在使用 PyTorch 模型中的控制流时遇到的常见问题,并确保与 ONNX 工作流程的顺利集成。

进一步阅读 ¶

以下列表涉及从基本示例到高级场景的教程,不一定按照列表顺序排列。您可以自由地直接跳转到您感兴趣的具体主题,或者耐心地逐一浏览,以了解有关 ONNX 导出器的所有内容。

1. 将 PyTorch 模型导出到 ONNX
2. 扩展 ONNX 导出器操作符支持
3. 将具有控制流的模型导出到 ONNX

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源