备注
点击此处下载完整示例代码
(可选) 从 PyTorch 导出模型到 ONNX 并使用 ONNX Runtime 运行它
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
备注
截至 PyTorch 2.1,ONNX 导出器有两个版本。
torch.onnx.dynamo_export
是基于 PyTorch 2.0 一起发布的 TorchDynamo 技术的最新(仍处于测试版)导出器。torch.onnx.export
基于 TorchScript 后端,自 PyTorch 1.2.0 起可用。
在本教程中,我们将介绍如何使用 TorchScript torch.onnx.export
ONNX 导出器将定义在 PyTorch 中的模型转换为 ONNX 格式。
导出的模型将使用 ONNX Runtime 执行。ONNX Runtime 是一个以性能为导向的 ONNX 模型引擎,它可以在多个平台和硬件(Windows、Linux、Mac 以及 CPU 和 GPU)上高效地进行推理。ONNX Runtime 已被证明可以显著提高多个模型的表现,具体解释如下。
对于本教程,您需要安装 ONNX 和 ONNX Runtime。您可以通过以下方式获取 ONNX 和 ONNX Runtime 的二进制构建版本。
%%bash
pip install onnx onnxruntime
ONNX Runtime 建议使用最新的稳定运行时版本。
# Some standard imports
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
超分辨率是一种提高图像、视频分辨率的手段,在图像处理或视频编辑中应用广泛。本教程中,我们将使用一个小型超分辨率模型。
首先,让我们在 PyTorch 中创建一个 SuperResolution
模型。该模型使用了“实时单图像和视频超分辨率使用高效子像素卷积神经网络”- Shi 等人所描述的高效子像素卷积层,通过提升因子增加图像的分辨率。该模型期望输入图像的 Y 分量,并输出超分辨率后的放大 Y 分量。
模型直接来自 PyTorch 的示例,未做任何修改:
# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
通常,您现在会训练这个模型;然而,对于本教程,我们将下载一些预训练的权重。请注意,此模型并未完全训练以达到良好的准确度,此处仅用于演示目的。
在导出模型之前,调用 torch_model.eval()
或 torch_model.train(False)
是很重要的,以将模型转换为推理模式。这是必需的,因为像 dropout 或 batchnorm 这样的操作在推理和训练模式下的行为不同。
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 64 # just a random number
# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
# set the model to inference mode
torch_model.eval()
在 PyTorch 中导出模型通过跟踪或脚本化完成。本教程将以跟踪导出的模型为例。要导出模型,我们调用 torch.onnx.export()
函数。这将执行模型,记录用于计算输出的操作符的跟踪。因为 export
运行模型,我们需要提供一个输入张量 x
。其中的值可以是随机的,只要它是正确的类型和大小。请注意,除非指定为动态轴,否则导出的 ONNX 图中所有输入维度的输入大小将是固定的。在本例中,我们导出具有 batch_size 1 的输入的模型,但在 torch.onnx.export()
中的 dynamic_axes
参数中指定第一个维度为动态。因此,导出的模型将接受大小为[batch_size, 1, 224, 224]的输入,其中 batch_size 可以是可变的。
要了解更多关于 PyTorch 导出接口的详细信息,请查看 torch.onnx 文档。
# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
我们还计算了 torch_out
,即模型的输出,我们将用它来验证我们导出的模型在 ONNX Runtime 中运行时计算相同的值。
在验证模型输出之前,我们将使用 ONNX API 检查 ONNX 模型。首先, onnx.load("super_resolution.onnx")
将加载保存的模型,并输出一个 onnx.ModelProto
结构(一个用于打包 ML 模型的顶层文件/容器格式。有关 onnx.proto 文档的更多信息。)。然后, onnx.checker.check_model(onnx_model)
将验证模型的结构,并确认模型具有有效的模式。通过检查模型的版本、图的结构以及节点及其输入和输出,验证 ONNX 图的正确性。
import onnx
onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)
现在我们将使用 ONNX Runtime 的 Python API 来计算输出。这部分通常可以在单独的进程或另一台机器上完成,但我们将继续在同一进程中,以便我们可以验证 ONNX Runtime 和 PyTorch 是否为网络计算相同的值。
为了使用 ONNX Runtime 运行模型,我们需要为模型创建一个具有所选配置参数的推理会话(这里我们使用默认配置)。会话创建后,我们使用 run() API 评估模型。此调用的输出是一个包含由 ONNX Runtime 计算出的模型输出的列表。
import onnxruntime
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
我们应该看到 PyTorch 和 ONNX Runtime 的输出在给定的精度( rtol=1e-03
和 atol=1e-05
)下数值上是一致的。作为旁注,如果它们不一致,那么 ONNX 导出器存在问题,请在此情况下联系我们。
模型之间的时序比较
由于 ONNX 模型优化推理速度,在 ONNX 模型上运行相同的数据,而不是在原生 PyTorch 模型上运行,应该会带来高达 2 倍的改进。随着批量大小的增加,改进更为明显。
import time
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
start = time.time()
torch_out = torch_model(x)
end = time.time()
print(f"Inference of Pytorch model used {end - start} seconds")
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
start = time.time()
ort_outs = ort_session.run(None, ort_inputs)
end = time.time()
print(f"Inference of ONNX model used {end - start} seconds")
使用 ONNX Runtime 在图像上运行模型
到目前为止,我们已经从 PyTorch 导出了一个模型,并展示了如何使用一个虚拟张量作为输入来加载和运行它。
在本教程中,我们将使用一个广为人知的猫图像,其外观如下。

首先,让我们加载这张图像,使用标准的 PIL Python 库进行预处理。请注意,这种预处理是处理数据以进行神经网络训练/测试的标准做法。
我们首先将图像调整大小以适应模型输入的大小(224x224)。然后,我们将图像分解为其 Y、Cb 和 Cr 分量。这些分量代表一个灰度图像(Y)和蓝差(Cb)以及红差(Cr)色度分量。由于 Y 分量对人类眼睛更敏感,我们对此分量感兴趣,并将其进行转换。在提取 Y 分量后,我们将其转换为张量,这将成为我们模型的输入。
from PIL import Image
import torchvision.transforms as transforms
img = Image.open("./_static/img/cat.jpg")
resize = transforms.Resize([224, 224])
img = resize(img)
img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()
to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)
现在,作为下一步,让我们将代表灰度缩放猫图像的张量在 ONNX Runtime 中运行,如前所述。
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]
到目前为止,模型的输出是一个张量。现在,我们将处理模型的输出,从输出张量中构建最终的输出图像,并保存该图像。后处理步骤已从此处 PyTorch 实现的超分辨率模型中采用。
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')
# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge(
"YCbCr", [
img_out_y,
img_cb.resize(img_out_y.size, Image.BICUBIC),
img_cr.resize(img_out_y.size, Image.BICUBIC),
]).convert("RGB")
# Save the image, we will compare this with the output image from mobile device
final_img.save("./_static/img/cat_superres_with_ort.jpg")
# Save resized original image (without super-resolution)
img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
img.save("cat_resized.jpg")
下面是两张图像的比较:

低分辨率图像

超分辨率后的图像
ONNX Runtime 是一个跨平台引擎,您可以在多个平台上运行它,包括 CPU 和 GPU。
ONNX Runtime 还可以部署到云上,使用 Azure Machine Learning 服务进行模型推理。更多信息请见此处。
关于 ONNX Runtime 性能的更多信息请见此处。
关于 ONNX Runtime 的更多信息请参阅此处。
脚本总运行时间:(0 分钟 0.000 秒)