备注
点击此处下载完整示例代码
ONNX 简介 || 将 PyTorch 模型导出为 ONNX || 扩展 ONNX 导出器操作符支持 || 将具有控制流的模型导出为 ONNX
将 PyTorch 模型导出为 ONNX ¶
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:王铁台,崔俊,塞亚戈·克里帕尔迪。
备注
截至 PyTorch 2.5 版本,ONNX 导出器有两个版本。
torch.onnx.export(..., dynamo=True)
是最新的(仍处于测试版)导出器,使用torch.export
和 Torch FX 来捕获图。它与 PyTorch 2.5 一起发布。torch.onnx.export
使用 TorchScript,自 PyTorch 1.2.0 版本以来一直可用。
在 60 分钟闪电战中,我们有机会从高层次了解 PyTorch,并训练一个小型神经网络来分类图像。在本教程中,我们将扩展这一内容,描述如何使用 torch.onnx.export(..., dynamo=True)
ONNX 导出器将定义在 PyTorch 中的模型转换为 ONNX 格式。
虽然 PyTorch 在迭代模型开发方面非常出色,但模型可以使用不同的格式部署到生产环境中,包括 ONNX(开放神经网络交换)!
ONNX 是一种灵活的开放标准格式,用于表示机器学习模型,标准化的机器学习表示允许它们在各种硬件平台和运行环境中执行,从大规模基于云的超计算机到资源受限的边缘设备,例如您的网络浏览器和手机。
在本教程中,我们将学习如何:
安装所需的依赖项。
编写一个简单的图像分类器模型。
将模型导出为 ONNX 格式。
将 ONNX 模型保存到文件中。
使用 Netron 可视化 ONNX 模型图。
使用 ONNX Runtime 执行 ONNX 模型
将 PyTorch 的结果与 ONNX Runtime 的结果进行比较。
1. 安装所需的依赖项 ¶
因为 ONNX 导出器使用 onnx
和 onnxscript
将 PyTorch 算子转换为 ONNX 算子,所以我们需要安装它们。
pip install --upgrade onnx onnxscript
2. 编写一个简单的图像分类器模型
当你的环境设置完毕后,让我们开始使用 PyTorch 构建我们的图像分类器模型,就像我们在 60 分钟闪电战中所做的那样。
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageClassifierModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
3. 将模型导出为 ONNX 格式
现在我们已经定义了模型,我们需要实例化它并创建一个随机的 32x32 输入。接下来,我们可以将模型导出为 ONNX 格式。
torch_model = ImageClassifierModel()
# Create example inputs for exporting the model. The inputs should be a tuple of tensors.
example_inputs = (torch.randn(1, 1, 32, 32),)
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True)
3.5.(可选)优化 ONNX 模型
ONNX 模型可以通过常量折叠和消除冗余节点进行优化。优化是在原地进行的,因此原始 ONNX 模型会被修改。
onnx_program.optimize()
如我们所见,我们不需要对模型进行任何代码更改。生成的 ONNX 模型以二进制 protobuf 文件的形式存储在 torch.onnx.ONNXProgram
中。
4. 将 ONNX 模型保存到文件中
虽然在许多应用中将导出的模型加载到内存中很有用,但我们可以使用以下代码将其保存到磁盘:
onnx_program.save("image_classifier_model.onnx")
您可以使用以下代码将 ONNX 文件重新加载到内存中并检查其是否格式正确:
import onnx
onnx_model = onnx.load("image_classifier_model.onnx")
onnx.checker.check_model(onnx_model)
5. 使用 Netron 可视化 ONNX 模型图
现在我们已经将模型保存到文件中,我们可以使用 Netron 来可视化它。Netron 可以安装在 macos、Linux 或 Windows 计算机上,或者直接从浏览器运行。让我们尝试使用 Web 版本,通过打开以下链接:https://netron.app/。

打开 Netron 后,我们可以将我们的 image_classifier_model.onnx
文件拖放到浏览器中,或者点击“打开模型”按钮后选择它。

就这样!我们已经成功将 PyTorch 模型导出为 ONNX 格式,并用 Netron 进行了可视化。
6. 使用 ONNX Runtime 执行 ONNX 模型
最后一步是使用 ONNX Runtime 执行 ONNX 模型,但在我们这样做之前,让我们先安装 ONNX Runtime。
pip install onnxruntime
ONNX 标准不支持 PyTorch 所支持的所有数据结构和类型,因此我们需要在将 PyTorch 输入传递给 ONNX Runtime 之前将其转换为 ONNX 格式。在我们的示例中,输入恰好相同,但在更复杂的模型中,它可能比原始 PyTorch 模型有更多的输入。
ONNX Runtime 需要额外的步骤,涉及将所有 PyTorch 张量转换为 Numpy(在 CPU 上)并将它们包装在一个字典中,键是包含输入名称的字符串,值是 Numpy 张量。
现在我们可以创建一个 ONNX Runtime 推理会话,使用处理后的输入执行 ONNX 模型并获取输出。在本教程中,ONNX Runtime 在 CPU 上执行,但也可以在 GPU 上执行。
import onnxruntime
onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs]
print(f"Input length: {len(onnx_inputs)}")
print(f"Sample input: {onnx_inputs}")
ort_session = onnxruntime.InferenceSession(
"./image_classifier_model.onnx", providers=["CPUExecutionProvider"]
)
onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)}
# ONNX Runtime returns a list of outputs
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]
7. 比较 PyTorch 的结果和 ONNX Runtime 的结果
确定导出的模型是否表现良好的最佳方式是通过与 PyTorch 进行数值评估,PyTorch 是我们的真实来源。
为了做到这一点,我们需要使用相同的输入执行 PyTorch 模型,并将结果与 ONNX Runtime 的结果进行比较。在比较结果之前,我们需要将 PyTorch 的输出转换为与 ONNX 格式相匹配。
torch_outputs = torch_model(*example_inputs)
assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
print("PyTorch and ONNX Runtime output matched!")
print(f"Output length: {len(onnxruntime_outputs)}")
print(f"Sample output: {onnxruntime_outputs}")
结论 ¶
就这些了!我们已经成功将我们的 PyTorch 模型导出为 ONNX 格式,将其保存到磁盘,使用 Netron 查看它,使用 ONNX Runtime 执行它,最后将其数值结果与 PyTorch 的结果进行比较。
进一步阅读 ¶
以下列表涉及从基本示例到高级场景的教程,不一定按照列表顺序排列。您可以自由地直接跳转到您感兴趣的具体主题,或者耐心地逐一浏览,以了解有关 ONNX 导出器的所有内容。
脚本总运行时间:(0 分钟 0.000 秒)